Initial commit: Pixel AI comic/video creation platform
- FastAPI backend with SQLModel, Alembic migrations, AgentScope agents - Next.js 15 frontend with React 19, Tailwind, Zustand, React Flow - Multi-provider AI system (DashScope, Kling, MiniMax, Volcengine, OpenAI, etc.) - All HTTP clients migrated from sync requests to async httpx - Admin-managed API keys via environment variables - SSRF vulnerability fixed in ensure_url()
This commit is contained in:
41
.gitignore
vendored
Normal file
41
.gitignore
vendored
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
.venv/
|
||||||
|
*.egg-info/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Database
|
||||||
|
*.db
|
||||||
|
*.sqlite
|
||||||
|
*.sqlite3
|
||||||
|
|
||||||
|
# Node
|
||||||
|
node_modules/
|
||||||
|
.next/
|
||||||
|
tsconfig.tsbuildinfo
|
||||||
|
|
||||||
|
# Environment & Secrets
|
||||||
|
.env
|
||||||
|
.env.local
|
||||||
|
.env.*.local
|
||||||
|
backend/src/config/storage.json
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
|
||||||
|
# OS
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
logs/
|
||||||
|
*.log
|
||||||
|
|
||||||
|
# Build
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
|
||||||
|
# Claude Code
|
||||||
|
.omc/
|
||||||
|
.claude/
|
||||||
125
CLAUDE.md
Normal file
125
CLAUDE.md
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
# CLAUDE.md
|
||||||
|
|
||||||
|
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||||
|
|
||||||
|
## Project Overview
|
||||||
|
|
||||||
|
Pixel is an AI-powered platform for creating comics and videos from scripts. It uses multi-agent systems (AgentScope) and multiple AI providers for image/video/audio generation.
|
||||||
|
|
||||||
|
## Repository Structure
|
||||||
|
|
||||||
|
- `backend/` — FastAPI (Python 3.12+), SQLModel/SQLAlchemy, Alembic migrations, AgentScope agents
|
||||||
|
- `frontend/` — Next.js 15 (App Router), React 19, Tailwind CSS, Zustand, @xyflow/react (canvas), TanStack Query
|
||||||
|
|
||||||
|
## Common Commands
|
||||||
|
|
||||||
|
### Backend
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd backend
|
||||||
|
uv sync # Install dependencies
|
||||||
|
uv run uvicorn src.main:app --reload --port 8000 # Start dev server
|
||||||
|
./start.sh # Alternative start script
|
||||||
|
alembic upgrade head # Apply DB migrations
|
||||||
|
alembic revision --autogenerate -m "desc" # Create migration
|
||||||
|
|
||||||
|
# Tests (pytest with asyncio_mode=auto)
|
||||||
|
pytest # All tests
|
||||||
|
pytest tests/test_models.py # Single test file
|
||||||
|
pytest -k "test_name" # Single test by name
|
||||||
|
pytest --cov=src # With coverage
|
||||||
|
|
||||||
|
# Code quality
|
||||||
|
black src/ && isort src/ # Format
|
||||||
|
flake8 src/ && mypy src/ # Lint
|
||||||
|
```
|
||||||
|
|
||||||
|
### Frontend
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd frontend
|
||||||
|
pnpm install # Install dependencies
|
||||||
|
pnpm dev # Start dev server (localhost:3000)
|
||||||
|
pnpm build # Production build
|
||||||
|
pnpm lint # ESLint
|
||||||
|
pnpm type-check # tsc --noEmit
|
||||||
|
pnpm format # ESLint --fix
|
||||||
|
|
||||||
|
# Tests
|
||||||
|
pnpm test # Vitest (watch mode)
|
||||||
|
pnpm test:run # Vitest (single run)
|
||||||
|
pnpm test:coverage # With coverage
|
||||||
|
pnpm e2e # Playwright E2E tests
|
||||||
|
|
||||||
|
# Regenerate API types after backend changes
|
||||||
|
pnpm gen:api
|
||||||
|
```
|
||||||
|
|
||||||
|
### Docker
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker-compose up -d # Start all services (backend, frontend, redis, postgres)
|
||||||
|
docker-compose down # Stop services
|
||||||
|
```
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
### Backend Three-Layer Pattern
|
||||||
|
|
||||||
|
All backend code follows: **API Layer** → **Service Layer** → **Repository Layer**
|
||||||
|
|
||||||
|
- **API** (`src/api/`): FastAPI routers, request validation, response formatting. All endpoints prefixed with `/api/v1`.
|
||||||
|
- **Services** (`src/services/`): Business logic. Key services: `project_service`, `task_service`, `storyboard_service`, `storage_service`.
|
||||||
|
- **Repositories** (`src/repositories/`): Data access via SQLModel. Use `AsyncBaseRepository` (from `base_async.py`) for new code; the sync `BaseRepository` is deprecated.
|
||||||
|
- **Models** (`src/models/`): `entities.py` has SQLModel table definitions; `schemas/` has Pydantic request/response schemas.
|
||||||
|
- **Mappers** (`src/mappers/`): Convert between DB entities and API schemas.
|
||||||
|
|
||||||
|
### AI Provider System
|
||||||
|
|
||||||
|
Providers are registered via JSON config files in `src/config/services/`. The `ModelRegistry` (in `src/services/provider/registry.py`) is a thread-safe factory registry. Model IDs use composite format: `provider/model_key` (e.g., `dashscope/qwen-image`).
|
||||||
|
|
||||||
|
Provider implementations live in `src/services/provider/<provider_name>/`. Each provider module has an `image.py`, `video.py`, etc. Provider types: dashscope, volcengine, kling, minimax, modelscope, midjourney, openai, google.
|
||||||
|
|
||||||
|
The `ModelRegistry.get()` method creates fresh instances each call (factory pattern) to avoid shared state. Supports variant resolution and fallback models.
|
||||||
|
|
||||||
|
### Task Management
|
||||||
|
|
||||||
|
The `UnifiedTaskManager` (`src/services/task_manager/`) handles async generation tasks with priority queues, concurrency control, exponential backoff retries, and Prometheus metrics. Tasks are persisted in `TaskDB` with statuses: pending, processing, success, failed, timeout, retrying.
|
||||||
|
|
||||||
|
### Agent Engine
|
||||||
|
|
||||||
|
`src/services/agent_engine/` contains an AgentScope-based multi-agent system for script analysis and creative workflows. Skills are organized under `skills/` with film production (storyboarding, cinematography, screenwriting, sound design) and general (canvas workflow, project management, creative generation) categories.
|
||||||
|
|
||||||
|
### Frontend Architecture
|
||||||
|
|
||||||
|
- **App Router** (`src/app/`): Pages organized by feature — admin, canvas, login, register, etc.
|
||||||
|
- **State**: Zustand stores in `src/lib/store/` — `canvasStore.ts` (with slices: nodes, edges, groups, history, selection, persistence), `authStore.ts`, `modelStore.ts`, `uiStore.ts`.
|
||||||
|
- **API Client**: Auto-generated from OpenAPI spec into `src/lib/api/`. Services in `src/lib/api/services/` wrap the generated client. Auth token resolved from localStorage via `src/lib/client.ts`.
|
||||||
|
- **Canvas**: React Flow based infinite canvas (`src/components/canvas/`). Canvas state persisted via `persistenceSlice`.
|
||||||
|
- **UI Components**: Radix UI primitives in `src/components/ui/`, styled with Tailwind + class-variance-authority.
|
||||||
|
|
||||||
|
### API Proxy
|
||||||
|
|
||||||
|
In development, Next.js rewrites (`next.config.mjs`) proxy `/api/*`, `/files/*`, `/uploads/*`, `/chat/*`, `/health` to the backend at `API_URL` (default `http://localhost:8000`). In production, set `NEXT_PUBLIC_API_URL` for direct browser-to-backend calls.
|
||||||
|
|
||||||
|
### Database
|
||||||
|
|
||||||
|
- **Dev**: SQLite (default, no config needed)
|
||||||
|
- **Prod**: PostgreSQL 15+ via `DATABASE_URL`
|
||||||
|
- **ORM**: SQLModel (Pydantic v2 + SQLAlchemy)
|
||||||
|
- **Migrations**: Alembic in `backend/alembic/`
|
||||||
|
- **Cache**: Optional Redis for caching and rate limiting
|
||||||
|
|
||||||
|
### Middleware Stack (backend)
|
||||||
|
|
||||||
|
Applied in order: error handler → request tracking → response formatter → security headers → security → rate limiter → performance → metrics → tracing → GZip → CORS.
|
||||||
|
|
||||||
|
## Key Conventions
|
||||||
|
|
||||||
|
- Backend uses `uv` as package manager, not pip/poetry
|
||||||
|
- Frontend uses `pnpm`, not npm/yarn
|
||||||
|
- All generation endpoints return a `task_id`; poll `GET /api/v1/tasks/{task_id}` for results
|
||||||
|
- Model IDs must use composite format: `provider/model_key` — never pass a separate `provider` parameter
|
||||||
|
- JSON fields in SQLite use `sa_column=Column(JSON)` on SQLModel fields
|
||||||
|
- Timestamps stored as Unix floats (not datetime objects) with `TimestampMixin` for ISO conversion
|
||||||
|
- Soft delete pattern: `deleted_at` field, not physical deletion
|
||||||
343
README.md
Normal file
343
README.md
Normal file
@@ -0,0 +1,343 @@
|
|||||||
|
# Pixel - AI Video Creation Platform
|
||||||
|
|
||||||
|
[中文文档](README_zh-CN.md)
|
||||||
|
|
||||||
|
Pixel is an intelligent platform for creating comics and videos from scripts using AI. It streamlines the workflow from scriptwriting to asset management, storyboard generation, and final video production, leveraging the power of Multi-Agent Systems and advanced generative models.
|
||||||
|
|
||||||
|
## ✨ Features
|
||||||
|
|
||||||
|
- **Intelligent Script Analysis**:
|
||||||
|
- Automatically parse scripts to identify characters, scenes, and props using LLMs.
|
||||||
|
- **Agent-based Workflow**: Utilizes specialized agents (powered by AgentScope) for deep script understanding and breakdown.
|
||||||
|
- **Multi-Provider AI Support**:
|
||||||
|
- **LLM**: DashScope (Qwen), Google (Gemini), VolcEngine (Doubao).
|
||||||
|
- **Image Generation**: Flux 1.1 Pro, Wanx, Kolors (ModelScope).
|
||||||
|
- **Video Generation**: Kling 1.5, Hailuo (MiniMax), CogVideoX, Wan 2.1.
|
||||||
|
- **Asset Management**: Centralized "Material Center" to manage and edit creative assets.
|
||||||
|
- **AI Storyboarding**: Generate visual storyboards from script descriptions using AI image generation with style consistency control.
|
||||||
|
- **Video Generation & Editing**:
|
||||||
|
- Transform static storyboards into dynamic videos.
|
||||||
|
- **Fine-grained Control**: First/Last frame control, Camera Motion (Zoom, Pan, Tilt), and Motion Bucket settings.
|
||||||
|
- **Infinite Canvas**: A node-based free creation workspace (powered by React Flow) supporting multi-selection, smooth zooming, and flexible node connections.
|
||||||
|
- **Project Management**: Organize your creative works in a structured workspace with support for Episodes and Scenes.
|
||||||
|
|
||||||
|
## 🏗 Architecture
|
||||||
|
|
||||||
|
The project is structured as a monorepo with a clean three-layer architecture:
|
||||||
|
|
||||||
|
- **`frontend/`**: A **Next.js 15** (App Router) application using **React 19**, Tailwind CSS, and `@xyflow/react` for the canvas interface.
|
||||||
|
- **`backend/`**: A **FastAPI** service with three-layer architecture (API Layer, Service Layer, Repository Layer). It uses **AgentScope** for agent orchestration and supports multiple model providers via a plugin system.
|
||||||
|
|
||||||
|
## 📚 Documentation
|
||||||
|
|
||||||
|
Comprehensive documentation is available:
|
||||||
|
|
||||||
|
- **[docs/API.md](docs/API.md)**: Complete API reference with examples
|
||||||
|
- **[docs/development-guide.md](docs/development-guide.md)**: Development best practices and guidelines
|
||||||
|
- **[docs/FRONTEND_OPTIMIZATION.md](docs/FRONTEND_OPTIMIZATION.md)**: Frontend optimization strategies
|
||||||
|
|
||||||
|
## 🚀 Quick Start
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- **Node.js**: v20 or higher (Required for Next.js 15)
|
||||||
|
- **Python**: v3.12 or higher
|
||||||
|
- **Package Manager**: `pnpm` (Frontend), `uv` (Backend - Recommended)
|
||||||
|
- **Redis**: v7+ (Optional but recommended for caching)
|
||||||
|
- **PostgreSQL**: v15+ (For production, SQLite for development)
|
||||||
|
- **API Keys**: Aliyun DashScope, VolcEngine, or Google AI Studio keys depending on models used
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
#### 1. Clone the Repository
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone <repository-url>
|
||||||
|
cd pixel
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. Backend Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd backend
|
||||||
|
|
||||||
|
# Install uv if you haven't already
|
||||||
|
pip install uv
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
uv sync
|
||||||
|
|
||||||
|
# Configure environment variables
|
||||||
|
cp .env.example .env
|
||||||
|
# Edit .env with your API keys
|
||||||
|
|
||||||
|
# Run database migrations
|
||||||
|
alembic upgrade head
|
||||||
|
|
||||||
|
# Start the backend server
|
||||||
|
./start.sh
|
||||||
|
# Or manually: uv run uvicorn src.main:app --reload --port 8000
|
||||||
|
```
|
||||||
|
|
||||||
|
The backend will be available at http://localhost:8000
|
||||||
|
|
||||||
|
#### 3. Frontend Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd frontend
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
pnpm install
|
||||||
|
|
||||||
|
# Start the development server
|
||||||
|
pnpm dev
|
||||||
|
```
|
||||||
|
|
||||||
|
The frontend will be available at http://localhost:3000
|
||||||
|
|
||||||
|
### Environment Configuration
|
||||||
|
|
||||||
|
Create a `.env` file in the `backend/` directory:
|
||||||
|
|
||||||
|
```env
|
||||||
|
# AI Providers
|
||||||
|
DASHSCOPE_API_KEY=your_dashscope_key
|
||||||
|
VOLCENGINE_ACCESS_KEY=your_volcengine_key
|
||||||
|
VOLCENGINE_SECRET_KEY=your_volcengine_secret
|
||||||
|
GOOGLE_API_KEY=your_google_key
|
||||||
|
KLING_API_KEY=your_kling_key
|
||||||
|
|
||||||
|
# Storage
|
||||||
|
OSS_ACCESS_KEY_ID=your_oss_key
|
||||||
|
OSS_ACCESS_KEY_SECRET=your_oss_secret
|
||||||
|
|
||||||
|
# Database (Optional - defaults to SQLite)
|
||||||
|
DATABASE_URL=postgresql://user:pass@localhost/pixel
|
||||||
|
|
||||||
|
# Redis (Optional but recommended)
|
||||||
|
REDIS_URL=redis://localhost:6379
|
||||||
|
REDIS_ENABLED=1
|
||||||
|
|
||||||
|
# CORS
|
||||||
|
# Production: set explicit comma-separated origins
|
||||||
|
CORS_ALLOWED_ORIGINS=https://your-app.example.com
|
||||||
|
# Development fallback origins
|
||||||
|
CORS_DEV_ALLOWED_ORIGINS=http://localhost:3000,http://127.0.0.1:3000
|
||||||
|
|
||||||
|
# Monitoring
|
||||||
|
ENABLE_METRICS=true
|
||||||
|
LOG_LEVEL=INFO
|
||||||
|
```
|
||||||
|
|
||||||
|
### Docker Deployment (Optional)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build and start all services
|
||||||
|
docker-compose up -d
|
||||||
|
|
||||||
|
# View logs
|
||||||
|
docker-compose logs -f
|
||||||
|
|
||||||
|
# Stop services
|
||||||
|
docker-compose down
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🔌 API Documentation
|
||||||
|
|
||||||
|
The backend provides RESTful APIs with comprehensive documentation:
|
||||||
|
|
||||||
|
- **[API Documentation](docs/API.md)**: Complete API reference with examples
|
||||||
|
- **Interactive Docs**: http://localhost:8000/docs (Swagger UI)
|
||||||
|
- **ReDoc**: http://localhost:8000/redoc
|
||||||
|
- **OpenAPI Spec**: http://localhost:8000/openapi.json
|
||||||
|
|
||||||
|
### Core Endpoints
|
||||||
|
|
||||||
|
* **Image Generation**: `POST /api/v1/generations/image`
|
||||||
|
* **Video Generation**: `POST /api/v1/generations/video`
|
||||||
|
* **Script Analysis**: `POST /api/v1/script/analyze`
|
||||||
|
* **Task Status**: `GET /api/v1/tasks/{task_id}`
|
||||||
|
* **Project Management**: `/api/v1/projects/*`
|
||||||
|
* **Canvas Operations**: `/api/v1/canvas/*`
|
||||||
|
|
||||||
|
### Interactive Documentation
|
||||||
|
|
||||||
|
- **Swagger UI**: http://localhost:8000/docs
|
||||||
|
- **ReDoc**: http://localhost:8000/redoc
|
||||||
|
- **OpenAPI Spec**: http://localhost:8000/openapi.json
|
||||||
|
|
||||||
|
All generation endpoints support an `extra_params` field to pass model-specific arguments directly to the underlying SDK.
|
||||||
|
|
||||||
|
## 🧪 Development
|
||||||
|
|
||||||
|
### Running Tests
|
||||||
|
|
||||||
|
**Backend:**
|
||||||
|
```bash
|
||||||
|
cd backend
|
||||||
|
pytest # Run all tests
|
||||||
|
pytest tests/test_integration.py # Run integration tests
|
||||||
|
pytest --cov=src # Run with coverage
|
||||||
|
```
|
||||||
|
|
||||||
|
**Frontend:**
|
||||||
|
```bash
|
||||||
|
cd frontend
|
||||||
|
pnpm test # Run all tests
|
||||||
|
pnpm test:watch # Run in watch mode
|
||||||
|
pnpm test:coverage # Run with coverage
|
||||||
|
```
|
||||||
|
|
||||||
|
### Code Quality
|
||||||
|
|
||||||
|
**Backend:**
|
||||||
|
```bash
|
||||||
|
# Format code
|
||||||
|
black src/
|
||||||
|
isort src/
|
||||||
|
|
||||||
|
# Lint
|
||||||
|
flake8 src/
|
||||||
|
mypy src/
|
||||||
|
```
|
||||||
|
|
||||||
|
**Frontend:**
|
||||||
|
```bash
|
||||||
|
# Lint
|
||||||
|
pnpm lint
|
||||||
|
|
||||||
|
# Type check
|
||||||
|
pnpm type-check
|
||||||
|
|
||||||
|
# Format
|
||||||
|
pnpm format
|
||||||
|
```
|
||||||
|
|
||||||
|
### Database Migrations
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd backend
|
||||||
|
|
||||||
|
# Create a new migration
|
||||||
|
alembic revision --autogenerate -m "description"
|
||||||
|
|
||||||
|
# Apply migrations
|
||||||
|
alembic upgrade head
|
||||||
|
|
||||||
|
# Rollback one migration
|
||||||
|
alembic downgrade -1
|
||||||
|
```
|
||||||
|
|
||||||
|
### API Type Generation
|
||||||
|
|
||||||
|
When backend API changes, regenerate frontend types:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd frontend
|
||||||
|
pnpm gen:api
|
||||||
|
```
|
||||||
|
|
||||||
|
## 📊 Monitoring
|
||||||
|
|
||||||
|
### Health Check
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://localhost:8000/health
|
||||||
|
```
|
||||||
|
|
||||||
|
### Metrics (Prometheus)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://localhost:8000/metrics
|
||||||
|
```
|
||||||
|
|
||||||
|
### Logs
|
||||||
|
|
||||||
|
Backend logs are in JSON format for easy parsing:
|
||||||
|
```bash
|
||||||
|
tail -f backend/logs/app.log | jq
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🤝 Contributing
|
||||||
|
|
||||||
|
1. Fork the repository
|
||||||
|
2. Create a feature branch (`git checkout -b feature/amazing-feature`)
|
||||||
|
3. Commit your changes (`git commit -m 'Add amazing feature'`)
|
||||||
|
4. Push to the branch (`git push origin feature/amazing-feature`)
|
||||||
|
5. Open a Pull Request
|
||||||
|
|
||||||
|
### Development Guidelines
|
||||||
|
|
||||||
|
- Follow the three-layer architecture (API, Service, Repository)
|
||||||
|
- Write tests for new features
|
||||||
|
- Update documentation
|
||||||
|
- Follow code style guidelines
|
||||||
|
- Use conventional commits
|
||||||
|
|
||||||
|
## 🐛 Troubleshooting
|
||||||
|
|
||||||
|
### Backend Issues
|
||||||
|
|
||||||
|
**Redis Connection Error:**
|
||||||
|
```bash
|
||||||
|
# Check if Redis is running
|
||||||
|
redis-cli ping
|
||||||
|
# Should return: PONG
|
||||||
|
```
|
||||||
|
|
||||||
|
**Database Connection Error:**
|
||||||
|
```bash
|
||||||
|
# Check database connection
|
||||||
|
psql -h localhost -U user -d pixel
|
||||||
|
```
|
||||||
|
|
||||||
|
**Import Errors:**
|
||||||
|
```bash
|
||||||
|
# Reinstall dependencies
|
||||||
|
uv sync --reinstall
|
||||||
|
```
|
||||||
|
|
||||||
|
### Frontend Issues
|
||||||
|
|
||||||
|
**Module Not Found:**
|
||||||
|
```bash
|
||||||
|
# Clear cache and reinstall
|
||||||
|
rm -rf node_modules .next
|
||||||
|
pnpm install
|
||||||
|
```
|
||||||
|
|
||||||
|
**API Type Mismatch:**
|
||||||
|
```bash
|
||||||
|
# Regenerate API types
|
||||||
|
pnpm gen:api
|
||||||
|
```
|
||||||
|
|
||||||
|
## 📦 Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
pixel/
|
||||||
|
├── backend/ # FastAPI backend
|
||||||
|
│ ├── src/
|
||||||
|
│ │ ├── controllers/ # API Layer (HTTP handlers)
|
||||||
|
│ │ ├── services/ # Service Layer (business logic)
|
||||||
|
│ │ ├── repositories/ # Repository Layer (data access)
|
||||||
|
│ │ ├── models/ # Data models (entities & schemas)
|
||||||
|
│ │ ├── middlewares/ # API middlewares
|
||||||
|
│ │ └── utils/ # Utilities
|
||||||
|
│ ├── tests/ # Test suite
|
||||||
|
│ └── alembic/ # Database migrations
|
||||||
|
├── frontend/ # Next.js frontend
|
||||||
|
│ ├── src/
|
||||||
|
│ │ ├── app/ # Next.js pages (App Router)
|
||||||
|
│ │ ├── components/ # React components
|
||||||
|
│ │ ├── lib/ # Utilities and services
|
||||||
|
│ │ └── store/ # State management (Zustand)
|
||||||
|
│ └── tests/ # Test suite
|
||||||
|
├── docs/ # Documentation
|
||||||
|
├── ARCHITECTURE.md # Architecture documentation
|
||||||
|
└── docker-compose.yml # Docker configuration
|
||||||
|
```
|
||||||
|
|
||||||
|
## 📄 License
|
||||||
|
|
||||||
|
MIT
|
||||||
336
README_zh-CN.md
Normal file
336
README_zh-CN.md
Normal file
@@ -0,0 +1,336 @@
|
|||||||
|
# Pixel - AI 视频创作平台
|
||||||
|
|
||||||
|
[English Documentation](README.md)
|
||||||
|
|
||||||
|
Pixel 是一个智能平台,利用 AI 从剧本创作漫画和视频。它简化了从剧本编写到素材管理、分镜生成以及最终视频制作的整个工作流,利用多智能体系统(Multi-Agent Systems)和先进的生成式模型提供强大的创作支持。
|
||||||
|
|
||||||
|
## ✨ 功能特性
|
||||||
|
|
||||||
|
- **智能剧本分析**:
|
||||||
|
- 自动解析剧本,识别角色、场景和道具。
|
||||||
|
- **Agent 工作流**:利用基于 AgentScope 的专用智能体进行深度的剧本理解和拆解。
|
||||||
|
- **多模型 AI 支持**:
|
||||||
|
- **LLM**: DashScope (通义千问), Google (Gemini), VolcEngine (豆包)。
|
||||||
|
- **生图**: Flux 1.1 Pro, 通义万相 (Wanx), 可图 (Kolors)。
|
||||||
|
- **生视频**: 可灵 (Kling 1.5), 海螺 (Hailuo/MiniMax), CogVideoX, 通义万相 2.1。
|
||||||
|
- **素材管理**:集中的“素材中心”,用于管理和编辑创意素材。
|
||||||
|
- **AI 分镜**:利用 AI 图像生成技术,根据剧本描述生成可视化分镜,并支持风格一致性控制。
|
||||||
|
- **视频生成与编辑**:
|
||||||
|
- 将静态分镜转化为动态视频。
|
||||||
|
- **精细控制**:支持首/尾帧控制、运镜控制(推拉摇移)、以及 Motion Bucket 参数调节。
|
||||||
|
- **无限画布**:基于节点的自由创作工作区(由 React Flow 驱动),支持多选拖拽、丝滑缩放和灵活的节点连接。
|
||||||
|
- **项目管理**:在结构化的工作区中组织您的创意作品,支持集(Episode)和场(Scene)管理。
|
||||||
|
|
||||||
|
## 🏗 架构
|
||||||
|
|
||||||
|
本项目采用 Monorepo 结构,后端采用清晰的三层架构:
|
||||||
|
|
||||||
|
- **`frontend/`**: 基于 **Next.js 15** (App Router) 的前端应用,使用 **React 19**、Tailwind CSS 和 `@xyflow/react` 构建画布界面。
|
||||||
|
- **`backend/`**: 基于 **FastAPI** 的后端服务,采用三层架构(API层、服务层、仓储层)。使用 **AgentScope** 进行智能体编排,并通过插件系统支持多种 AI 模型提供商。
|
||||||
|
|
||||||
|
## 📚 文档
|
||||||
|
|
||||||
|
完整的技术文档:
|
||||||
|
|
||||||
|
- **[ARCHITECTURE.md](ARCHITECTURE.md)**: 完整的系统架构、设计决策和技术栈
|
||||||
|
- **[backend/README.md](backend/README.md)**: 后端设置、API文档和开发指南
|
||||||
|
- **[frontend/README.md](frontend/README.md)**: 前端设置、组件结构和开发指南
|
||||||
|
|
||||||
|
## 🚀 快速开始
|
||||||
|
|
||||||
|
### 前置要求
|
||||||
|
|
||||||
|
- **Node.js**: v20 或更高版本 (Next.js 15 必需)
|
||||||
|
- **Python**: v3.12 或更高版本
|
||||||
|
- **包管理器**: `pnpm` (前端), `uv` (后端 - 推荐)
|
||||||
|
- **Redis**: v7+ (可选但推荐用于缓存)
|
||||||
|
- **PostgreSQL**: v15+ (生产环境,开发环境使用 SQLite)
|
||||||
|
- **API Keys**: 根据使用的模型准备 阿里云 DashScope, 火山引擎 VolcEngine 或 Google AI Studio 的 Key
|
||||||
|
|
||||||
|
### 安装步骤
|
||||||
|
|
||||||
|
#### 1. 克隆仓库
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone <repository-url>
|
||||||
|
cd pixel
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. 后端设置
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd backend
|
||||||
|
|
||||||
|
# 安装 uv(如果尚未安装)
|
||||||
|
pip install uv
|
||||||
|
|
||||||
|
# 安装依赖
|
||||||
|
uv sync
|
||||||
|
|
||||||
|
# 配置环境变量
|
||||||
|
cp .env.example .env
|
||||||
|
# 编辑 .env 文件,填入您的 API Keys
|
||||||
|
|
||||||
|
# 运行数据库迁移
|
||||||
|
alembic upgrade head
|
||||||
|
|
||||||
|
# 启动后端服务器
|
||||||
|
./start.sh
|
||||||
|
# 或手动运行: uv run uvicorn src.main:app --reload --port 8000
|
||||||
|
```
|
||||||
|
|
||||||
|
后端将在 http://localhost:8000 运行
|
||||||
|
|
||||||
|
#### 3. 前端设置
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd frontend
|
||||||
|
|
||||||
|
# 安装依赖
|
||||||
|
pnpm install
|
||||||
|
|
||||||
|
# 启动开发服务器
|
||||||
|
pnpm dev
|
||||||
|
```
|
||||||
|
|
||||||
|
前端将在 http://localhost:3000 运行
|
||||||
|
|
||||||
|
### 环境配置
|
||||||
|
|
||||||
|
在 `backend/` 目录中创建 `.env` 文件:
|
||||||
|
|
||||||
|
```env
|
||||||
|
# AI 提供商
|
||||||
|
DASHSCOPE_API_KEY=your_dashscope_key
|
||||||
|
VOLCENGINE_ACCESS_KEY=your_volcengine_key
|
||||||
|
VOLCENGINE_SECRET_KEY=your_volcengine_secret
|
||||||
|
GOOGLE_API_KEY=your_google_key
|
||||||
|
KLING_API_KEY=your_kling_key
|
||||||
|
|
||||||
|
# 存储
|
||||||
|
OSS_ACCESS_KEY_ID=your_oss_key
|
||||||
|
OSS_ACCESS_KEY_SECRET=your_oss_secret
|
||||||
|
|
||||||
|
# 数据库(可选 - 默认使用 SQLite)
|
||||||
|
DATABASE_URL=postgresql://user:pass@localhost/pixel
|
||||||
|
|
||||||
|
# Redis(可选但推荐)
|
||||||
|
REDIS_URL=redis://localhost:6379
|
||||||
|
|
||||||
|
# 监控
|
||||||
|
ENABLE_METRICS=true
|
||||||
|
LOG_LEVEL=INFO
|
||||||
|
```
|
||||||
|
|
||||||
|
### Docker 部署(可选)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 构建并启动所有服务
|
||||||
|
docker-compose up -d
|
||||||
|
|
||||||
|
# 查看日志
|
||||||
|
docker-compose logs -f
|
||||||
|
|
||||||
|
# 停止服务
|
||||||
|
docker-compose down
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🔌 API 文档
|
||||||
|
|
||||||
|
后端提供完整的 RESTful API 文档:
|
||||||
|
|
||||||
|
- **[API 文档](docs/API.md)**: 完整的 API 参考和示例
|
||||||
|
- **交互式文档**: http://localhost:8000/docs (Swagger UI)
|
||||||
|
- **ReDoc**: http://localhost:8000/redoc
|
||||||
|
- **OpenAPI 规范**: http://localhost:8000/openapi.json
|
||||||
|
|
||||||
|
### 核心端点
|
||||||
|
|
||||||
|
* **生图**: `POST /api/v1/generations/image`
|
||||||
|
* **生视频**: `POST /api/v1/generations/video`
|
||||||
|
* **剧本分析**: `POST /api/v1/script/analyze`
|
||||||
|
* **任务状态**: `GET /api/v1/tasks/{task_id}`
|
||||||
|
* **项目管理**: `/api/v1/projects/*`
|
||||||
|
* **画布操作**: `/api/v1/canvas/*`
|
||||||
|
|
||||||
|
### 交互式文档
|
||||||
|
|
||||||
|
- **Swagger UI**: http://localhost:8000/docs
|
||||||
|
- **ReDoc**: http://localhost:8000/redoc
|
||||||
|
- **OpenAPI 规范**: http://localhost:8000/openapi.json
|
||||||
|
|
||||||
|
所有生成接口均支持 `extra_params` 字段,用于直接向底层 SDK 传递模型特定参数。
|
||||||
|
|
||||||
|
## 🧪 开发
|
||||||
|
|
||||||
|
### 运行测试
|
||||||
|
|
||||||
|
**后端:**
|
||||||
|
```bash
|
||||||
|
cd backend
|
||||||
|
pytest # 运行所有测试
|
||||||
|
pytest tests/test_integration.py # 运行集成测试
|
||||||
|
pytest --cov=src # 运行并生成覆盖率报告
|
||||||
|
```
|
||||||
|
|
||||||
|
**前端:**
|
||||||
|
```bash
|
||||||
|
cd frontend
|
||||||
|
pnpm test # 运行所有测试
|
||||||
|
pnpm test:watch # 监听模式运行
|
||||||
|
pnpm test:coverage # 运行并生成覆盖率报告
|
||||||
|
```
|
||||||
|
|
||||||
|
### 代码质量
|
||||||
|
|
||||||
|
**后端:**
|
||||||
|
```bash
|
||||||
|
# 格式化代码
|
||||||
|
black src/
|
||||||
|
isort src/
|
||||||
|
|
||||||
|
# 代码检查
|
||||||
|
flake8 src/
|
||||||
|
mypy src/
|
||||||
|
```
|
||||||
|
|
||||||
|
**前端:**
|
||||||
|
```bash
|
||||||
|
# 代码检查
|
||||||
|
pnpm lint
|
||||||
|
|
||||||
|
# 类型检查
|
||||||
|
pnpm type-check
|
||||||
|
|
||||||
|
# 格式化
|
||||||
|
pnpm format
|
||||||
|
```
|
||||||
|
|
||||||
|
### 数据库迁移
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd backend
|
||||||
|
|
||||||
|
# 创建新迁移
|
||||||
|
alembic revision --autogenerate -m "描述"
|
||||||
|
|
||||||
|
# 应用迁移
|
||||||
|
alembic upgrade head
|
||||||
|
|
||||||
|
# 回滚一个迁移
|
||||||
|
alembic downgrade -1
|
||||||
|
```
|
||||||
|
|
||||||
|
### API 类型生成
|
||||||
|
|
||||||
|
当后端 API 变更时,重新生成前端类型:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd frontend
|
||||||
|
pnpm gen:api
|
||||||
|
```
|
||||||
|
|
||||||
|
## 📊 监控
|
||||||
|
|
||||||
|
### 健康检查
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://localhost:8000/health
|
||||||
|
```
|
||||||
|
|
||||||
|
### 指标(Prometheus)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://localhost:8000/metrics
|
||||||
|
```
|
||||||
|
|
||||||
|
### 日志
|
||||||
|
|
||||||
|
后端日志采用 JSON 格式,便于解析:
|
||||||
|
```bash
|
||||||
|
tail -f backend/logs/app.log | jq
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🤝 贡献
|
||||||
|
|
||||||
|
1. Fork 本仓库
|
||||||
|
2. 创建特性分支 (`git checkout -b feature/amazing-feature`)
|
||||||
|
3. 提交更改 (`git commit -m 'Add amazing feature'`)
|
||||||
|
4. 推送到分支 (`git push origin feature/amazing-feature`)
|
||||||
|
5. 开启 Pull Request
|
||||||
|
|
||||||
|
### 开发指南
|
||||||
|
|
||||||
|
- 遵循三层架构(API层、服务层、仓储层)
|
||||||
|
- 为新功能编写测试
|
||||||
|
- 更新文档
|
||||||
|
- 遵循代码风格指南
|
||||||
|
- 使用约定式提交
|
||||||
|
|
||||||
|
## 🐛 故障排除
|
||||||
|
|
||||||
|
### 后端问题
|
||||||
|
|
||||||
|
**Redis 连接错误:**
|
||||||
|
```bash
|
||||||
|
# 检查 Redis 是否运行
|
||||||
|
redis-cli ping
|
||||||
|
# 应返回: PONG
|
||||||
|
```
|
||||||
|
|
||||||
|
**数据库连接错误:**
|
||||||
|
```bash
|
||||||
|
# 检查数据库连接
|
||||||
|
psql -h localhost -U user -d pixel
|
||||||
|
```
|
||||||
|
|
||||||
|
**导入错误:**
|
||||||
|
```bash
|
||||||
|
# 重新安装依赖
|
||||||
|
uv sync --reinstall
|
||||||
|
```
|
||||||
|
|
||||||
|
### 前端问题
|
||||||
|
|
||||||
|
**模块未找到:**
|
||||||
|
```bash
|
||||||
|
# 清除缓存并重新安装
|
||||||
|
rm -rf node_modules .next
|
||||||
|
pnpm install
|
||||||
|
```
|
||||||
|
|
||||||
|
**API 类型不匹配:**
|
||||||
|
```bash
|
||||||
|
# 重新生成 API 类型
|
||||||
|
pnpm gen:api
|
||||||
|
```
|
||||||
|
|
||||||
|
## 📦 项目结构
|
||||||
|
|
||||||
|
```
|
||||||
|
pixel/
|
||||||
|
├── backend/ # FastAPI 后端
|
||||||
|
│ ├── src/
|
||||||
|
│ │ ├── controllers/ # API 层(HTTP 处理)
|
||||||
|
│ │ ├── services/ # 服务层(业务逻辑)
|
||||||
|
│ │ ├── repositories/ # 仓储层(数据访问)
|
||||||
|
│ │ ├── models/ # 数据模型(实体和模式)
|
||||||
|
│ │ ├── middlewares/ # API 中间件
|
||||||
|
│ │ └── utils/ # 工具函数
|
||||||
|
│ ├── tests/ # 测试套件
|
||||||
|
│ └── alembic/ # 数据库迁移
|
||||||
|
├── frontend/ # Next.js 前端
|
||||||
|
│ ├── src/
|
||||||
|
│ │ ├── app/ # Next.js 页面(App Router)
|
||||||
|
│ │ ├── components/ # React 组件
|
||||||
|
│ │ ├── lib/ # 工具和服务
|
||||||
|
│ │ └── store/ # 状态管理(Zustand)
|
||||||
|
│ └── tests/ # 测试套件
|
||||||
|
├── docs/ # 文档
|
||||||
|
├── ARCHITECTURE.md # 架构文档
|
||||||
|
└── docker-compose.yml # Docker 配置
|
||||||
|
```
|
||||||
|
|
||||||
|
## 📄 许可证
|
||||||
|
|
||||||
|
MIT
|
||||||
108
backend/.env.example
Normal file
108
backend/.env.example
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
# ===========================================
|
||||||
|
# Pixel Backend Environment Configuration
|
||||||
|
# ===========================================
|
||||||
|
# Copy this file to .env and fill in your values.
|
||||||
|
|
||||||
|
# ---- Server ----
|
||||||
|
NODE_ENV=development
|
||||||
|
PY_PORT=8000
|
||||||
|
|
||||||
|
# ---- Database ----
|
||||||
|
# Default: SQLite (backend/data/pixel.db)
|
||||||
|
# DATABASE_URL=postgresql://user:pass@localhost:5432/pixel
|
||||||
|
# DB_PATH= # override SQLite file path
|
||||||
|
# DATA_DIR= # override data directory
|
||||||
|
|
||||||
|
# Database connection pool
|
||||||
|
# DB_POOL_SIZE=20
|
||||||
|
# DB_MAX_OVERFLOW=10
|
||||||
|
# DB_POOL_TIMEOUT=30
|
||||||
|
# DB_POOL_RECYCLE=3600
|
||||||
|
# DB_POOL_PRE_PING=true
|
||||||
|
# SLOW_QUERY_THRESHOLD=1.0
|
||||||
|
|
||||||
|
# ---- Redis ----
|
||||||
|
REDIS_URL=redis://localhost:6379
|
||||||
|
REDIS_ENABLED=1
|
||||||
|
|
||||||
|
# ---- JWT Auth ----
|
||||||
|
# Auto-generated in dev; MUST set in production
|
||||||
|
# JWT_SECRET_KEY=your-secret-key-here
|
||||||
|
ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||||
|
REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||||
|
|
||||||
|
# ---- Encryption Key (for user API key storage) ----
|
||||||
|
# Auto-generated in dev; MUST set in production
|
||||||
|
# MASTER_ENCRYPTION_KEY=your-fernet-key-here
|
||||||
|
|
||||||
|
# ---- CORS ----
|
||||||
|
# CORS_ALLOWED_ORIGINS=https://your-app.example.com
|
||||||
|
# CORS_DEV_ALLOWED_ORIGINS=http://localhost:3000,http://127.0.0.1:3000
|
||||||
|
# ALLOW_DEV_ORIGINS=1
|
||||||
|
|
||||||
|
# ---- Storage (OSS) ----
|
||||||
|
STORAGE_TYPE=local
|
||||||
|
# OSS_REGION=oss-cn-shanghai
|
||||||
|
# OSS_ENDPOINT=oss-cn-shanghai.aliyuncs.com
|
||||||
|
# OSS_BUCKET=your-bucket-name
|
||||||
|
# ALIBABA_CLOUD_ACCESS_KEY_ID=your_key
|
||||||
|
# ALIBABA_CLOUD_ACCESS_KEY_SECRET=your_secret
|
||||||
|
|
||||||
|
# ---- Email (SMTP) ----
|
||||||
|
# SMTP_HOST=
|
||||||
|
# SMTP_PORT=587
|
||||||
|
# SMTP_USER=
|
||||||
|
# SMTP_PASSWORD=
|
||||||
|
# SMTP_FROM=
|
||||||
|
# SMTP_TLS=true
|
||||||
|
# FRONTEND_URL=http://localhost:3000
|
||||||
|
|
||||||
|
# ===========================================
|
||||||
|
# AI Provider API Keys
|
||||||
|
# All users share these system-level keys.
|
||||||
|
# ===========================================
|
||||||
|
|
||||||
|
# DashScope (Qwen LLM, Wanx Image, Z-Image)
|
||||||
|
# DASHSCOPE_API_KEY=sk-xxx
|
||||||
|
# DASHSCOPE_BASE_URL= # optional
|
||||||
|
|
||||||
|
# VolcEngine / 火山引擎 (Doubao LLM, video)
|
||||||
|
# VOLCENGINE_API_KEY=xxx
|
||||||
|
|
||||||
|
# Google (Gemini LLM)
|
||||||
|
# GOOGLE_API_KEY=xxx
|
||||||
|
|
||||||
|
# OpenAI
|
||||||
|
# OPENAI_API_KEY=sk-xxx
|
||||||
|
# OPENAI_BASE_URL= # optional, for proxies
|
||||||
|
|
||||||
|
# MiniMax / 海螺 (video, audio, music)
|
||||||
|
# MINIMAX_API_KEY=xxx
|
||||||
|
# MINIMAX_GROUP_ID=xxx
|
||||||
|
|
||||||
|
# Kling / 可灵 (video) — requires both access_key and secret_key
|
||||||
|
# KLING_ACCESS_KEY=xxx
|
||||||
|
# KLING_SECRET_KEY=xxx
|
||||||
|
# KLING_API_BASE=https://api-beijing.klingai.com/v1
|
||||||
|
|
||||||
|
# Midjourney / 有川 (image)
|
||||||
|
# MIDJOURNEY_API_KEY=xxx
|
||||||
|
# MIDJOURNEY_PROXY_URL=xxx
|
||||||
|
# YOUCHUAN_APP_ID=xxx
|
||||||
|
# YOUCHUAN_SECRET_KEY=xxx
|
||||||
|
|
||||||
|
# ModelScope (image, video)
|
||||||
|
# MODELSCOPE_API_TOKEN=xxx
|
||||||
|
|
||||||
|
# ---- Script Agent (AgentScope) ----
|
||||||
|
# Override keys specifically for script analysis agents
|
||||||
|
# SCRIPT_DASHSCOPE_API_KEY=xxx
|
||||||
|
# SCRIPT_DASHSCOPE_BASE_URL=xxx
|
||||||
|
# SCRIPT_OPENAI_API_KEY=xxx
|
||||||
|
# SCRIPT_OPENAI_BASE_URL=xxx
|
||||||
|
|
||||||
|
# ---- Monitoring ----
|
||||||
|
# ENABLE_METRICS=true
|
||||||
|
# LOG_LEVEL=INFO
|
||||||
|
# TRACING_ENABLED=0
|
||||||
|
# OTLP_ENDPOINT=http://localhost:4317
|
||||||
40
backend/Dockerfile
Normal file
40
backend/Dockerfile
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
FROM python:3.12-slim
|
||||||
|
|
||||||
|
# Install system dependencies
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
curl \
|
||||||
|
git \
|
||||||
|
libgl1 \
|
||||||
|
libglib2.0-0 \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Install uv
|
||||||
|
RUN pip install uv
|
||||||
|
|
||||||
|
# Set working directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy dependency files
|
||||||
|
COPY pyproject.toml uv.lock ./
|
||||||
|
|
||||||
|
# Create virtual environment and install dependencies
|
||||||
|
# Using --frozen to ensure strict adherence to uv.lock
|
||||||
|
RUN uv sync --frozen --no-install-project
|
||||||
|
|
||||||
|
# Add virtual environment to PATH
|
||||||
|
ENV PATH="/app/.venv/bin:$PATH"
|
||||||
|
|
||||||
|
# Copy source code
|
||||||
|
COPY src ./src
|
||||||
|
|
||||||
|
# Install the project itself (if needed)
|
||||||
|
RUN uv sync --frozen
|
||||||
|
|
||||||
|
# Create directories for data/uploads if they don't exist
|
||||||
|
RUN mkdir -p data/uploads
|
||||||
|
|
||||||
|
# Expose port
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
# Start command
|
||||||
|
CMD ["uv", "run", "uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
147
backend/alembic.ini
Normal file
147
backend/alembic.ini
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
# A generic, single database configuration.
|
||||||
|
|
||||||
|
[alembic]
|
||||||
|
# path to migration scripts.
|
||||||
|
# this is typically a path given in POSIX (e.g. forward slashes)
|
||||||
|
# format, relative to the token %(here)s which refers to the location of this
|
||||||
|
# ini file
|
||||||
|
script_location = %(here)s/alembic
|
||||||
|
|
||||||
|
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||||
|
# Uncomment the line below if you want the files to be prepended with date and time
|
||||||
|
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||||
|
# for all available tokens
|
||||||
|
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||||
|
|
||||||
|
# sys.path path, will be prepended to sys.path if present.
|
||||||
|
# defaults to the current working directory. for multiple paths, the path separator
|
||||||
|
# is defined by "path_separator" below.
|
||||||
|
prepend_sys_path = .
|
||||||
|
|
||||||
|
|
||||||
|
# timezone to use when rendering the date within the migration file
|
||||||
|
# as well as the filename.
|
||||||
|
# If specified, requires the tzdata library which can be installed by adding
|
||||||
|
# `alembic[tz]` to the pip requirements.
|
||||||
|
# string value is passed to ZoneInfo()
|
||||||
|
# leave blank for localtime
|
||||||
|
# timezone =
|
||||||
|
|
||||||
|
# max length of characters to apply to the "slug" field
|
||||||
|
# truncate_slug_length = 40
|
||||||
|
|
||||||
|
# set to 'true' to run the environment during
|
||||||
|
# the 'revision' command, regardless of autogenerate
|
||||||
|
# revision_environment = false
|
||||||
|
|
||||||
|
# set to 'true' to allow .pyc and .pyo files without
|
||||||
|
# a source .py file to be detected as revisions in the
|
||||||
|
# versions/ directory
|
||||||
|
# sourceless = false
|
||||||
|
|
||||||
|
# version location specification; This defaults
|
||||||
|
# to <script_location>/versions. When using multiple version
|
||||||
|
# directories, initial revisions must be specified with --version-path.
|
||||||
|
# The path separator used here should be the separator specified by "path_separator"
|
||||||
|
# below.
|
||||||
|
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
|
||||||
|
|
||||||
|
# path_separator; This indicates what character is used to split lists of file
|
||||||
|
# paths, including version_locations and prepend_sys_path within configparser
|
||||||
|
# files such as alembic.ini.
|
||||||
|
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
|
||||||
|
# to provide os-dependent path splitting.
|
||||||
|
#
|
||||||
|
# Note that in order to support legacy alembic.ini files, this default does NOT
|
||||||
|
# take place if path_separator is not present in alembic.ini. If this
|
||||||
|
# option is omitted entirely, fallback logic is as follows:
|
||||||
|
#
|
||||||
|
# 1. Parsing of the version_locations option falls back to using the legacy
|
||||||
|
# "version_path_separator" key, which if absent then falls back to the legacy
|
||||||
|
# behavior of splitting on spaces and/or commas.
|
||||||
|
# 2. Parsing of the prepend_sys_path option falls back to the legacy
|
||||||
|
# behavior of splitting on spaces, commas, or colons.
|
||||||
|
#
|
||||||
|
# Valid values for path_separator are:
|
||||||
|
#
|
||||||
|
# path_separator = :
|
||||||
|
# path_separator = ;
|
||||||
|
# path_separator = space
|
||||||
|
# path_separator = newline
|
||||||
|
#
|
||||||
|
# Use os.pathsep. Default configuration used for new projects.
|
||||||
|
path_separator = os
|
||||||
|
|
||||||
|
# set to 'true' to search source files recursively
|
||||||
|
# in each "version_locations" directory
|
||||||
|
# new in Alembic version 1.10
|
||||||
|
# recursive_version_locations = false
|
||||||
|
|
||||||
|
# the output encoding used when revision files
|
||||||
|
# are written from script.py.mako
|
||||||
|
# output_encoding = utf-8
|
||||||
|
|
||||||
|
# database URL. This is consumed by the user-maintained env.py script only.
|
||||||
|
# other means of configuring database URLs may be customized within the env.py
|
||||||
|
# file.
|
||||||
|
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||||
|
|
||||||
|
|
||||||
|
[post_write_hooks]
|
||||||
|
# post_write_hooks defines scripts or Python functions that are run
|
||||||
|
# on newly generated revision scripts. See the documentation for further
|
||||||
|
# detail and examples
|
||||||
|
|
||||||
|
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||||
|
# hooks = black
|
||||||
|
# black.type = console_scripts
|
||||||
|
# black.entrypoint = black
|
||||||
|
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||||
|
|
||||||
|
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
|
||||||
|
# hooks = ruff
|
||||||
|
# ruff.type = module
|
||||||
|
# ruff.module = ruff
|
||||||
|
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||||
|
|
||||||
|
# Alternatively, use the exec runner to execute a binary found on your PATH
|
||||||
|
# hooks = ruff
|
||||||
|
# ruff.type = exec
|
||||||
|
# ruff.executable = ruff
|
||||||
|
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||||
|
|
||||||
|
# Logging configuration. This is also consumed by the user-maintained
|
||||||
|
# env.py script only.
|
||||||
|
[loggers]
|
||||||
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys = console
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys = generic
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level = WARNING
|
||||||
|
handlers = console
|
||||||
|
qualname =
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level = WARNING
|
||||||
|
handlers =
|
||||||
|
qualname = sqlalchemy.engine
|
||||||
|
|
||||||
|
[logger_alembic]
|
||||||
|
level = INFO
|
||||||
|
handlers =
|
||||||
|
qualname = alembic
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
args = (sys.stderr,)
|
||||||
|
level = NOTSET
|
||||||
|
formatter = generic
|
||||||
|
|
||||||
|
[formatter_generic]
|
||||||
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
|
datefmt = %H:%M:%S
|
||||||
1
backend/alembic/README
Normal file
1
backend/alembic/README
Normal file
@@ -0,0 +1 @@
|
|||||||
|
Generic single-database configuration.
|
||||||
94
backend/alembic/env.py
Normal file
94
backend/alembic/env.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from logging.config import fileConfig
|
||||||
|
|
||||||
|
from sqlalchemy import engine_from_config
|
||||||
|
from sqlalchemy import pool
|
||||||
|
from sqlmodel import SQLModel
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
|
||||||
|
# Add backend directory to sys.path so we can import src
|
||||||
|
# Add project root to sys.path to allow importing backend
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
|
||||||
|
|
||||||
|
# Import models
|
||||||
|
from backend.src.models.entities import *
|
||||||
|
from backend.src.models.session import *
|
||||||
|
from backend.src.config.settings import DB_PATH, DATABASE_URL
|
||||||
|
|
||||||
|
# this is the Alembic Config object, which provides
|
||||||
|
# access to the values within the .ini file in use.
|
||||||
|
config = context.config
|
||||||
|
|
||||||
|
# Interpret the config file for Python logging.
|
||||||
|
# This line sets up loggers basically.
|
||||||
|
if config.config_file_name is not None:
|
||||||
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
|
# Override sqlalchemy.url with the one from config
|
||||||
|
if DATABASE_URL:
|
||||||
|
config.set_main_option("sqlalchemy.url", DATABASE_URL)
|
||||||
|
else:
|
||||||
|
config.set_main_option("sqlalchemy.url", f"sqlite:///{DB_PATH}")
|
||||||
|
|
||||||
|
# add your model's MetaData object here
|
||||||
|
# for 'autogenerate' support
|
||||||
|
target_metadata = SQLModel.metadata
|
||||||
|
|
||||||
|
# other values from the config, defined by the needs of env.py,
|
||||||
|
# can be acquired:
|
||||||
|
# my_important_option = config.get_main_option("my_important_option")
|
||||||
|
# ... etc.
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline() -> None:
|
||||||
|
"""Run migrations in 'offline' mode.
|
||||||
|
|
||||||
|
This configures the context with just a URL
|
||||||
|
and not an Engine, though an Engine is acceptable
|
||||||
|
here as well. By skipping the Engine creation
|
||||||
|
we don't even need a DBAPI to be available.
|
||||||
|
|
||||||
|
Calls to context.execute() here emit the given string to the
|
||||||
|
script output.
|
||||||
|
|
||||||
|
"""
|
||||||
|
url = config.get_main_option("sqlalchemy.url")
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
literal_binds=True,
|
||||||
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online() -> None:
|
||||||
|
"""Run migrations in 'online' mode.
|
||||||
|
|
||||||
|
In this scenario we need to create an Engine
|
||||||
|
and associate a connection with the context.
|
||||||
|
|
||||||
|
"""
|
||||||
|
connectable = engine_from_config(
|
||||||
|
config.get_section(config.config_ini_section, {}),
|
||||||
|
prefix="sqlalchemy.",
|
||||||
|
poolclass=pool.NullPool,
|
||||||
|
)
|
||||||
|
|
||||||
|
with connectable.connect() as connection:
|
||||||
|
context.configure(
|
||||||
|
connection=connection, target_metadata=target_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
run_migrations_online()
|
||||||
28
backend/alembic/script.py.mako
Normal file
28
backend/alembic/script.py.mako
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""${message}
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
${imports if imports else ""}
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = ${repr(up_revision)}
|
||||||
|
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||||
|
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
${downgrades if downgrades else "pass"}
|
||||||
339
backend/alembic/verify_migration.py
Normal file
339
backend/alembic/verify_migration.py
Normal file
@@ -0,0 +1,339 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Migration Verification Script for Canvas Metadata Table
|
||||||
|
|
||||||
|
This script verifies that the canvas_metadata migration was successful by:
|
||||||
|
1. Checking that the canvas_metadata table exists
|
||||||
|
2. Verifying all required columns exist with correct types
|
||||||
|
3. Checking that all indexes were created
|
||||||
|
4. Validating data migration from general_canvases
|
||||||
|
5. Validating data migration from asset canvases
|
||||||
|
6. Validating data migration from storyboard canvases
|
||||||
|
7. Checking data integrity and consistency
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from sqlalchemy import create_engine, inspect, text
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
import json
|
||||||
|
|
||||||
|
# Add backend directory to sys.path
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
|
||||||
|
|
||||||
|
from backend.src.config.settings import DB_PATH
|
||||||
|
|
||||||
|
|
||||||
|
class MigrationVerifier:
|
||||||
|
def __init__(self, db_path: str):
|
||||||
|
self.db_path = db_path
|
||||||
|
self.engine = create_engine(f"sqlite:///{db_path}")
|
||||||
|
self.inspector = inspect(self.engine)
|
||||||
|
self.errors = []
|
||||||
|
self.warnings = []
|
||||||
|
self.success_count = 0
|
||||||
|
self.total_checks = 0
|
||||||
|
|
||||||
|
def check(self, condition: bool, success_msg: str, error_msg: str):
|
||||||
|
"""Helper method to track check results"""
|
||||||
|
self.total_checks += 1
|
||||||
|
if condition:
|
||||||
|
self.success_count += 1
|
||||||
|
print(f"✅ {success_msg}")
|
||||||
|
else:
|
||||||
|
self.errors.append(error_msg)
|
||||||
|
print(f"❌ {error_msg}")
|
||||||
|
|
||||||
|
def warn(self, message: str):
|
||||||
|
"""Helper method to track warnings"""
|
||||||
|
self.warnings.append(message)
|
||||||
|
print(f"⚠️ {message}")
|
||||||
|
|
||||||
|
def verify_table_exists(self) -> bool:
|
||||||
|
"""Verify that canvas_metadata table exists"""
|
||||||
|
print("\n=== Checking Table Existence ===")
|
||||||
|
tables = self.inspector.get_table_names()
|
||||||
|
self.check(
|
||||||
|
'canvas_metadata' in tables,
|
||||||
|
"canvas_metadata table exists",
|
||||||
|
"canvas_metadata table does not exist"
|
||||||
|
)
|
||||||
|
return 'canvas_metadata' in tables
|
||||||
|
|
||||||
|
def verify_columns(self) -> bool:
|
||||||
|
"""Verify all required columns exist with correct types"""
|
||||||
|
print("\n=== Checking Columns ===")
|
||||||
|
|
||||||
|
required_columns = {
|
||||||
|
'id': 'VARCHAR',
|
||||||
|
'project_id': 'VARCHAR',
|
||||||
|
'canvas_type': 'VARCHAR',
|
||||||
|
'related_entity_type': 'VARCHAR',
|
||||||
|
'related_entity_id': 'VARCHAR',
|
||||||
|
'name': 'VARCHAR',
|
||||||
|
'description': 'VARCHAR',
|
||||||
|
'order_index': 'INTEGER',
|
||||||
|
'is_pinned': 'BOOLEAN',
|
||||||
|
'tags': 'JSON',
|
||||||
|
'node_count': 'INTEGER',
|
||||||
|
'last_accessed_at': 'FLOAT',
|
||||||
|
'access_count': 'INTEGER',
|
||||||
|
'created_at': 'FLOAT',
|
||||||
|
'updated_at': 'FLOAT',
|
||||||
|
'deleted_at': 'FLOAT',
|
||||||
|
'legacy_id': 'VARCHAR'
|
||||||
|
}
|
||||||
|
|
||||||
|
columns = self.inspector.get_columns('canvas_metadata')
|
||||||
|
column_dict = {col['name']: col for col in columns}
|
||||||
|
|
||||||
|
all_columns_exist = True
|
||||||
|
for col_name, expected_type in required_columns.items():
|
||||||
|
if col_name in column_dict:
|
||||||
|
col_type = str(column_dict[col_name]['type']).upper()
|
||||||
|
# SQLite stores JSON as TEXT, so we need to check for that
|
||||||
|
if expected_type == 'JSON' and 'TEXT' in col_type:
|
||||||
|
print(f"✅ Column '{col_name}' exists with type {col_type} (JSON stored as TEXT)")
|
||||||
|
elif expected_type in col_type or col_type in expected_type:
|
||||||
|
print(f"✅ Column '{col_name}' exists with type {col_type}")
|
||||||
|
else:
|
||||||
|
self.warn(f"Column '{col_name}' exists but type is {col_type}, expected {expected_type}")
|
||||||
|
else:
|
||||||
|
all_columns_exist = False
|
||||||
|
self.errors.append(f"Column '{col_name}' is missing")
|
||||||
|
print(f"❌ Column '{col_name}' is missing")
|
||||||
|
|
||||||
|
return all_columns_exist
|
||||||
|
|
||||||
|
def verify_indexes(self) -> bool:
|
||||||
|
"""Verify all required indexes exist"""
|
||||||
|
print("\n=== Checking Indexes ===")
|
||||||
|
|
||||||
|
required_indexes = [
|
||||||
|
'ix_canvas_metadata_project_id',
|
||||||
|
'ix_canvas_metadata_canvas_type',
|
||||||
|
'ix_canvas_metadata_related_entity_id',
|
||||||
|
'ix_canvas_metadata_legacy_id',
|
||||||
|
'ix_canvas_metadata_project_type',
|
||||||
|
'ix_canvas_metadata_type_entity'
|
||||||
|
]
|
||||||
|
|
||||||
|
indexes = self.inspector.get_indexes('canvas_metadata')
|
||||||
|
index_names = [idx['name'] for idx in indexes]
|
||||||
|
|
||||||
|
all_indexes_exist = True
|
||||||
|
for idx_name in required_indexes:
|
||||||
|
if idx_name in index_names:
|
||||||
|
print(f"✅ Index '{idx_name}' exists")
|
||||||
|
else:
|
||||||
|
all_indexes_exist = False
|
||||||
|
self.errors.append(f"Index '{idx_name}' is missing")
|
||||||
|
print(f"❌ Index '{idx_name}' is missing")
|
||||||
|
|
||||||
|
return all_indexes_exist
|
||||||
|
|
||||||
|
def verify_foreign_keys(self) -> bool:
|
||||||
|
"""Verify foreign key constraints"""
|
||||||
|
print("\n=== Checking Foreign Keys ===")
|
||||||
|
|
||||||
|
fks = self.inspector.get_foreign_keys('canvas_metadata')
|
||||||
|
|
||||||
|
has_project_fk = any(
|
||||||
|
fk['referred_table'] == 'projects' and 'project_id' in fk['constrained_columns']
|
||||||
|
for fk in fks
|
||||||
|
)
|
||||||
|
|
||||||
|
self.check(
|
||||||
|
has_project_fk,
|
||||||
|
"Foreign key to projects table exists",
|
||||||
|
"Foreign key to projects table is missing"
|
||||||
|
)
|
||||||
|
|
||||||
|
return has_project_fk
|
||||||
|
|
||||||
|
def verify_data_migration(self) -> bool:
|
||||||
|
"""Verify data was migrated correctly"""
|
||||||
|
print("\n=== Checking Data Migration ===")
|
||||||
|
|
||||||
|
with self.engine.connect() as conn:
|
||||||
|
# Check if any canvas_metadata records exist
|
||||||
|
result = conn.execute(text("SELECT COUNT(*) FROM canvas_metadata"))
|
||||||
|
count = result.scalar()
|
||||||
|
|
||||||
|
if count > 0:
|
||||||
|
print(f"✅ Found {count} canvas metadata records")
|
||||||
|
|
||||||
|
# Check general canvases
|
||||||
|
result = conn.execute(text(
|
||||||
|
"SELECT COUNT(*) FROM canvas_metadata WHERE canvas_type = 'general'"
|
||||||
|
))
|
||||||
|
general_count = result.scalar()
|
||||||
|
print(f" - General canvases: {general_count}")
|
||||||
|
|
||||||
|
# Check asset canvases
|
||||||
|
result = conn.execute(text(
|
||||||
|
"SELECT COUNT(*) FROM canvas_metadata WHERE canvas_type = 'asset'"
|
||||||
|
))
|
||||||
|
asset_count = result.scalar()
|
||||||
|
print(f" - Asset canvases: {asset_count}")
|
||||||
|
|
||||||
|
# Check storyboard canvases
|
||||||
|
result = conn.execute(text(
|
||||||
|
"SELECT COUNT(*) FROM canvas_metadata WHERE canvas_type = 'storyboard'"
|
||||||
|
))
|
||||||
|
storyboard_count = result.scalar()
|
||||||
|
print(f" - Storyboard canvases: {storyboard_count}")
|
||||||
|
|
||||||
|
# Verify legacy_id mapping for migrated canvases
|
||||||
|
result = conn.execute(text(
|
||||||
|
"SELECT COUNT(*) FROM canvas_metadata WHERE legacy_id IS NOT NULL"
|
||||||
|
))
|
||||||
|
legacy_count = result.scalar()
|
||||||
|
if legacy_count > 0:
|
||||||
|
print(f"✅ Found {legacy_count} canvases with legacy_id mapping")
|
||||||
|
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
self.warn("No canvas metadata records found (this is OK if database is empty)")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def verify_data_integrity(self) -> bool:
|
||||||
|
"""Verify data integrity constraints"""
|
||||||
|
print("\n=== Checking Data Integrity ===")
|
||||||
|
|
||||||
|
with self.engine.connect() as conn:
|
||||||
|
# Check if canvases table exists
|
||||||
|
tables = self.inspector.get_table_names()
|
||||||
|
if 'canvases' not in tables:
|
||||||
|
self.warn("canvases table does not exist yet - skipping canvas content check")
|
||||||
|
orphaned_metadata = 0
|
||||||
|
else:
|
||||||
|
# Check that all canvas_metadata records have corresponding canvas content
|
||||||
|
result = conn.execute(text("""
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM canvas_metadata cm
|
||||||
|
LEFT JOIN canvases c ON cm.id = c.id
|
||||||
|
WHERE c.id IS NULL
|
||||||
|
"""))
|
||||||
|
orphaned_metadata = result.scalar()
|
||||||
|
|
||||||
|
self.check(
|
||||||
|
orphaned_metadata == 0,
|
||||||
|
f"All canvas metadata records have corresponding canvas content",
|
||||||
|
f"Found {orphaned_metadata} canvas metadata records without canvas content"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that related_entity_id is set for asset and storyboard canvases
|
||||||
|
result = conn.execute(text("""
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM canvas_metadata
|
||||||
|
WHERE canvas_type IN ('asset', 'storyboard')
|
||||||
|
AND related_entity_id IS NULL
|
||||||
|
"""))
|
||||||
|
missing_entity_id = result.scalar()
|
||||||
|
|
||||||
|
self.check(
|
||||||
|
missing_entity_id == 0,
|
||||||
|
"All asset/storyboard canvases have related_entity_id",
|
||||||
|
f"Found {missing_entity_id} asset/storyboard canvases without related_entity_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that general canvases don't have related_entity_id
|
||||||
|
result = conn.execute(text("""
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM canvas_metadata
|
||||||
|
WHERE canvas_type = 'general'
|
||||||
|
AND related_entity_id IS NOT NULL
|
||||||
|
"""))
|
||||||
|
invalid_general = result.scalar()
|
||||||
|
|
||||||
|
self.check(
|
||||||
|
invalid_general == 0,
|
||||||
|
"General canvases don't have related_entity_id",
|
||||||
|
f"Found {invalid_general} general canvases with related_entity_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
return orphaned_metadata == 0 and missing_entity_id == 0 and invalid_general == 0
|
||||||
|
|
||||||
|
def verify_project_relationship(self) -> bool:
|
||||||
|
"""Verify that ProjectDB relationship is working"""
|
||||||
|
print("\n=== Checking Project Relationship ===")
|
||||||
|
|
||||||
|
with self.engine.connect() as conn:
|
||||||
|
# Check that all canvas_metadata records reference valid projects
|
||||||
|
result = conn.execute(text("""
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM canvas_metadata cm
|
||||||
|
LEFT JOIN projects p ON cm.project_id = p.id
|
||||||
|
WHERE p.id IS NULL
|
||||||
|
"""))
|
||||||
|
orphaned_canvases = result.scalar()
|
||||||
|
|
||||||
|
self.check(
|
||||||
|
orphaned_canvases == 0,
|
||||||
|
"All canvas metadata records reference valid projects",
|
||||||
|
f"Found {orphaned_canvases} canvas metadata records with invalid project_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
return orphaned_canvases == 0
|
||||||
|
|
||||||
|
def run_all_checks(self) -> bool:
|
||||||
|
"""Run all verification checks"""
|
||||||
|
print("=" * 60)
|
||||||
|
print("Canvas Metadata Migration Verification")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"Database: {self.db_path}")
|
||||||
|
|
||||||
|
if not os.path.exists(self.db_path):
|
||||||
|
print(f"\n❌ Database file does not exist: {self.db_path}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Run all checks
|
||||||
|
table_exists = self.verify_table_exists()
|
||||||
|
if not table_exists:
|
||||||
|
print("\n❌ Cannot continue verification - table does not exist")
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.verify_columns()
|
||||||
|
self.verify_indexes()
|
||||||
|
self.verify_foreign_keys()
|
||||||
|
self.verify_data_migration()
|
||||||
|
self.verify_data_integrity()
|
||||||
|
self.verify_project_relationship()
|
||||||
|
|
||||||
|
# Print summary
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Verification Summary")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"Total checks: {self.total_checks}")
|
||||||
|
print(f"Passed: {self.success_count}")
|
||||||
|
print(f"Failed: {len(self.errors)}")
|
||||||
|
print(f"Warnings: {len(self.warnings)}")
|
||||||
|
|
||||||
|
if self.errors:
|
||||||
|
print("\n❌ Errors:")
|
||||||
|
for error in self.errors:
|
||||||
|
print(f" - {error}")
|
||||||
|
|
||||||
|
if self.warnings:
|
||||||
|
print("\n⚠️ Warnings:")
|
||||||
|
for warning in self.warnings:
|
||||||
|
print(f" - {warning}")
|
||||||
|
|
||||||
|
if len(self.errors) == 0:
|
||||||
|
print("\n✅ All checks passed! Migration is successful.")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
print("\n❌ Migration verification failed. Please review the errors above.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main entry point"""
|
||||||
|
verifier = MigrationVerifier(DB_PATH)
|
||||||
|
success = verifier.run_all_checks()
|
||||||
|
sys.exit(0 if success else 1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
127
backend/alembic/versions/72f609dd9e66_initial_schema.py
Normal file
127
backend/alembic/versions/72f609dd9e66_initial_schema.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
"""Initial schema
|
||||||
|
|
||||||
|
Revision ID: 72f609dd9e66
|
||||||
|
Revises:
|
||||||
|
Create Date: 2026-01-08 09:52:59.473436
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
import sqlmodel
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '72f609dd9e66'
|
||||||
|
down_revision: Union[str, Sequence[str], None] = None
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('projects',
|
||||||
|
sa.Column('id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||||
|
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||||
|
sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
|
sa.Column('type', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||||
|
sa.Column('status', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.Float(), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.Float(), nullable=False),
|
||||||
|
sa.Column('resolution', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
|
sa.Column('ratio', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
|
sa.Column('style_preset', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
|
sa.Column('style_params', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('chapters', sa.JSON(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_table('tasks',
|
||||||
|
sa.Column('id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||||
|
sa.Column('type', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||||
|
sa.Column('status', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.Float(), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.Float(), nullable=False),
|
||||||
|
sa.Column('model', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
|
sa.Column('params', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('provider_task_id', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
|
sa.Column('result', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('error', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_table('assets',
|
||||||
|
sa.Column('id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||||
|
sa.Column('project_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||||
|
sa.Column('type', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||||
|
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||||
|
sa.Column('desc', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||||
|
sa.Column('tags', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('image_url', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
|
sa.Column('image_urls', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('video_urls', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('extra_data', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('generations', sa.JSON(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_assets_project_id'), 'assets', ['project_id'], unique=False)
|
||||||
|
op.create_table('episodes',
|
||||||
|
sa.Column('id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||||
|
sa.Column('project_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||||
|
sa.Column('order_index', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('title', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||||
|
sa.Column('desc', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
|
sa.Column('content', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
|
sa.Column('status', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_episodes_project_id'), 'episodes', ['project_id'], unique=False)
|
||||||
|
op.create_table('storyboards',
|
||||||
|
sa.Column('id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||||
|
sa.Column('project_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||||
|
sa.Column('episode_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||||
|
sa.Column('order_index', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('shot', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||||
|
sa.Column('desc', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||||
|
sa.Column('duration', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||||
|
sa.Column('type', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||||
|
sa.Column('scene_id', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
|
sa.Column('character_ids', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('prop_ids', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('voiceover', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
|
sa.Column('audio_desc', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
|
sa.Column('audio_url', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
|
sa.Column('camera_movement', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
|
sa.Column('transition', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
|
sa.Column('visual_anchor', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
|
sa.Column('visual_dynamics', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
|
sa.Column('director_note', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
|
sa.Column('image_prompt', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
|
sa.Column('video_script', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
|
sa.Column('image_urls', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('video_urls', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('generations', sa.JSON(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['episode_id'], ['episodes.id'], ),
|
||||||
|
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_storyboards_episode_id'), 'storyboards', ['episode_id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_storyboards_project_id'), 'storyboards', ['project_id'], unique=False)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_index(op.f('ix_storyboards_project_id'), table_name='storyboards')
|
||||||
|
op.drop_index(op.f('ix_storyboards_episode_id'), table_name='storyboards')
|
||||||
|
op.drop_table('storyboards')
|
||||||
|
op.drop_index(op.f('ix_episodes_project_id'), table_name='episodes')
|
||||||
|
op.drop_table('episodes')
|
||||||
|
op.drop_index(op.f('ix_assets_project_id'), table_name='assets')
|
||||||
|
op.drop_table('assets')
|
||||||
|
op.drop_table('tasks')
|
||||||
|
op.drop_table('projects')
|
||||||
|
# ### end Alembic commands ###
|
||||||
245
backend/alembic/versions/add_canvas_metadata_table.py
Normal file
245
backend/alembic/versions/add_canvas_metadata_table.py
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
"""add canvas metadata table
|
||||||
|
|
||||||
|
Revision ID: add_canvas_metadata
|
||||||
|
Revises: bfac9b8e32f5
|
||||||
|
Create Date: 2026-01-17 10:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
import sqlmodel
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'add_canvas_metadata'
|
||||||
|
down_revision: Union[str, Sequence[str], None] = ('add_progress_tracking', 'add_prompt_fields')
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
# 1. 创建 canvas_metadata 表
|
||||||
|
op.create_table(
|
||||||
|
'canvas_metadata',
|
||||||
|
sa.Column('id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||||
|
sa.Column('project_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||||
|
sa.Column('canvas_type', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||||
|
sa.Column('related_entity_type', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
|
sa.Column('related_entity_id', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
|
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||||
|
sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
|
sa.Column('order_index', sa.Integer(), nullable=False, server_default='0'),
|
||||||
|
sa.Column('is_pinned', sa.Boolean(), nullable=False, server_default='0'),
|
||||||
|
sa.Column('tags', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('node_count', sa.Integer(), nullable=False, server_default='0'),
|
||||||
|
sa.Column('last_accessed_at', sa.Float(), nullable=True),
|
||||||
|
sa.Column('access_count', sa.Integer(), nullable=False, server_default='0'),
|
||||||
|
sa.Column('created_at', sa.Float(), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.Float(), nullable=False),
|
||||||
|
sa.Column('deleted_at', sa.Float(), nullable=True),
|
||||||
|
sa.Column('legacy_id', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
sa.ForeignKeyConstraint(['project_id'], ['projects.id'])
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. 创建索引
|
||||||
|
op.create_index('ix_canvas_metadata_project_id', 'canvas_metadata', ['project_id'])
|
||||||
|
op.create_index('ix_canvas_metadata_canvas_type', 'canvas_metadata', ['canvas_type'])
|
||||||
|
op.create_index('ix_canvas_metadata_related_entity_id', 'canvas_metadata', ['related_entity_id'])
|
||||||
|
op.create_index('ix_canvas_metadata_legacy_id', 'canvas_metadata', ['legacy_id'])
|
||||||
|
op.create_index('ix_canvas_metadata_project_type', 'canvas_metadata', ['project_id', 'canvas_type'])
|
||||||
|
op.create_index('ix_canvas_metadata_type_entity', 'canvas_metadata', ['canvas_type', 'related_entity_id'])
|
||||||
|
|
||||||
|
# 3. 迁移数据
|
||||||
|
migrate_general_canvases()
|
||||||
|
migrate_asset_canvases()
|
||||||
|
migrate_storyboard_canvases()
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_general_canvases():
|
||||||
|
"""迁移通用画布数据"""
|
||||||
|
conn = op.get_bind()
|
||||||
|
|
||||||
|
# 获取所有项目的 general_canvases
|
||||||
|
try:
|
||||||
|
projects = conn.execute(sa.text("SELECT id, general_canvases FROM projects")).fetchall()
|
||||||
|
except:
|
||||||
|
# 如果 general_canvases 列不存在,跳过
|
||||||
|
return
|
||||||
|
|
||||||
|
for project in projects:
|
||||||
|
project_id = project[0]
|
||||||
|
general_canvases_json = project[1]
|
||||||
|
|
||||||
|
if not general_canvases_json:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
canvases = json.loads(general_canvases_json) if isinstance(general_canvases_json, str) else general_canvases_json
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not isinstance(canvases, list):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for idx, canvas in enumerate(canvases):
|
||||||
|
canvas_id = canvas.get('id')
|
||||||
|
if not canvas_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 插入到 canvas_metadata
|
||||||
|
conn.execute(sa.text("""
|
||||||
|
INSERT INTO canvas_metadata (
|
||||||
|
id, project_id, canvas_type, name, order_index,
|
||||||
|
created_at, updated_at
|
||||||
|
) VALUES (
|
||||||
|
:id, :project_id, 'general', :name, :order_index,
|
||||||
|
:created_at, :updated_at
|
||||||
|
)
|
||||||
|
"""), {
|
||||||
|
'id': canvas_id,
|
||||||
|
'project_id': project_id,
|
||||||
|
'name': canvas.get('name', f'Canvas {idx + 1}'),
|
||||||
|
'order_index': idx,
|
||||||
|
'created_at': canvas.get('createdAt', datetime.now().timestamp()),
|
||||||
|
'updated_at': canvas.get('updatedAt', datetime.now().timestamp())
|
||||||
|
})
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_asset_canvases():
|
||||||
|
"""迁移素材画布数据"""
|
||||||
|
conn = op.get_bind()
|
||||||
|
|
||||||
|
# 查找所有以 canvas-asset- 开头的画布
|
||||||
|
try:
|
||||||
|
canvases = conn.execute(sa.text("""
|
||||||
|
SELECT id, project_id, updated_at
|
||||||
|
FROM canvases
|
||||||
|
WHERE id LIKE 'canvas-asset-%'
|
||||||
|
""")).fetchall()
|
||||||
|
except:
|
||||||
|
return
|
||||||
|
|
||||||
|
for canvas in canvases:
|
||||||
|
old_id = canvas[0]
|
||||||
|
project_id = canvas[1]
|
||||||
|
updated_at = canvas[2]
|
||||||
|
|
||||||
|
# 提取 asset_id
|
||||||
|
asset_id = old_id.replace('canvas-asset-', '')
|
||||||
|
|
||||||
|
# 查找对应的 asset
|
||||||
|
try:
|
||||||
|
asset = conn.execute(sa.text("""
|
||||||
|
SELECT name FROM assets WHERE id = :asset_id
|
||||||
|
"""), {'asset_id': asset_id}).fetchone()
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not asset:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 生成新 UUID
|
||||||
|
new_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# 插入元数据
|
||||||
|
conn.execute(sa.text("""
|
||||||
|
INSERT INTO canvas_metadata (
|
||||||
|
id, project_id, canvas_type, related_entity_type,
|
||||||
|
related_entity_id, name, created_at, updated_at, legacy_id
|
||||||
|
) VALUES (
|
||||||
|
:id, :project_id, 'asset', 'asset',
|
||||||
|
:asset_id, :name, :created_at, :updated_at, :legacy_id
|
||||||
|
)
|
||||||
|
"""), {
|
||||||
|
'id': new_id,
|
||||||
|
'project_id': project_id,
|
||||||
|
'asset_id': asset_id,
|
||||||
|
'name': asset[0],
|
||||||
|
'created_at': updated_at,
|
||||||
|
'updated_at': updated_at,
|
||||||
|
'legacy_id': old_id
|
||||||
|
})
|
||||||
|
|
||||||
|
# 更新 canvases 表的 ID
|
||||||
|
conn.execute(sa.text("""
|
||||||
|
UPDATE canvases SET id = :new_id WHERE id = :old_id
|
||||||
|
"""), {'new_id': new_id, 'old_id': old_id})
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_storyboard_canvases():
|
||||||
|
"""迁移分镜画布数据"""
|
||||||
|
conn = op.get_bind()
|
||||||
|
|
||||||
|
# 查找所有以 canvas-storyboard- 开头的画布
|
||||||
|
try:
|
||||||
|
canvases = conn.execute(sa.text("""
|
||||||
|
SELECT id, project_id, updated_at
|
||||||
|
FROM canvases
|
||||||
|
WHERE id LIKE 'canvas-storyboard-%'
|
||||||
|
""")).fetchall()
|
||||||
|
except:
|
||||||
|
return
|
||||||
|
|
||||||
|
for canvas in canvases:
|
||||||
|
old_id = canvas[0]
|
||||||
|
project_id = canvas[1]
|
||||||
|
updated_at = canvas[2]
|
||||||
|
|
||||||
|
storyboard_id = old_id.replace('canvas-storyboard-', '')
|
||||||
|
|
||||||
|
try:
|
||||||
|
storyboard = conn.execute(sa.text("""
|
||||||
|
SELECT shot FROM storyboards WHERE id = :storyboard_id
|
||||||
|
"""), {'storyboard_id': storyboard_id}).fetchone()
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not storyboard:
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
conn.execute(sa.text("""
|
||||||
|
INSERT INTO canvas_metadata (
|
||||||
|
id, project_id, canvas_type, related_entity_type,
|
||||||
|
related_entity_id, name, created_at, updated_at, legacy_id
|
||||||
|
) VALUES (
|
||||||
|
:id, :project_id, 'storyboard', 'storyboard',
|
||||||
|
:storyboard_id, :name, :created_at, :updated_at, :legacy_id
|
||||||
|
)
|
||||||
|
"""), {
|
||||||
|
'id': new_id,
|
||||||
|
'project_id': project_id,
|
||||||
|
'storyboard_id': storyboard_id,
|
||||||
|
'name': storyboard[0],
|
||||||
|
'created_at': updated_at,
|
||||||
|
'updated_at': updated_at,
|
||||||
|
'legacy_id': old_id
|
||||||
|
})
|
||||||
|
|
||||||
|
conn.execute(sa.text("""
|
||||||
|
UPDATE canvases SET id = :new_id WHERE id = :old_id
|
||||||
|
"""), {'new_id': new_id, 'old_id': old_id})
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
# 回滚操作
|
||||||
|
op.drop_index('ix_canvas_metadata_type_entity', 'canvas_metadata')
|
||||||
|
op.drop_index('ix_canvas_metadata_project_type', 'canvas_metadata')
|
||||||
|
op.drop_index('ix_canvas_metadata_legacy_id', 'canvas_metadata')
|
||||||
|
op.drop_index('ix_canvas_metadata_related_entity_id', 'canvas_metadata')
|
||||||
|
op.drop_index('ix_canvas_metadata_canvas_type', 'canvas_metadata')
|
||||||
|
op.drop_index('ix_canvas_metadata_project_id', 'canvas_metadata')
|
||||||
|
op.drop_table('canvas_metadata')
|
||||||
41
backend/alembic/versions/add_cinematic_fields.py
Normal file
41
backend/alembic/versions/add_cinematic_fields.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
"""add cinematic and professional fields to assets and storyboards
|
||||||
|
|
||||||
|
Revision ID: add_cinematic_fields
|
||||||
|
Revises: add_prompt_fields
|
||||||
|
Create Date: 2026-01-20
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'add_cinematic_fields'
|
||||||
|
down_revision = 'add_canvas_metadata'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Assets表已经使用extra_data存储这些字段,但为了查询效率,我们可以选择不添加直接列
|
||||||
|
# 因为Asset的emotion, environment_type, weather等字段已经通过extra_data JSON存储
|
||||||
|
# 如果未来需要索引查询,可以添加:
|
||||||
|
# op.add_column('assets', sa.Column('emotion', sa.String(), nullable=True))
|
||||||
|
# op.add_column('assets', sa.Column('environment_type', sa.String(), nullable=True))
|
||||||
|
# op.add_column('assets', sa.Column('weather', sa.String(), nullable=True))
|
||||||
|
|
||||||
|
# Add cinematic control fields to storyboards table
|
||||||
|
op.add_column('storyboards', sa.Column('camera_angle', sa.String(), nullable=True))
|
||||||
|
op.add_column('storyboards', sa.Column('lens', sa.String(), nullable=True))
|
||||||
|
op.add_column('storyboards', sa.Column('focus', sa.String(), nullable=True))
|
||||||
|
op.add_column('storyboards', sa.Column('lighting', sa.String(), nullable=True))
|
||||||
|
op.add_column('storyboards', sa.Column('color_style', sa.String(), nullable=True))
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# Remove cinematic fields from storyboards
|
||||||
|
op.drop_column('storyboards', 'color_style')
|
||||||
|
op.drop_column('storyboards', 'lighting')
|
||||||
|
op.drop_column('storyboards', 'focus')
|
||||||
|
op.drop_column('storyboards', 'lens')
|
||||||
|
op.drop_column('storyboards', 'camera_angle')
|
||||||
100
backend/alembic/versions/add_indexes_and_optimizations.py
Normal file
100
backend/alembic/versions/add_indexes_and_optimizations.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
"""Add indexes and database optimizations
|
||||||
|
|
||||||
|
Revision ID: add_indexes_opt
|
||||||
|
Revises: bfac9b8e32f5
|
||||||
|
Create Date: 2026-01-14 10:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'add_indexes_opt'
|
||||||
|
down_revision: Union[str, Sequence[str], None] = 'bfac9b8e32f5'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Add indexes, soft delete columns, and full-text search support."""
|
||||||
|
|
||||||
|
# Add soft delete columns
|
||||||
|
op.add_column('projects', sa.Column('deleted_at', sa.Float(), nullable=True))
|
||||||
|
op.add_column('assets', sa.Column('deleted_at', sa.Float(), nullable=True))
|
||||||
|
op.add_column('episodes', sa.Column('deleted_at', sa.Float(), nullable=True))
|
||||||
|
op.add_column('storyboards', sa.Column('deleted_at', sa.Float(), nullable=True))
|
||||||
|
op.add_column('tasks', sa.Column('deleted_at', sa.Float(), nullable=True))
|
||||||
|
|
||||||
|
# Add indexes for frequently queried fields on projects
|
||||||
|
op.create_index('idx_projects_created_at', 'projects', ['created_at'])
|
||||||
|
op.create_index('idx_projects_updated_at', 'projects', ['updated_at'])
|
||||||
|
op.create_index('idx_projects_status', 'projects', ['status'])
|
||||||
|
op.create_index('idx_projects_deleted_at', 'projects', ['deleted_at'])
|
||||||
|
|
||||||
|
# Add indexes for tasks
|
||||||
|
op.create_index('idx_tasks_status', 'tasks', ['status'])
|
||||||
|
op.create_index('idx_tasks_type', 'tasks', ['type'])
|
||||||
|
op.create_index('idx_tasks_created_at', 'tasks', ['created_at'])
|
||||||
|
op.create_index('idx_tasks_type_status', 'tasks', ['type', 'status'])
|
||||||
|
op.create_index('idx_tasks_deleted_at', 'tasks', ['deleted_at'])
|
||||||
|
|
||||||
|
# Add indexes for assets
|
||||||
|
op.create_index('idx_assets_type', 'assets', ['type'])
|
||||||
|
op.create_index('idx_assets_deleted_at', 'assets', ['deleted_at'])
|
||||||
|
|
||||||
|
# Add indexes for episodes
|
||||||
|
op.create_index('idx_episodes_status', 'episodes', ['status'])
|
||||||
|
op.create_index('idx_episodes_order_index', 'episodes', ['order_index'])
|
||||||
|
op.create_index('idx_episodes_deleted_at', 'episodes', ['deleted_at'])
|
||||||
|
|
||||||
|
# Add indexes for storyboards
|
||||||
|
op.create_index('idx_storyboards_type', 'storyboards', ['type'])
|
||||||
|
op.create_index('idx_storyboards_order_index', 'storyboards', ['order_index'])
|
||||||
|
op.create_index('idx_storyboards_deleted_at', 'storyboards', ['deleted_at'])
|
||||||
|
|
||||||
|
# Note: SQLite doesn't support full-text search indexes like PostgreSQL
|
||||||
|
# For SQLite, we'll use the FTS5 virtual table approach in the application layer
|
||||||
|
# or use LIKE queries with indexes on the name columns
|
||||||
|
# Adding index on name columns for better LIKE query performance
|
||||||
|
op.create_index('idx_projects_name', 'projects', ['name'])
|
||||||
|
op.create_index('idx_assets_name', 'assets', ['name'])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Remove indexes, soft delete columns, and full-text search support."""
|
||||||
|
|
||||||
|
# Drop indexes
|
||||||
|
op.drop_index('idx_assets_name', table_name='assets')
|
||||||
|
op.drop_index('idx_projects_name', table_name='projects')
|
||||||
|
|
||||||
|
op.drop_index('idx_storyboards_deleted_at', table_name='storyboards')
|
||||||
|
op.drop_index('idx_storyboards_order_index', table_name='storyboards')
|
||||||
|
op.drop_index('idx_storyboards_type', table_name='storyboards')
|
||||||
|
|
||||||
|
op.drop_index('idx_episodes_deleted_at', table_name='episodes')
|
||||||
|
op.drop_index('idx_episodes_order_index', table_name='episodes')
|
||||||
|
op.drop_index('idx_episodes_status', table_name='episodes')
|
||||||
|
|
||||||
|
op.drop_index('idx_assets_deleted_at', table_name='assets')
|
||||||
|
op.drop_index('idx_assets_type', table_name='assets')
|
||||||
|
|
||||||
|
op.drop_index('idx_tasks_deleted_at', table_name='tasks')
|
||||||
|
op.drop_index('idx_tasks_type_status', table_name='tasks')
|
||||||
|
op.drop_index('idx_tasks_created_at', table_name='tasks')
|
||||||
|
op.drop_index('idx_tasks_type', table_name='tasks')
|
||||||
|
op.drop_index('idx_tasks_status', table_name='tasks')
|
||||||
|
|
||||||
|
op.drop_index('idx_projects_deleted_at', table_name='projects')
|
||||||
|
op.drop_index('idx_projects_status', table_name='projects')
|
||||||
|
op.drop_index('idx_projects_updated_at', table_name='projects')
|
||||||
|
op.drop_index('idx_projects_created_at', table_name='projects')
|
||||||
|
|
||||||
|
# Drop soft delete columns
|
||||||
|
op.drop_column('tasks', 'deleted_at')
|
||||||
|
op.drop_column('storyboards', 'deleted_at')
|
||||||
|
op.drop_column('episodes', 'deleted_at')
|
||||||
|
op.drop_column('assets', 'deleted_at')
|
||||||
|
op.drop_column('projects', 'deleted_at')
|
||||||
30
backend/alembic/versions/add_progress_tracking.py
Normal file
30
backend/alembic/versions/add_progress_tracking.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""add progress tracking fields
|
||||||
|
|
||||||
|
Revision ID: add_progress_tracking
|
||||||
|
Revises: add_task_mgmt_fields
|
||||||
|
Create Date: 2026-01-16 15:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import sqlite
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'add_progress_tracking'
|
||||||
|
down_revision = 'add_task_mgmt_fields'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Add progress and error columns to projects table
|
||||||
|
with op.batch_alter_table('projects', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('progress', sa.JSON(), nullable=True))
|
||||||
|
batch_op.add_column(sa.Column('error', sa.JSON(), nullable=True))
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# Remove progress and error columns from projects table
|
||||||
|
with op.batch_alter_table('projects', schema=None) as batch_op:
|
||||||
|
batch_op.drop_column('error')
|
||||||
|
batch_op.drop_column('progress')
|
||||||
36
backend/alembic/versions/add_prompt_fields.py
Normal file
36
backend/alembic/versions/add_prompt_fields.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
"""add prompt fields to assets and storyboards
|
||||||
|
|
||||||
|
Revision ID: add_prompt_fields
|
||||||
|
Revises: add_task_mgmt_fields
|
||||||
|
Create Date: 2026-01-16
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'add_prompt_fields'
|
||||||
|
down_revision = 'add_task_mgmt_fields'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Add image_prompt to assets table
|
||||||
|
op.add_column('assets', sa.Column('image_prompt', sa.String(), nullable=True))
|
||||||
|
|
||||||
|
# Add prompt fields to storyboards table
|
||||||
|
op.add_column('storyboards', sa.Column('original_text', sa.String(), nullable=True))
|
||||||
|
op.add_column('storyboards', sa.Column('merge_image_prompt', sa.String(), nullable=True))
|
||||||
|
op.add_column('storyboards', sa.Column('video_prompt', sa.String(), nullable=True))
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# Remove fields from storyboards
|
||||||
|
op.drop_column('storyboards', 'video_prompt')
|
||||||
|
op.drop_column('storyboards', 'merge_image_prompt')
|
||||||
|
op.drop_column('storyboards', 'original_text')
|
||||||
|
|
||||||
|
# Remove field from assets
|
||||||
|
op.drop_column('assets', 'image_prompt')
|
||||||
42
backend/alembic/versions/add_provider_to_tasks.py
Normal file
42
backend/alembic/versions/add_provider_to_tasks.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""add provider to tasks
|
||||||
|
|
||||||
|
Revision ID: add_provider_to_tasks
|
||||||
|
Revises: bfac9b8e32f5
|
||||||
|
Create Date: 2024-02-11
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'add_provider_to_tasks'
|
||||||
|
down_revision = 'bfac9b8e32f5'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Add provider column to tasks table"""
|
||||||
|
# Add provider column (nullable, indexed)
|
||||||
|
op.add_column('tasks', sa.Column('provider', sa.String(), nullable=True))
|
||||||
|
|
||||||
|
# Create index on provider column for faster queries
|
||||||
|
op.create_index(op.f('ix_tasks_provider'), 'tasks', ['provider'], unique=False)
|
||||||
|
|
||||||
|
# Optional: Migrate existing data by extracting provider from params
|
||||||
|
# This is a data migration that can be run separately if needed
|
||||||
|
op.execute("""
|
||||||
|
UPDATE tasks
|
||||||
|
SET provider = params->>'provider'
|
||||||
|
WHERE params->>'provider' IS NOT NULL
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Remove provider column from tasks table"""
|
||||||
|
# Drop index first
|
||||||
|
op.drop_index(op.f('ix_tasks_provider'), table_name='tasks')
|
||||||
|
|
||||||
|
# Drop column
|
||||||
|
op.drop_column('tasks', 'provider')
|
||||||
50
backend/alembic/versions/add_task_management_fields.py
Normal file
50
backend/alembic/versions/add_task_management_fields.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
"""add task management fields
|
||||||
|
|
||||||
|
Revision ID: add_task_mgmt_fields
|
||||||
|
Revises: add_indexes_opt
|
||||||
|
Create Date: 2026-01-14
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'add_task_mgmt_fields'
|
||||||
|
down_revision = 'add_indexes_opt'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# Add retry configuration fields
|
||||||
|
op.add_column('tasks', sa.Column('retry_count', sa.Integer(), nullable=False, server_default='0'))
|
||||||
|
op.add_column('tasks', sa.Column('max_retries', sa.Integer(), nullable=False, server_default='3'))
|
||||||
|
|
||||||
|
# Add timestamp fields for task lifecycle
|
||||||
|
op.add_column('tasks', sa.Column('started_at', sa.Float(), nullable=True))
|
||||||
|
op.add_column('tasks', sa.Column('completed_at', sa.Float(), nullable=True))
|
||||||
|
|
||||||
|
# Add user context fields
|
||||||
|
op.add_column('tasks', sa.Column('user_id', sa.String(), nullable=True))
|
||||||
|
op.add_column('tasks', sa.Column('project_id', sa.String(), nullable=True))
|
||||||
|
|
||||||
|
# Add indexes for new fields
|
||||||
|
op.create_index('idx_tasks_user_id', 'tasks', ['user_id'])
|
||||||
|
op.create_index('idx_tasks_project_id', 'tasks', ['project_id'])
|
||||||
|
|
||||||
|
# Note: deleted_at column already exists from previous migration
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# Remove indexes
|
||||||
|
op.drop_index('idx_tasks_project_id', table_name='tasks')
|
||||||
|
op.drop_index('idx_tasks_user_id', table_name='tasks')
|
||||||
|
|
||||||
|
# Remove columns
|
||||||
|
op.drop_column('tasks', 'project_id')
|
||||||
|
op.drop_column('tasks', 'user_id')
|
||||||
|
op.drop_column('tasks', 'completed_at')
|
||||||
|
op.drop_column('tasks', 'started_at')
|
||||||
|
op.drop_column('tasks', 'max_retries')
|
||||||
|
op.drop_column('tasks', 'retry_count')
|
||||||
54
backend/alembic/versions/add_user_sessions.py
Normal file
54
backend/alembic/versions/add_user_sessions.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
"""add user_sessions table
|
||||||
|
|
||||||
|
Revision ID: add_user_sessions
|
||||||
|
Revises: b546dbb9df98
|
||||||
|
Create Date: 2026-03-09
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
revision: str = 'add_user_sessions'
|
||||||
|
down_revision: Union[str, Sequence[str], None] = 'b546dbb9df98'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
'user_sessions',
|
||||||
|
sa.Column('id', sa.String(), nullable=False),
|
||||||
|
sa.Column('user_id', sa.String(), nullable=False),
|
||||||
|
sa.Column('session_family_id', sa.String(), nullable=False),
|
||||||
|
sa.Column('refresh_token_hash', sa.String(), nullable=False),
|
||||||
|
sa.Column('status', sa.String(), nullable=False, server_default='active'),
|
||||||
|
sa.Column('created_at', sa.Float(), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.Float(), nullable=False),
|
||||||
|
sa.Column('expires_at', sa.Float(), nullable=False),
|
||||||
|
sa.Column('last_used_at', sa.Float(), nullable=True),
|
||||||
|
sa.Column('revoked_at', sa.Float(), nullable=True),
|
||||||
|
sa.Column('revoked_reason', sa.String(), nullable=True),
|
||||||
|
sa.Column('replaced_by_session_id', sa.String(), nullable=True),
|
||||||
|
sa.Column('ip_address', sa.String(), nullable=True),
|
||||||
|
sa.Column('user_agent', sa.Text(), nullable=True),
|
||||||
|
sa.Column('device_name', sa.String(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_user_sessions_user_id'), 'user_sessions', ['user_id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_user_sessions_session_family_id'), 'user_sessions', ['session_family_id'], unique=False)
|
||||||
|
op.create_index(op.f('ix_user_sessions_refresh_token_hash'), 'user_sessions', ['refresh_token_hash'], unique=False)
|
||||||
|
op.create_index(op.f('ix_user_sessions_status'), 'user_sessions', ['status'], unique=False)
|
||||||
|
op.create_index(op.f('ix_user_sessions_revoked_at'), 'user_sessions', ['revoked_at'], unique=False)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index(op.f('ix_user_sessions_revoked_at'), table_name='user_sessions')
|
||||||
|
op.drop_index(op.f('ix_user_sessions_status'), table_name='user_sessions')
|
||||||
|
op.drop_index(op.f('ix_user_sessions_refresh_token_hash'), table_name='user_sessions')
|
||||||
|
op.drop_index(op.f('ix_user_sessions_session_family_id'), table_name='user_sessions')
|
||||||
|
op.drop_index(op.f('ix_user_sessions_user_id'), table_name='user_sessions')
|
||||||
|
op.drop_table('user_sessions')
|
||||||
@@ -0,0 +1,72 @@
|
|||||||
|
"""add_users_and_api_keys_tables
|
||||||
|
|
||||||
|
Revision ID: b546dbb9df98
|
||||||
|
Revises: rename_style_preset
|
||||||
|
Create Date: 2026-02-14 13:01:36.394119
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'b546dbb9df98'
|
||||||
|
down_revision: Union[str, Sequence[str], None] = 'rename_style_preset'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema - Add users and user_api_keys tables."""
|
||||||
|
# Create users table
|
||||||
|
op.create_table(
|
||||||
|
'users',
|
||||||
|
sa.Column('id', sa.String(), nullable=False),
|
||||||
|
sa.Column('username', sa.String(), nullable=False),
|
||||||
|
sa.Column('email', sa.String(), nullable=True),
|
||||||
|
sa.Column('password_hash', sa.String(), nullable=False),
|
||||||
|
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='1'),
|
||||||
|
sa.Column('is_superuser', sa.Boolean(), nullable=False, server_default='0'),
|
||||||
|
sa.Column('permissions', sa.JSON(), nullable=False, server_default='[]'),
|
||||||
|
sa.Column('roles', sa.JSON(), nullable=False, server_default='[]'),
|
||||||
|
sa.Column('created_at', sa.Float(), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.Float(), nullable=False),
|
||||||
|
sa.Column('last_login', sa.Float(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
sa.UniqueConstraint('email'),
|
||||||
|
sa.UniqueConstraint('username')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=False)
|
||||||
|
op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=False)
|
||||||
|
|
||||||
|
# Create user_api_keys table
|
||||||
|
op.create_table(
|
||||||
|
'user_api_keys',
|
||||||
|
sa.Column('id', sa.String(), nullable=False),
|
||||||
|
sa.Column('user_id', sa.String(), nullable=False),
|
||||||
|
sa.Column('provider', sa.String(), nullable=False),
|
||||||
|
sa.Column('encrypted_key', sa.String(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(), nullable=True),
|
||||||
|
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='1'),
|
||||||
|
sa.Column('created_at', sa.Float(), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.Float(), nullable=False),
|
||||||
|
sa.Column('last_used_at', sa.Float(), nullable=True),
|
||||||
|
sa.Column('usage_count', sa.Integer(), nullable=False, server_default='0'),
|
||||||
|
sa.Column('extra_config', sa.JSON(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_user_api_keys_provider'), 'user_api_keys', ['provider'], unique=False)
|
||||||
|
op.create_index(op.f('ix_user_api_keys_user_id'), 'user_api_keys', ['user_id'], unique=False)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema - Remove users and user_api_keys tables."""
|
||||||
|
op.drop_index(op.f('ix_user_api_keys_user_id'), table_name='user_api_keys')
|
||||||
|
op.drop_index(op.f('ix_user_api_keys_provider'), table_name='user_api_keys')
|
||||||
|
op.drop_table('user_api_keys')
|
||||||
|
op.drop_index(op.f('ix_users_username'), table_name='users')
|
||||||
|
op.drop_index(op.f('ix_users_email'), table_name='users')
|
||||||
|
op.drop_table('users')
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
"""add location and time to storyboards
|
||||||
|
|
||||||
|
Revision ID: bfac9b8e32f5
|
||||||
|
Revises: 72f609dd9e66
|
||||||
|
Create Date: 2026-01-11 00:49:48.323949
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'bfac9b8e32f5'
|
||||||
|
down_revision: Union[str, Sequence[str], None] = '72f609dd9e66'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
"""Upgrade schema."""
|
||||||
|
op.add_column('storyboards', sa.Column('location', sa.String(), nullable=True))
|
||||||
|
op.add_column('storyboards', sa.Column('time', sa.String(), nullable=True))
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
"""Downgrade schema."""
|
||||||
|
op.drop_column('storyboards', 'time')
|
||||||
|
op.drop_column('storyboards', 'location')
|
||||||
34
backend/alembic/versions/rename_style_preset_to_style_id.py
Normal file
34
backend/alembic/versions/rename_style_preset_to_style_id.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
"""rename style_preset to style_id
|
||||||
|
|
||||||
|
Revision ID: rename_style_preset
|
||||||
|
Revises: add_cinematic_fields, add_provider_to_tasks
|
||||||
|
Create Date: 2024-02-11
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from typing import Union, Sequence
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'rename_style_preset'
|
||||||
|
down_revision: Union[str, Sequence[str], None] = ('add_cinematic_fields', 'add_provider_to_tasks')
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
"""Rename style_preset column to style_id in projects table"""
|
||||||
|
# SQLite doesn't support ALTER COLUMN RENAME directly
|
||||||
|
# We need to use a workaround with table recreation
|
||||||
|
|
||||||
|
with op.batch_alter_table('projects', schema=None) as batch_op:
|
||||||
|
# Rename the column
|
||||||
|
batch_op.alter_column('style_preset', new_column_name='style_id')
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
"""Revert style_id column back to style_preset"""
|
||||||
|
with op.batch_alter_table('projects', schema=None) as batch_op:
|
||||||
|
# Rename back
|
||||||
|
batch_op.alter_column('style_id', new_column_name='style_preset')
|
||||||
49
backend/pyproject.toml
Normal file
49
backend/pyproject.toml
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
[project]
|
||||||
|
name = "backend"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = [
|
||||||
|
"alibabacloud-tea-openapi>=0.4.2",
|
||||||
|
"alibabacloud-tea-util>=0.3.14",
|
||||||
|
"alibabacloud-videoenhan20200320>=4.0.0",
|
||||||
|
"fastapi>=0.127.0",
|
||||||
|
"oss2>=2.19.1",
|
||||||
|
"pydantic>=2.12.5",
|
||||||
|
"python-dotenv>=1.2.1",
|
||||||
|
"python-multipart>=0.0.21",
|
||||||
|
"requests>=2.32.5",
|
||||||
|
"httpx>=0.27.0",
|
||||||
|
"uvicorn>=0.40.0",
|
||||||
|
"modelscope>=1.29.2",
|
||||||
|
"volcengine>=1.0.100",
|
||||||
|
"google-generativeai>=0.8.4",
|
||||||
|
"sqlmodel>=0.0.31",
|
||||||
|
"alembic>=1.17.2",
|
||||||
|
"agentscope>=1.0.11",
|
||||||
|
"redis>=5.0.0",
|
||||||
|
"prometheus-client>=0.20.0",
|
||||||
|
"opentelemetry-api>=1.25.0",
|
||||||
|
"opentelemetry-sdk>=1.25.0",
|
||||||
|
"opentelemetry-instrumentation-fastapi>=0.46b0",
|
||||||
|
"opentelemetry-exporter-otlp-proto-grpc>=1.25.0",
|
||||||
|
"fastmcp>=2.0.0",
|
||||||
|
"tenacity>=8.2.3",
|
||||||
|
"psycopg2-binary>=2.9.9",
|
||||||
|
"asyncpg>=0.29.0",
|
||||||
|
"sqladmin[full]>=0.16.0",
|
||||||
|
"itsdangerous>=2.1.2",
|
||||||
|
"psutil>=5.9.0",
|
||||||
|
"Pillow>=11.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
asyncio_default_fixture_loop_scope = "function"
|
||||||
|
pythonpath = ["."]
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
dev = [
|
||||||
|
"hypothesis>=6.151.5",
|
||||||
|
]
|
||||||
84
backend/scripts/generate_error_codes_ts.py
Normal file
84
backend/scripts/generate_error_codes_ts.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Generate TypeScript ErrorCode enum from Python ErrorCode enum
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python scripts/generate_error_codes_ts.py > ../frontend/src/lib/errors.ts
|
||||||
|
|
||||||
|
This script synchronizes the frontend ErrorCode enum with the backend definition
|
||||||
|
to ensure consistency across the project.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add src to path
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
||||||
|
|
||||||
|
from src.utils.errors import ErrorCode
|
||||||
|
|
||||||
|
|
||||||
|
def generate_typescript():
|
||||||
|
lines = [
|
||||||
|
"import { logger } from './utils/logger';",
|
||||||
|
"",
|
||||||
|
"/**",
|
||||||
|
" * Frontend Error Handling Module",
|
||||||
|
" * Provides standardized error codes, user-friendly messages, and retry logic",
|
||||||
|
" *",
|
||||||
|
" * Error codes are synchronized with backend:",
|
||||||
|
" * Format: 4-digit string",
|
||||||
|
" * - 0000: Success",
|
||||||
|
" * - 1xxx: General errors",
|
||||||
|
" * - 2xxx: Business errors",
|
||||||
|
" * - 3xxx: Task errors",
|
||||||
|
" * - 4xxx: AI service errors",
|
||||||
|
" * - 5xxx: Storage errors",
|
||||||
|
" */",
|
||||||
|
"",
|
||||||
|
"/**",
|
||||||
|
" * Error codes matching backend error responses",
|
||||||
|
" * Auto-synchronized with backend/src/utils/errors.py",
|
||||||
|
" */",
|
||||||
|
"export enum ErrorCode {",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Group error codes by category for better organization
|
||||||
|
categories = {
|
||||||
|
"0000": "Success",
|
||||||
|
"1": "General errors",
|
||||||
|
"2": "Business errors",
|
||||||
|
"3": "Task errors",
|
||||||
|
"4": "AI service errors",
|
||||||
|
"5": "Storage errors",
|
||||||
|
}
|
||||||
|
|
||||||
|
current_category = None
|
||||||
|
|
||||||
|
# Add backend error codes
|
||||||
|
for code in ErrorCode:
|
||||||
|
category = code.value[0] if code.value != "0000" else "0000"
|
||||||
|
|
||||||
|
if category != current_category:
|
||||||
|
if current_category is not None:
|
||||||
|
lines.append("")
|
||||||
|
if category in categories:
|
||||||
|
lines.append(f" // {categories[category]} ({category}xxx)")
|
||||||
|
current_category = category
|
||||||
|
|
||||||
|
lines.append(f" {code.name} = '{code.value}',")
|
||||||
|
|
||||||
|
# Add frontend-specific error codes
|
||||||
|
lines.extend([
|
||||||
|
"",
|
||||||
|
" // Frontend-specific errors (not from backend)",
|
||||||
|
" NETWORK_ERROR = 'NET01',",
|
||||||
|
" TIMEOUT_ERROR = 'TIM01',",
|
||||||
|
"}",
|
||||||
|
])
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print(generate_typescript())
|
||||||
1
backend/src/__init__.py
Normal file
1
backend/src/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Backend application package."""
|
||||||
33
backend/src/admin_config.py
Normal file
33
backend/src/admin_config.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
from sqladmin import Admin, ModelView
|
||||||
|
from src.models.entities import ProjectDB, AssetDB, EpisodeDB, StoryboardDB
|
||||||
|
|
||||||
|
class ProjectAdmin(ModelView, model=ProjectDB):
|
||||||
|
column_list = [ProjectDB.id, ProjectDB.name, ProjectDB.type, ProjectDB.status, ProjectDB.created_at]
|
||||||
|
column_searchable_list = [ProjectDB.name, ProjectDB.id]
|
||||||
|
column_sortable_list = [ProjectDB.created_at]
|
||||||
|
icon = "fa-solid fa-diagram-project"
|
||||||
|
|
||||||
|
class AssetAdmin(ModelView, model=AssetDB):
|
||||||
|
column_list = [AssetDB.id, AssetDB.name, AssetDB.type, AssetDB.project_id]
|
||||||
|
column_searchable_list = [AssetDB.name, AssetDB.id, AssetDB.type]
|
||||||
|
# 列_filters = ["type", "project_id"]
|
||||||
|
# 列_filters = [AssetDB.type, AssetDB.project_id]
|
||||||
|
icon = "fa-solid fa-cube"
|
||||||
|
|
||||||
|
class EpisodeAdmin(ModelView, model=EpisodeDB):
|
||||||
|
column_list = [EpisodeDB.id, EpisodeDB.title, EpisodeDB.order_index, EpisodeDB.status, EpisodeDB.project_id]
|
||||||
|
column_searchable_list = [EpisodeDB.title, EpisodeDB.id]
|
||||||
|
# 列_filters = [EpisodeDB.status, EpisodeDB.project_id]
|
||||||
|
icon = "fa-solid fa-film"
|
||||||
|
|
||||||
|
class StoryboardAdmin(ModelView, model=StoryboardDB):
|
||||||
|
column_list = [StoryboardDB.id, StoryboardDB.project_id, StoryboardDB.episode_id, StoryboardDB.order_index]
|
||||||
|
# 列_filters = [StoryboardDB.project_id, StoryboardDB.episode_id]
|
||||||
|
icon = "fa-solid fa-image"
|
||||||
|
|
||||||
|
def setup_admin(app, engine):
|
||||||
|
admin = Admin(app, engine, title="Pixel管理后台")
|
||||||
|
admin.add_view(ProjectAdmin)
|
||||||
|
admin.add_view(AssetAdmin)
|
||||||
|
admin.add_view(EpisodeAdmin)
|
||||||
|
admin.add_view(StoryboardAdmin)
|
||||||
32
backend/src/api/admin/__init__.py
Normal file
32
backend/src/api/admin/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""
|
||||||
|
Admin API Package
|
||||||
|
|
||||||
|
管理 API 模块,包含:
|
||||||
|
- dashboard: 仪表板统计和系统资源路由
|
||||||
|
- users: 用户管理路由
|
||||||
|
- projects: 项目管理路由
|
||||||
|
- tasks: 任务管理路由
|
||||||
|
- settings: 系统设置路由
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from .dashboard import router as dashboard_router
|
||||||
|
from .users import router as users_router
|
||||||
|
from .projects import router as projects_router
|
||||||
|
from .tasks import router as tasks_router
|
||||||
|
from .settings import router as settings_router
|
||||||
|
|
||||||
|
# 创建主路由器
|
||||||
|
router = APIRouter(prefix="/admin", tags=["admin"])
|
||||||
|
|
||||||
|
# 包含所有子路由
|
||||||
|
router.include_router(dashboard_router)
|
||||||
|
router.include_router(users_router)
|
||||||
|
router.include_router(projects_router)
|
||||||
|
router.include_router(tasks_router)
|
||||||
|
router.include_router(settings_router)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"router",
|
||||||
|
]
|
||||||
85
backend/src/api/admin/dashboard.py
Normal file
85
backend/src/api/admin/dashboard.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
"""
|
||||||
|
Admin API - Dashboard Routes
|
||||||
|
|
||||||
|
包含仪表板统计和系统资源相关的 API 路由。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
|
||||||
|
from src.auth.dependencies import require_admin
|
||||||
|
from src.auth.models import UserAuth
|
||||||
|
from src.models.schemas import ResponseModel
|
||||||
|
from src.services.admin_service import admin_service
|
||||||
|
from src.utils.errors import ErrorCode, AppException
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/dashboard", tags=["admin-dashboard"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/stats", response_model=ResponseModel)
|
||||||
|
async def get_dashboard_stats(
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
) -> ResponseModel:
|
||||||
|
"""
|
||||||
|
Get dashboard statistics
|
||||||
|
|
||||||
|
Returns counts of users, projects, and tasks by status.
|
||||||
|
Requires admin privileges.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
stats = await admin_service.get_dashboard_stats()
|
||||||
|
return ResponseModel(data=stats.model_dump())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting dashboard stats: {e}")
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
message=str(e),
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/system", response_model=ResponseModel)
|
||||||
|
async def get_system_resources(
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
) -> ResponseModel:
|
||||||
|
"""
|
||||||
|
Get system resource information
|
||||||
|
|
||||||
|
Returns CPU, memory, disk usage and uptime.
|
||||||
|
Requires admin privileges.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
resources = await admin_service.get_system_resources()
|
||||||
|
return ResponseModel(data=resources.model_dump())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting system resources: {e}")
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
message=str(e),
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/activity", response_model=ResponseModel)
|
||||||
|
async def get_recent_activity(
|
||||||
|
limit: int = Query(20, ge=1, le=100),
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
) -> ResponseModel:
|
||||||
|
"""
|
||||||
|
Get recent system activity
|
||||||
|
|
||||||
|
Returns recent user, project, and task activities.
|
||||||
|
Requires admin privileges.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
activity = await admin_service.get_recent_activity(limit=limit)
|
||||||
|
return ResponseModel(data=activity.model_dump())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting recent activity: {e}")
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
message=str(e),
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
147
backend/src/api/admin/projects.py
Normal file
147
backend/src/api/admin/projects.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
"""
|
||||||
|
Admin API - Projects Routes
|
||||||
|
|
||||||
|
包含项目管理相关的 API 路由。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Query, Request
|
||||||
|
|
||||||
|
from src.auth.dependencies import require_admin
|
||||||
|
from src.auth.models import UserAuth
|
||||||
|
from src.models.schemas import ResponseModel
|
||||||
|
from src.services.admin_service import admin_service
|
||||||
|
from src.utils.pagination import Paginator
|
||||||
|
from src.utils.errors import ErrorCode, AppException
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/projects", tags=["admin-projects"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=ResponseModel)
|
||||||
|
async def list_projects(
|
||||||
|
request: Request,
|
||||||
|
page: int = Query(1, ge=1, description="Page number, starting from 1"),
|
||||||
|
page_size: int = Query(20, ge=1, le=100, description="Items per page"),
|
||||||
|
sort: Optional[str] = Query(None, description="Sort field, format: field:asc or field:desc"),
|
||||||
|
filter: Optional[str] = Query(None, description="Filter conditions, JSON format"),
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
) -> ResponseModel:
|
||||||
|
"""
|
||||||
|
List all projects with pagination
|
||||||
|
|
||||||
|
Supports filtering by status, type, and search by name.
|
||||||
|
Requires admin privileges.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Parse filters
|
||||||
|
filters = {}
|
||||||
|
if filter:
|
||||||
|
import json
|
||||||
|
try:
|
||||||
|
filters = json.loads(filter)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Parse sort
|
||||||
|
sort_by = "created_at"
|
||||||
|
sort_order = "desc"
|
||||||
|
if sort:
|
||||||
|
parts = sort.split(":")
|
||||||
|
if len(parts) == 2:
|
||||||
|
sort_by = parts[0]
|
||||||
|
sort_order = parts[1]
|
||||||
|
|
||||||
|
items, total = await admin_service.list_projects(
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
filters=filters if filters else None,
|
||||||
|
sort_by=sort_by,
|
||||||
|
sort_order=sort_order,
|
||||||
|
)
|
||||||
|
|
||||||
|
paginator = Paginator(
|
||||||
|
items=[item.model_dump(by_alias=True) for item in items],
|
||||||
|
total=total,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return paginator.to_response(request)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error listing projects: {e}")
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
message=str(e),
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{project_id}", response_model=ResponseModel)
|
||||||
|
async def get_project(
|
||||||
|
project_id: str,
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
) -> ResponseModel:
|
||||||
|
"""
|
||||||
|
Get project details by ID
|
||||||
|
|
||||||
|
Requires admin privileges.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
project = await admin_service.get_project(project_id)
|
||||||
|
if not project:
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
message="Project not found",
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
return ResponseModel(data=project.model_dump())
|
||||||
|
except AppException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting project: {e}")
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
message=str(e),
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{project_id}", response_model=ResponseModel)
|
||||||
|
async def delete_project(
|
||||||
|
project_id: str,
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
) -> ResponseModel:
|
||||||
|
"""
|
||||||
|
Delete a project
|
||||||
|
|
||||||
|
Requires admin privileges.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
success = await admin_service.delete_project(project_id)
|
||||||
|
if not success:
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
message="Project not found",
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data={
|
||||||
|
"id": project_id,
|
||||||
|
"deleted": True,
|
||||||
|
"message": "Project deleted successfully",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except AppException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting project: {e}")
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
message=str(e),
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
81
backend/src/api/admin/settings.py
Normal file
81
backend/src/api/admin/settings.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
"""
|
||||||
|
Admin API - Settings Routes
|
||||||
|
|
||||||
|
包含系统设置相关的 API 路由。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
from src.auth.dependencies import require_admin
|
||||||
|
from src.auth.models import UserAuth
|
||||||
|
from src.models.schemas import ResponseModel
|
||||||
|
from src.models.admin_schemas import SystemSettingUpdateRequest
|
||||||
|
from src.services.admin_service import admin_service
|
||||||
|
from src.utils.errors import ErrorCode, AppException
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/settings", tags=["admin-settings"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=ResponseModel)
|
||||||
|
async def get_system_settings(
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
) -> ResponseModel:
|
||||||
|
"""
|
||||||
|
Get system settings
|
||||||
|
|
||||||
|
Returns all configurable system settings.
|
||||||
|
Requires admin privileges.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
settings = await admin_service.get_system_settings()
|
||||||
|
return ResponseModel(data=settings.model_dump())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting system settings: {e}")
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
message=str(e),
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{key}", response_model=ResponseModel)
|
||||||
|
async def update_system_setting(
|
||||||
|
key: str,
|
||||||
|
request: SystemSettingUpdateRequest,
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
) -> ResponseModel:
|
||||||
|
"""
|
||||||
|
Update a system setting
|
||||||
|
|
||||||
|
Requires admin privileges.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
setting = await admin_service.update_system_setting(key, request.value)
|
||||||
|
if not setting:
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
message="Setting not found",
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data={
|
||||||
|
"key": key,
|
||||||
|
"value": request.value,
|
||||||
|
"updated_at": setting.updated_at,
|
||||||
|
"message": "Setting updated successfully",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except AppException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating system setting: {e}")
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
message=str(e),
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
443
backend/src/api/admin/tasks.py
Normal file
443
backend/src/api/admin/tasks.py
Normal file
@@ -0,0 +1,443 @@
|
|||||||
|
"""
|
||||||
|
Admin API - Tasks Routes
|
||||||
|
|
||||||
|
包含任务管理相关的 API 路由。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Query, Request
|
||||||
|
|
||||||
|
from src.auth.dependencies import require_admin
|
||||||
|
from src.auth.models import UserAuth
|
||||||
|
from src.models.schemas import ResponseModel
|
||||||
|
from src.models.admin_schemas import SystemSettingUpdateRequest
|
||||||
|
from src.services.admin_service import admin_service
|
||||||
|
from src.utils.pagination import Paginator
|
||||||
|
from src.utils.errors import ErrorCode, AppException
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/tasks", tags=["admin-tasks"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=ResponseModel)
|
||||||
|
async def list_tasks(
|
||||||
|
request: Request,
|
||||||
|
page: int = Query(1, ge=1, description="Page number, starting from 1"),
|
||||||
|
page_size: int = Query(20, ge=1, le=100, description="Items per page"),
|
||||||
|
sort: Optional[str] = Query(None, description="Sort field, format: field:asc or field:desc"),
|
||||||
|
filter: Optional[str] = Query(None, description="Filter conditions, JSON format"),
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
) -> ResponseModel:
|
||||||
|
"""
|
||||||
|
List all tasks with pagination
|
||||||
|
|
||||||
|
Supports filtering by status, type, provider, user_id, and project_id.
|
||||||
|
Requires admin privileges.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Parse filters
|
||||||
|
filters = {}
|
||||||
|
if filter:
|
||||||
|
import json
|
||||||
|
try:
|
||||||
|
filters = json.loads(filter)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Parse sort
|
||||||
|
sort_by = "created_at"
|
||||||
|
sort_order = "desc"
|
||||||
|
if sort:
|
||||||
|
parts = sort.split(":")
|
||||||
|
if len(parts) == 2:
|
||||||
|
sort_by = parts[0]
|
||||||
|
sort_order = parts[1]
|
||||||
|
|
||||||
|
items, total = await admin_service.list_tasks(
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
filters=filters if filters else None,
|
||||||
|
sort_by=sort_by,
|
||||||
|
sort_order=sort_order,
|
||||||
|
)
|
||||||
|
|
||||||
|
paginator = Paginator(
|
||||||
|
items=[item.model_dump() for item in items],
|
||||||
|
total=total,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return paginator.to_response(request)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error listing tasks: {e}")
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
message=str(e),
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/stats", response_model=ResponseModel)
|
||||||
|
async def get_task_stats(
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
) -> ResponseModel:
|
||||||
|
"""
|
||||||
|
Get task statistics
|
||||||
|
|
||||||
|
Returns counts by status, type, provider, and success rate.
|
||||||
|
Requires admin privileges.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
stats = await admin_service.get_task_stats()
|
||||||
|
return ResponseModel(data=stats.model_dump())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting task stats: {e}")
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
message=str(e),
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/queue", response_model=ResponseModel)
|
||||||
|
async def get_task_queue_status(
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
) -> ResponseModel:
|
||||||
|
"""
|
||||||
|
Get task queue status
|
||||||
|
|
||||||
|
Returns queue length, processing count, and worker status.
|
||||||
|
Requires admin privileges.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
status = await admin_service.get_task_queue_status()
|
||||||
|
return ResponseModel(data=status.model_dump())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting task queue status: {e}")
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
message=str(e),
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{task_id}", response_model=ResponseModel)
|
||||||
|
async def get_task_detail(
|
||||||
|
task_id: str,
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
) -> ResponseModel:
|
||||||
|
"""
|
||||||
|
Get task detail by ID
|
||||||
|
|
||||||
|
Returns detailed task information including user and project details.
|
||||||
|
Requires admin privileges.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
task = await admin_service.get_task_detail(task_id)
|
||||||
|
if not task:
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
message="Task not found",
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
return ResponseModel(data=task.model_dump())
|
||||||
|
except AppException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting task detail: {e}")
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
message=str(e),
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{task_id}/retry", response_model=ResponseModel)
|
||||||
|
async def retry_task(
|
||||||
|
task_id: str,
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
) -> ResponseModel:
|
||||||
|
"""
|
||||||
|
Retry a failed task
|
||||||
|
|
||||||
|
Requires admin privileges.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
task = await admin_service.retry_task(task_id)
|
||||||
|
if not task:
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
message="Task not found or cannot be retried",
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data={
|
||||||
|
"id": task_id,
|
||||||
|
"status": task.status,
|
||||||
|
"message": "Task queued for retry",
|
||||||
|
"retry_count": task.retry_count,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except AppException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error retrying task: {e}")
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
message=str(e),
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{task_id}", response_model=ResponseModel)
|
||||||
|
async def delete_task(
|
||||||
|
task_id: str,
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
) -> ResponseModel:
|
||||||
|
"""
|
||||||
|
Delete a task
|
||||||
|
|
||||||
|
Requires admin privileges.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
success = await admin_service.delete_task(task_id)
|
||||||
|
if not success:
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
message="Task not found",
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data={
|
||||||
|
"id": task_id,
|
||||||
|
"deleted": True,
|
||||||
|
"message": "Task deleted successfully",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except AppException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting task: {e}")
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
message=str(e),
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/batch-retry", response_model=ResponseModel)
|
||||||
|
async def batch_retry_tasks(
|
||||||
|
task_ids: List[str],
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
) -> ResponseModel:
|
||||||
|
"""
|
||||||
|
Batch retry failed tasks
|
||||||
|
|
||||||
|
Requires admin privileges.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from sqlmodel import Session, select
|
||||||
|
from src.config.database import engine
|
||||||
|
from src.models.entities import TaskDB
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
retried = []
|
||||||
|
failed = []
|
||||||
|
|
||||||
|
with Session(engine) as session:
|
||||||
|
for task_id in task_ids:
|
||||||
|
task = session.get(TaskDB, task_id)
|
||||||
|
if not task:
|
||||||
|
failed.append({"id": task_id, "reason": "Task not found"})
|
||||||
|
continue
|
||||||
|
|
||||||
|
if task.status not in ["failed", "timeout"]:
|
||||||
|
failed.append({"id": task_id, "reason": f"Cannot retry task with status: {task.status}"})
|
||||||
|
continue
|
||||||
|
|
||||||
|
task.status = "pending"
|
||||||
|
task.retry_count = 0
|
||||||
|
task.error = None
|
||||||
|
task.updated_at = datetime.now().timestamp()
|
||||||
|
session.add(task)
|
||||||
|
retried.append(task_id)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data={
|
||||||
|
"retried": retried,
|
||||||
|
"failed": failed,
|
||||||
|
"total": len(task_ids),
|
||||||
|
"success_count": len(retried),
|
||||||
|
"message": f"Successfully queued {len(retried)} tasks for retry",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error batch retrying tasks: {e}")
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
message=str(e),
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/batch-cancel", response_model=ResponseModel)
|
||||||
|
async def batch_cancel_tasks(
|
||||||
|
task_ids: List[str],
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
) -> ResponseModel:
|
||||||
|
"""
|
||||||
|
Batch cancel pending/processing tasks
|
||||||
|
|
||||||
|
Requires admin privileges.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from sqlmodel import Session, select
|
||||||
|
from src.config.database import engine
|
||||||
|
from src.models.entities import TaskDB
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
cancelled = []
|
||||||
|
failed = []
|
||||||
|
|
||||||
|
with Session(engine) as session:
|
||||||
|
for task_id in task_ids:
|
||||||
|
task = session.get(TaskDB, task_id)
|
||||||
|
if not task:
|
||||||
|
failed.append({"id": task_id, "reason": "Task not found"})
|
||||||
|
continue
|
||||||
|
|
||||||
|
if task.status not in ["pending", "processing"]:
|
||||||
|
failed.append({"id": task_id, "reason": f"Cannot cancel task with status: {task.status}"})
|
||||||
|
continue
|
||||||
|
|
||||||
|
task.status = "cancelled"
|
||||||
|
task.updated_at = datetime.now().timestamp()
|
||||||
|
session.add(task)
|
||||||
|
cancelled.append(task_id)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data={
|
||||||
|
"cancelled": cancelled,
|
||||||
|
"failed": failed,
|
||||||
|
"total": len(task_ids),
|
||||||
|
"success_count": len(cancelled),
|
||||||
|
"message": f"Successfully cancelled {len(cancelled)} tasks",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error batch cancelling tasks: {e}")
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
message=str(e),
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/batch-delete", response_model=ResponseModel)
|
||||||
|
async def batch_delete_tasks(
|
||||||
|
task_ids: List[str],
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
) -> ResponseModel:
|
||||||
|
"""
|
||||||
|
Batch delete tasks
|
||||||
|
|
||||||
|
Requires admin privileges.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from sqlmodel import Session, select
|
||||||
|
from src.config.database import engine
|
||||||
|
from src.models.entities import TaskDB
|
||||||
|
|
||||||
|
deleted = []
|
||||||
|
failed = []
|
||||||
|
|
||||||
|
with Session(engine) as session:
|
||||||
|
for task_id in task_ids:
|
||||||
|
task = session.get(TaskDB, task_id)
|
||||||
|
if not task:
|
||||||
|
failed.append({"id": task_id, "reason": "Task not found"})
|
||||||
|
continue
|
||||||
|
|
||||||
|
session.delete(task)
|
||||||
|
deleted.append(task_id)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data={
|
||||||
|
"deleted": deleted,
|
||||||
|
"failed": failed,
|
||||||
|
"total": len(task_ids),
|
||||||
|
"success_count": len(deleted),
|
||||||
|
"message": f"Successfully deleted {len(deleted)} tasks",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error batch deleting tasks: {e}")
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
message=str(e),
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/cleanup-completed", response_model=ResponseModel)
|
||||||
|
async def cleanup_completed_tasks(
|
||||||
|
days: int = Query(30, ge=1, le=365, description="清理多少天前的已完成任务"),
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
) -> ResponseModel:
|
||||||
|
"""
|
||||||
|
Cleanup completed tasks older than specified days
|
||||||
|
|
||||||
|
Requires admin privileges.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from sqlmodel import Session, select
|
||||||
|
from src.config.database import engine
|
||||||
|
from src.models.entities import TaskDB
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
cutoff_timestamp = time.time() - (days * 24 * 60 * 60)
|
||||||
|
|
||||||
|
with Session(engine) as session:
|
||||||
|
# Find completed tasks older than cutoff
|
||||||
|
old_tasks = session.exec(
|
||||||
|
select(TaskDB).where(
|
||||||
|
TaskDB.status == "success",
|
||||||
|
TaskDB.completed_at < cutoff_timestamp
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
|
||||||
|
deleted_count = 0
|
||||||
|
for task in old_tasks:
|
||||||
|
session.delete(task)
|
||||||
|
deleted_count += 1
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data={
|
||||||
|
"deleted_count": deleted_count,
|
||||||
|
"days": days,
|
||||||
|
"cutoff_date": datetime.fromtimestamp(cutoff_timestamp).isoformat(),
|
||||||
|
"message": f"Successfully deleted {deleted_count} completed tasks older than {days} days",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error cleaning up completed tasks: {e}")
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
message=str(e),
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
317
backend/src/api/admin/users.py
Normal file
317
backend/src/api/admin/users.py
Normal file
@@ -0,0 +1,317 @@
|
|||||||
|
"""
|
||||||
|
Admin API - Users Routes
|
||||||
|
|
||||||
|
包含用户管理相关的 API 路由。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Query, Request
|
||||||
|
|
||||||
|
from src.auth.dependencies import require_admin
|
||||||
|
from src.auth.models import UserAuth
|
||||||
|
from src.models.schemas import ResponseModel, PaginationParams
|
||||||
|
from src.models.admin_schemas import (
|
||||||
|
AdminUserCreateRequest,
|
||||||
|
AdminUserUpdateRequest,
|
||||||
|
)
|
||||||
|
from src.services.admin_service import admin_service
|
||||||
|
from src.services.user_service import user_service
|
||||||
|
from src.utils.pagination import Paginator
|
||||||
|
from src.utils.errors import ErrorCode, AppException
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/users", tags=["admin-users"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=ResponseModel)
|
||||||
|
async def list_users(
|
||||||
|
request: Request,
|
||||||
|
page: int = Query(1, ge=1, description="Page number, starting from 1"),
|
||||||
|
page_size: int = Query(20, ge=1, le=100, description="Items per page"),
|
||||||
|
sort: Optional[str] = Query(None, description="Sort field, format: field:asc or field:desc"),
|
||||||
|
filter: Optional[str] = Query(None, description="Filter conditions, JSON format"),
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
) -> ResponseModel:
|
||||||
|
"""
|
||||||
|
List all users with pagination
|
||||||
|
|
||||||
|
Supports filtering by is_active, is_superuser, and search by username/email.
|
||||||
|
Requires admin privileges.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Parse filters
|
||||||
|
filters = {}
|
||||||
|
if filter:
|
||||||
|
import json
|
||||||
|
try:
|
||||||
|
filters = json.loads(filter)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Parse sort
|
||||||
|
sort_by = "created_at"
|
||||||
|
sort_order = "desc"
|
||||||
|
if sort:
|
||||||
|
parts = sort.split(":")
|
||||||
|
if len(parts) == 2:
|
||||||
|
sort_by = parts[0]
|
||||||
|
sort_order = parts[1]
|
||||||
|
|
||||||
|
items, total = await admin_service.list_users(
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
filters=filters if filters else None,
|
||||||
|
sort_by=sort_by,
|
||||||
|
sort_order=sort_order,
|
||||||
|
)
|
||||||
|
|
||||||
|
paginator = Paginator(
|
||||||
|
items=[item.model_dump() for item in items],
|
||||||
|
total=total,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return paginator.to_response(request)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error listing users: {e}")
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
message=str(e),
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{user_id}", response_model=ResponseModel)
|
||||||
|
async def get_user(
|
||||||
|
user_id: str,
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
) -> ResponseModel:
|
||||||
|
"""
|
||||||
|
Get user details by ID
|
||||||
|
|
||||||
|
Requires admin privileges.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
user = await admin_service.get_user(user_id)
|
||||||
|
if not user:
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
message="User not found",
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
return ResponseModel(data=user.model_dump())
|
||||||
|
except AppException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting user: {e}")
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
message=str(e),
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=ResponseModel)
|
||||||
|
async def create_user(
|
||||||
|
request: AdminUserCreateRequest,
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
) -> ResponseModel:
|
||||||
|
"""
|
||||||
|
Create a new user
|
||||||
|
|
||||||
|
Requires admin privileges.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check if username already exists
|
||||||
|
existing_user = await user_service.get_user_by_username(request.username)
|
||||||
|
if existing_user:
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.CONFLICT,
|
||||||
|
message="Username already exists",
|
||||||
|
status_code=400
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if email already exists
|
||||||
|
existing_email = await user_service.get_user_by_email(request.email)
|
||||||
|
if existing_email:
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.CONFLICT,
|
||||||
|
message="Email already registered",
|
||||||
|
status_code=400
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create user using user_service
|
||||||
|
user = await user_service.create_user(
|
||||||
|
username=request.username,
|
||||||
|
email=request.email,
|
||||||
|
password=request.password
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update additional fields if provided
|
||||||
|
update_data = {}
|
||||||
|
if request.is_active is not None:
|
||||||
|
update_data["is_active"] = request.is_active
|
||||||
|
if request.is_superuser is not None:
|
||||||
|
update_data["is_superuser"] = request.is_superuser
|
||||||
|
if request.roles:
|
||||||
|
update_data["roles"] = request.roles
|
||||||
|
if request.permissions:
|
||||||
|
update_data["permissions"] = request.permissions
|
||||||
|
|
||||||
|
if update_data:
|
||||||
|
user = await admin_service.update_user(user.id, update_data)
|
||||||
|
|
||||||
|
logger.info(f"Admin {current_user.username} created user: {user.username} ({user.id})")
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data=user.model_dump(),
|
||||||
|
message="User created successfully"
|
||||||
|
)
|
||||||
|
except AppException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating user: {e}")
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
message=str(e),
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{user_id}", response_model=ResponseModel)
|
||||||
|
async def update_user(
|
||||||
|
user_id: str,
|
||||||
|
request: AdminUserUpdateRequest,
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
) -> ResponseModel:
|
||||||
|
"""
|
||||||
|
Update user information
|
||||||
|
|
||||||
|
Requires admin privileges.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Prevent self-demotion from superuser
|
||||||
|
if user_id == current_user.id and request.is_superuser is False:
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.FORBIDDEN,
|
||||||
|
message="Cannot remove your own superuser status",
|
||||||
|
status_code=400
|
||||||
|
)
|
||||||
|
|
||||||
|
update_data = request.model_dump(exclude_unset=True)
|
||||||
|
user = await admin_service.update_user(user_id, update_data)
|
||||||
|
|
||||||
|
if not user:
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
message="User not found",
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
return ResponseModel(data=user.model_dump())
|
||||||
|
except AppException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating user: {e}")
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
message=str(e),
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{user_id}/toggle-active", response_model=ResponseModel)
|
||||||
|
async def toggle_user_active(
|
||||||
|
user_id: str,
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
) -> ResponseModel:
|
||||||
|
"""
|
||||||
|
Toggle user active status
|
||||||
|
|
||||||
|
Requires admin privileges.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Prevent self-deactivation
|
||||||
|
if user_id == current_user.id:
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.FORBIDDEN,
|
||||||
|
message="Cannot deactivate your own account",
|
||||||
|
status_code=400
|
||||||
|
)
|
||||||
|
|
||||||
|
user = await admin_service.get_user(user_id)
|
||||||
|
if not user:
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
message="User not found",
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
new_status = not user.is_active
|
||||||
|
updated_user = await admin_service.toggle_user_active(user_id, new_status)
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data={
|
||||||
|
"id": user_id,
|
||||||
|
"is_active": new_status,
|
||||||
|
"message": f"User {'activated' if new_status else 'deactivated'} successfully",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except AppException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error toggling user active status: {e}")
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
message=str(e),
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{user_id}", response_model=ResponseModel)
|
||||||
|
async def delete_user(
|
||||||
|
user_id: str,
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
) -> ResponseModel:
|
||||||
|
"""
|
||||||
|
Delete a user
|
||||||
|
|
||||||
|
Requires admin privileges.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Prevent self-deletion
|
||||||
|
if user_id == current_user.id:
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.FORBIDDEN,
|
||||||
|
message="Cannot delete your own account",
|
||||||
|
status_code=400
|
||||||
|
)
|
||||||
|
|
||||||
|
success = await admin_service.delete_user(user_id)
|
||||||
|
if not success:
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
message="User not found",
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data={
|
||||||
|
"id": user_id,
|
||||||
|
"deleted": True,
|
||||||
|
"message": "User deleted successfully",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except AppException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting user: {e}")
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
message=str(e),
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
206
backend/src/api/audit_logs.py
Normal file
206
backend/src/api/audit_logs.py
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
"""
|
||||||
|
Audit Log API
|
||||||
|
|
||||||
|
操作审计日志 API 端点。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Dict, Any, List
|
||||||
|
from datetime import datetime
|
||||||
|
from fastapi import APIRouter, Depends, Query, Request
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from src.auth.dependencies import require_admin
|
||||||
|
from src.auth.models import UserAuth
|
||||||
|
from src.models.schemas import ResponseModel
|
||||||
|
from src.utils.pagination import Paginator
|
||||||
|
from src.services.audit_log_service import audit_log_service
|
||||||
|
from src.models.audit_log import AuditLogDB
|
||||||
|
from src.utils.errors import ErrorCode, AppException
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/admin/audit-logs", tags=["admin-audit-logs"])
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ===== 响应模型 =====
|
||||||
|
|
||||||
|
class AuditLogListItem(BaseModel):
|
||||||
|
"""审计日志列表项"""
|
||||||
|
id: str
|
||||||
|
user_id: Optional[str]
|
||||||
|
username: Optional[str]
|
||||||
|
action: str
|
||||||
|
resource_type: Optional[str]
|
||||||
|
resource_id: Optional[str]
|
||||||
|
ip_address: Optional[str]
|
||||||
|
created_at: str # ISO format
|
||||||
|
|
||||||
|
|
||||||
|
class AuditLogDetailResponse(BaseModel):
|
||||||
|
"""审计日志详情"""
|
||||||
|
id: str
|
||||||
|
user_id: Optional[str]
|
||||||
|
username: Optional[str]
|
||||||
|
action: str
|
||||||
|
resource_type: Optional[str]
|
||||||
|
resource_id: Optional[str]
|
||||||
|
ip_address: Optional[str]
|
||||||
|
user_agent: Optional[str]
|
||||||
|
details: Optional[Dict[str, Any]]
|
||||||
|
created_at: str
|
||||||
|
|
||||||
|
|
||||||
|
# ===== API 端点 =====
|
||||||
|
|
||||||
|
@router.get("", response_model=ResponseModel)
|
||||||
|
async def list_audit_logs(
|
||||||
|
request: Request,
|
||||||
|
page: int = Query(1, ge=1, description="页码"),
|
||||||
|
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||||
|
user_id: Optional[str] = Query(None, description="按用户 ID 过滤"),
|
||||||
|
action: Optional[str] = Query(None, description="按操作类型过滤"),
|
||||||
|
resource_type: Optional[str] = Query(None, description="按资源类型过滤"),
|
||||||
|
resource_id: Optional[str] = Query(None, description="按资源 ID 过滤"),
|
||||||
|
start_date: Optional[str] = Query(None, description="开始日期 (ISO format)"),
|
||||||
|
end_date: Optional[str] = Query(None, description="结束日期 (ISO format)"),
|
||||||
|
sort_by: str = Query("created_at", description="排序字段"),
|
||||||
|
sort_order: str = Query("desc", description="排序方向"),
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
列出审计日志
|
||||||
|
|
||||||
|
支持高级搜索:按用户、操作类型、资源类型、时间范围过滤。
|
||||||
|
需要管理员权限。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 构建过滤条件
|
||||||
|
filters = {}
|
||||||
|
if user_id:
|
||||||
|
filters["user_id"] = user_id
|
||||||
|
if action:
|
||||||
|
filters["action"] = action
|
||||||
|
if resource_type:
|
||||||
|
filters["resource_type"] = resource_type
|
||||||
|
if resource_id:
|
||||||
|
filters["resource_id"] = resource_id
|
||||||
|
if start_date:
|
||||||
|
filters["start_date"] = start_date
|
||||||
|
if end_date:
|
||||||
|
filters["end_date"] = end_date
|
||||||
|
|
||||||
|
logs, total = audit_log_service.list_logs(
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
filters=filters,
|
||||||
|
sort_by=sort_by,
|
||||||
|
sort_order=sort_order,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 转换为响应格式
|
||||||
|
items = [
|
||||||
|
AuditLogListItem(
|
||||||
|
id=log.id,
|
||||||
|
user_id=log.user_id,
|
||||||
|
username=log.username,
|
||||||
|
action=log.action,
|
||||||
|
resource_type=log.resource_type,
|
||||||
|
resource_id=log.resource_id,
|
||||||
|
ip_address=log.ip_address,
|
||||||
|
created_at=datetime.fromtimestamp(log.created_at).isoformat(),
|
||||||
|
).model_dump()
|
||||||
|
for log in logs
|
||||||
|
]
|
||||||
|
|
||||||
|
paginator = Paginator(
|
||||||
|
items=items,
|
||||||
|
total=total,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return paginator.to_response(request)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error listing audit logs: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{log_id}", response_model=ResponseModel)
|
||||||
|
async def get_audit_log(
|
||||||
|
log_id: str,
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取审计日志详情
|
||||||
|
|
||||||
|
需要管理员权限。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
log = audit_log_service.get_log(log_id)
|
||||||
|
if not log:
|
||||||
|
raise AppException(
|
||||||
|
message="Audit log not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data=AuditLogDetailResponse(
|
||||||
|
id=log.id,
|
||||||
|
user_id=log.user_id,
|
||||||
|
username=log.username,
|
||||||
|
action=log.action,
|
||||||
|
resource_type=log.resource_type,
|
||||||
|
resource_id=log.resource_id,
|
||||||
|
ip_address=log.ip_address,
|
||||||
|
user_agent=log.user_agent,
|
||||||
|
details=log.details,
|
||||||
|
created_at=datetime.fromtimestamp(log.created_at).isoformat(),
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
except AppException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting audit log: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/export/csv", response_model=ResponseModel)
|
||||||
|
async def export_audit_logs(
|
||||||
|
user_id: Optional[str] = Query(None, description="按用户 ID 过滤"),
|
||||||
|
action: Optional[str] = Query(None, description="按操作类型过滤"),
|
||||||
|
resource_type: Optional[str] = Query(None, description="按资源类型过滤"),
|
||||||
|
start_date: Optional[str] = Query(None, description="开始日期 (ISO format)"),
|
||||||
|
end_date: Optional[str] = Query(None, description="结束日期 (ISO format)"),
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
导出审计日志为 CSV
|
||||||
|
|
||||||
|
需要管理员权限。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 构建过滤条件
|
||||||
|
filters = {}
|
||||||
|
if user_id:
|
||||||
|
filters["user_id"] = user_id
|
||||||
|
if action:
|
||||||
|
filters["action"] = action
|
||||||
|
if resource_type:
|
||||||
|
filters["resource_type"] = resource_type
|
||||||
|
if start_date:
|
||||||
|
filters["start_date"] = start_date
|
||||||
|
if end_date:
|
||||||
|
filters["end_date"] = end_date
|
||||||
|
|
||||||
|
csv_content = audit_log_service.export_logs(filters=filters)
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data={
|
||||||
|
"content": csv_content,
|
||||||
|
"filename": f"audit_logs_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error exporting audit logs: {e}")
|
||||||
|
raise
|
||||||
927
backend/src/api/auth.py
Normal file
927
backend/src/api/auth.py
Normal file
@@ -0,0 +1,927 @@
|
|||||||
|
"""
|
||||||
|
认证 API 路由
|
||||||
|
|
||||||
|
提供用户登录、注册、获取当前用户信息等接口。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import io
|
||||||
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from fastapi import APIRouter, status, Depends, UploadFile, File, Request
|
||||||
|
from pydantic import BaseModel, EmailStr, Field
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from src.auth.jwt import (
|
||||||
|
create_token_pair,
|
||||||
|
verify_password,
|
||||||
|
get_password_hash,
|
||||||
|
verify_token,
|
||||||
|
ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||||
|
)
|
||||||
|
from src.services.token_blacklist_service import token_blacklist_service
|
||||||
|
from src.services.session_service import session_service
|
||||||
|
from src.services.email_service import email_service
|
||||||
|
from src.config.settings import REDIS_ENABLED, NODE_ENV
|
||||||
|
from src.auth.dependencies import get_current_user, UserAuth
|
||||||
|
from src.auth.models import RefreshTokenRequest
|
||||||
|
from src.services.user_service import user_service
|
||||||
|
from src.services.storage_service import storage_manager
|
||||||
|
from src.models.schemas import ResponseModel
|
||||||
|
from src.utils.errors import ErrorCode, AppException
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
router = APIRouter(prefix="/auth", tags=["认证"])
|
||||||
|
|
||||||
|
|
||||||
|
def _build_user_payload(user: UserAuth) -> dict:
|
||||||
|
return {
|
||||||
|
"id": user.id,
|
||||||
|
"username": user.username,
|
||||||
|
"email": user.email,
|
||||||
|
"is_active": user.is_active,
|
||||||
|
"is_superuser": user.is_superuser,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_bearer_token(request: Request) -> Optional[str]:
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if auth_header.startswith("Bearer "):
|
||||||
|
return auth_header[7:]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_access_payload(request: Request):
|
||||||
|
access_token = _extract_bearer_token(request)
|
||||||
|
if not access_token:
|
||||||
|
return None
|
||||||
|
return await verify_token(access_token, token_type="access")
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_session(session_db, current_session_id: Optional[str] = None) -> dict:
|
||||||
|
return {
|
||||||
|
"id": session_db.id,
|
||||||
|
"session_family_id": session_db.session_family_id,
|
||||||
|
"status": session_db.status,
|
||||||
|
"device_name": session_db.device_name,
|
||||||
|
"ip_address": session_db.ip_address,
|
||||||
|
"user_agent": session_db.user_agent,
|
||||||
|
"created_at": session_db.created_at,
|
||||||
|
"updated_at": session_db.updated_at,
|
||||||
|
"expires_at": session_db.expires_at,
|
||||||
|
"last_used_at": session_db.last_used_at,
|
||||||
|
"revoked_at": session_db.revoked_at,
|
||||||
|
"revoked_reason": session_db.revoked_reason,
|
||||||
|
"current": session_db.id == current_session_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class LoginRequest(BaseModel):
|
||||||
|
"""登录请求"""
|
||||||
|
|
||||||
|
username: str = Field(..., description="用户名或邮箱")
|
||||||
|
password: str = Field(..., description="密码")
|
||||||
|
|
||||||
|
|
||||||
|
class RegisterRequest(BaseModel):
|
||||||
|
"""注册请求"""
|
||||||
|
|
||||||
|
username: str = Field(..., min_length=3, max_length=50, description="用户名")
|
||||||
|
email: EmailStr = Field(..., description="邮箱")
|
||||||
|
password: str = Field(..., min_length=6, description="密码")
|
||||||
|
|
||||||
|
|
||||||
|
class TokenResponse(BaseModel):
|
||||||
|
"""Token 响应"""
|
||||||
|
|
||||||
|
access_token: str
|
||||||
|
refresh_token: str
|
||||||
|
token_type: str = "bearer"
|
||||||
|
expires_in: int = ACCESS_TOKEN_EXPIRE_MINUTES * 60 # 从配置读取
|
||||||
|
|
||||||
|
|
||||||
|
class UserInfoResponse(BaseModel):
|
||||||
|
"""用户信息响应"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
username: str
|
||||||
|
email: Optional[str]
|
||||||
|
is_active: bool
|
||||||
|
is_superuser: bool
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login", response_model=ResponseModel)
|
||||||
|
async def login(request: LoginRequest, http_request: Request):
|
||||||
|
"""
|
||||||
|
用户登录
|
||||||
|
|
||||||
|
支持使用用户名或邮箱登录。
|
||||||
|
成功返回 access_token 和 refresh_token。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 使用 authenticate_user 方法验证用户(包含密码验证)
|
||||||
|
user = await user_service.authenticate_user(request.username, request.password)
|
||||||
|
|
||||||
|
if not user:
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.UNAUTHORIZED,
|
||||||
|
message="Invalid credentials",
|
||||||
|
status_code=401,
|
||||||
|
details={"headers": {"WWW-Authenticate": "Bearer"}}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查用户是否激活
|
||||||
|
if not user.is_active:
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.FORBIDDEN,
|
||||||
|
message="User account is inactive",
|
||||||
|
status_code=403
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新最后登录时间
|
||||||
|
await user_service.update_last_login(user.id)
|
||||||
|
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
session_family_id = str(uuid.uuid4())
|
||||||
|
tokens = create_token_pair(
|
||||||
|
user_id=user.id,
|
||||||
|
scopes=["user"],
|
||||||
|
session_id=session_id,
|
||||||
|
session_family_id=session_family_id,
|
||||||
|
)
|
||||||
|
session_service.create_session(
|
||||||
|
user_id=user.id,
|
||||||
|
refresh_token=tokens.refresh_token,
|
||||||
|
session_id=session_id,
|
||||||
|
session_family_id=session_family_id,
|
||||||
|
ip_address=http_request.client.host if http_request.client else None,
|
||||||
|
user_agent=http_request.headers.get("user-agent"),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"User logged in: {user.username} ({user.id})")
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data={
|
||||||
|
"access_token": tokens.access_token,
|
||||||
|
"refresh_token": tokens.refresh_token,
|
||||||
|
"token_type": "bearer",
|
||||||
|
"expires_in": ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||||
|
"session_id": session_id,
|
||||||
|
"user": _build_user_payload(user),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except AppException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Login error: {e}")
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
message="Internal server error",
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/register", response_model=ResponseModel)
|
||||||
|
async def register(request: RegisterRequest, http_request: Request):
|
||||||
|
"""
|
||||||
|
用户注册
|
||||||
|
|
||||||
|
创建新用户账号,用户名和邮箱必须唯一。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 检查用户名是否已存在
|
||||||
|
existing_user = await user_service.get_user_by_username(request.username)
|
||||||
|
if existing_user:
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.CONFLICT,
|
||||||
|
message="Username already registered",
|
||||||
|
status_code=400
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查邮箱是否已存在
|
||||||
|
existing_email = await user_service.get_user_by_email(request.email)
|
||||||
|
if existing_email:
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.CONFLICT,
|
||||||
|
message="Email already registered",
|
||||||
|
status_code=400
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建用户
|
||||||
|
user = await user_service.create_user(
|
||||||
|
username=request.username, email=request.email, password=request.password
|
||||||
|
)
|
||||||
|
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
session_family_id = str(uuid.uuid4())
|
||||||
|
tokens = create_token_pair(
|
||||||
|
user_id=user.id,
|
||||||
|
scopes=["user"],
|
||||||
|
session_id=session_id,
|
||||||
|
session_family_id=session_family_id,
|
||||||
|
)
|
||||||
|
session_service.create_session(
|
||||||
|
user_id=user.id,
|
||||||
|
refresh_token=tokens.refresh_token,
|
||||||
|
session_id=session_id,
|
||||||
|
session_family_id=session_family_id,
|
||||||
|
ip_address=http_request.client.host if http_request.client else None,
|
||||||
|
user_agent=http_request.headers.get("user-agent"),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"User registered: {user.username} ({user.id})")
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data={
|
||||||
|
"access_token": tokens.access_token,
|
||||||
|
"refresh_token": tokens.refresh_token,
|
||||||
|
"token_type": "bearer",
|
||||||
|
"expires_in": ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||||
|
"session_id": session_id,
|
||||||
|
"user": _build_user_payload(user),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except AppException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Registration error: {e}")
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
message="Internal server error",
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/refresh", response_model=ResponseModel)
|
||||||
|
async def refresh_token(request: RefreshTokenRequest, http_request: Request):
|
||||||
|
"""
|
||||||
|
刷新 access token
|
||||||
|
|
||||||
|
使用 refresh token 获取新的 access token。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 验证 refresh token(检查黑名单)
|
||||||
|
payload = await verify_token(request.refresh_token, token_type="refresh")
|
||||||
|
if not payload:
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.UNAUTHORIZED,
|
||||||
|
message="Invalid or revoked refresh token",
|
||||||
|
status_code=401,
|
||||||
|
details={"headers": {"WWW-Authenticate": "Bearer"}}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not payload.sid:
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.UNAUTHORIZED,
|
||||||
|
message="Refresh token session is missing",
|
||||||
|
status_code=401,
|
||||||
|
details={"headers": {"WWW-Authenticate": "Bearer"}}
|
||||||
|
)
|
||||||
|
|
||||||
|
user = await user_service.get_user_by_id(payload.sub)
|
||||||
|
if not user or not user.is_active:
|
||||||
|
if user:
|
||||||
|
session_service.revoke_user_sessions(user.id, reason="inactive_user")
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.UNAUTHORIZED,
|
||||||
|
message="User session is no longer active",
|
||||||
|
status_code=401,
|
||||||
|
details={"headers": {"WWW-Authenticate": "Bearer"}}
|
||||||
|
)
|
||||||
|
|
||||||
|
new_session_id = str(uuid.uuid4())
|
||||||
|
tokens = create_token_pair(
|
||||||
|
user_id=payload.sub,
|
||||||
|
scopes=payload.scopes or ["user"],
|
||||||
|
session_id=new_session_id,
|
||||||
|
session_family_id=payload.sfid,
|
||||||
|
)
|
||||||
|
|
||||||
|
rotated_session = session_service.rotate_refresh_token(
|
||||||
|
payload.sid,
|
||||||
|
request.refresh_token,
|
||||||
|
tokens.refresh_token,
|
||||||
|
new_session_id=new_session_id,
|
||||||
|
ip_address=http_request.client.host if http_request.client else None,
|
||||||
|
user_agent=http_request.headers.get("user-agent"),
|
||||||
|
)
|
||||||
|
if not rotated_session:
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.UNAUTHORIZED,
|
||||||
|
message="Invalid or replayed refresh token",
|
||||||
|
status_code=401,
|
||||||
|
details={"headers": {"WWW-Authenticate": "Bearer"}}
|
||||||
|
)
|
||||||
|
|
||||||
|
user_data = None
|
||||||
|
if user:
|
||||||
|
user_data = _build_user_payload(user)
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data={
|
||||||
|
"access_token": tokens.access_token,
|
||||||
|
"refresh_token": tokens.refresh_token,
|
||||||
|
"token_type": "bearer",
|
||||||
|
"expires_in": ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||||
|
"session_id": rotated_session.id,
|
||||||
|
**({"user": user_data} if user_data else {}),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except AppException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Token refresh error: {e}")
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
message="Internal server error",
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me", response_model=ResponseModel)
|
||||||
|
async def get_me(current_user: UserAuth = Depends(get_current_user)):
|
||||||
|
"""
|
||||||
|
获取当前登录用户信息
|
||||||
|
|
||||||
|
需要有效的 access token。
|
||||||
|
"""
|
||||||
|
return ResponseModel(
|
||||||
|
data={
|
||||||
|
"id": current_user.id,
|
||||||
|
"username": current_user.username,
|
||||||
|
"email": current_user.email,
|
||||||
|
"avatar_url": current_user.avatar_url,
|
||||||
|
"is_active": current_user.is_active,
|
||||||
|
"is_superuser": current_user.is_superuser,
|
||||||
|
"permissions": current_user.permissions,
|
||||||
|
"roles": current_user.roles,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/session", response_model=ResponseModel)
|
||||||
|
async def get_current_session(
|
||||||
|
http_request: Request,
|
||||||
|
current_user: UserAuth = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
payload = await _get_access_payload(http_request)
|
||||||
|
session_data = None
|
||||||
|
|
||||||
|
if payload and payload.sid:
|
||||||
|
session_db = session_service.get_session(payload.sid)
|
||||||
|
if session_db and session_db.user_id == current_user.id:
|
||||||
|
session_data = _serialize_session(session_db, current_session_id=payload.sid)
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data={
|
||||||
|
"authenticated": True,
|
||||||
|
"session_id": payload.sid if payload else None,
|
||||||
|
"user": {
|
||||||
|
"id": current_user.id,
|
||||||
|
"username": current_user.username,
|
||||||
|
"email": current_user.email,
|
||||||
|
"avatar_url": current_user.avatar_url,
|
||||||
|
"is_active": current_user.is_active,
|
||||||
|
"is_superuser": current_user.is_superuser,
|
||||||
|
"permissions": current_user.permissions,
|
||||||
|
"roles": current_user.roles,
|
||||||
|
},
|
||||||
|
"session": session_data,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/sessions", response_model=ResponseModel)
|
||||||
|
async def list_sessions(
|
||||||
|
http_request: Request,
|
||||||
|
include_inactive: bool = False,
|
||||||
|
current_user: UserAuth = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
payload = await _get_access_payload(http_request)
|
||||||
|
current_session_id = payload.sid if payload else None
|
||||||
|
sessions = session_service.list_user_sessions(
|
||||||
|
current_user.id,
|
||||||
|
include_inactive=include_inactive,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data={
|
||||||
|
"items": [
|
||||||
|
_serialize_session(session_db, current_session_id=current_session_id)
|
||||||
|
for session_db in sessions
|
||||||
|
],
|
||||||
|
"total": len(sessions),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/sessions/{session_id}", response_model=ResponseModel)
|
||||||
|
async def revoke_session(
|
||||||
|
session_id: str,
|
||||||
|
http_request: Request,
|
||||||
|
current_user: UserAuth = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
payload = await _get_access_payload(http_request)
|
||||||
|
revoked = session_service.revoke_user_session(
|
||||||
|
current_user.id,
|
||||||
|
session_id,
|
||||||
|
reason="user_revoke",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not revoked:
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
message="Session not found",
|
||||||
|
status_code=404,
|
||||||
|
)
|
||||||
|
|
||||||
|
revoked_access = False
|
||||||
|
if payload and payload.sid == session_id:
|
||||||
|
access_token = _extract_bearer_token(http_request)
|
||||||
|
if access_token:
|
||||||
|
revoked_access = await token_blacklist_service.revoke_token(
|
||||||
|
access_token,
|
||||||
|
reason="user_revoke",
|
||||||
|
)
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
message="Session revoked successfully",
|
||||||
|
data={
|
||||||
|
"session_id": session_id,
|
||||||
|
"revoked": True,
|
||||||
|
"current": payload.sid == session_id if payload else False,
|
||||||
|
"revoked_access": revoked_access,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/logout", response_model=ResponseModel)
|
||||||
|
async def logout(
|
||||||
|
request: RefreshTokenRequest,
|
||||||
|
http_request: Request,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
用户登出
|
||||||
|
|
||||||
|
将当前 access token 和 refresh token 加入黑名单,使其失效。
|
||||||
|
客户端也需要清除本地存储的 token。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
access_token = _extract_bearer_token(http_request)
|
||||||
|
revoked_access = False
|
||||||
|
if access_token:
|
||||||
|
revoked_access = await token_blacklist_service.revoke_token(
|
||||||
|
access_token,
|
||||||
|
reason="logout",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 撤销 refresh token(如果提供)
|
||||||
|
revoked_refresh = False
|
||||||
|
if request.refresh_token:
|
||||||
|
refresh_payload = await verify_token(request.refresh_token, token_type="refresh")
|
||||||
|
if refresh_payload and refresh_payload.sid:
|
||||||
|
session_service.revoke_session(refresh_payload.sid, reason="logout")
|
||||||
|
await token_blacklist_service.revoke_token(
|
||||||
|
request.refresh_token,
|
||||||
|
reason="logout"
|
||||||
|
)
|
||||||
|
revoked_refresh = True
|
||||||
|
|
||||||
|
logger.info(f"User logged out: {current_user.username} ({current_user.id})")
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
message="Logged out successfully",
|
||||||
|
data={
|
||||||
|
"user_id": current_user.id,
|
||||||
|
"revoked": revoked_access or revoked_refresh,
|
||||||
|
"revoked_access": revoked_access,
|
||||||
|
"revoked_refresh": revoked_refresh,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Logout error: {e}")
|
||||||
|
# 即使撤销失败也返回成功,因为客户端应该清除本地 token
|
||||||
|
return ResponseModel(message="Logged out successfully")
|
||||||
|
|
||||||
|
|
||||||
|
class LogoutAllRequest(BaseModel):
|
||||||
|
"""登出所有设备请求"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/logout-all", response_model=ResponseModel)
|
||||||
|
async def logout_all_devices(
|
||||||
|
request: LogoutAllRequest,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
从所有设备登出
|
||||||
|
|
||||||
|
撤销该用户的所有 Token,使用户在所有设备上登出。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 撤销该用户所有 token
|
||||||
|
await token_blacklist_service.revoke_all_user_tokens(
|
||||||
|
current_user.id,
|
||||||
|
reason="logout_all"
|
||||||
|
)
|
||||||
|
revoked_sessions = session_service.revoke_user_sessions(
|
||||||
|
current_user.id,
|
||||||
|
reason="logout_all",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"User logged out from all devices: {current_user.username} ({current_user.id})")
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
message="Logged out from all devices successfully",
|
||||||
|
data={
|
||||||
|
"user_id": current_user.id,
|
||||||
|
"all_devices": True,
|
||||||
|
"revoked_sessions": revoked_sessions,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Logout all error: {e}")
|
||||||
|
raise AppException(
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
message="Failed to logout from all devices",
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/avatar", response_model=ResponseModel)
|
||||||
|
async def upload_avatar(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
上传用户头像
|
||||||
|
|
||||||
|
支持 JPG、PNG 格式,最大 5MB。
|
||||||
|
头像会被裁剪为正方形并压缩至 256x256 像素。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 验证文件类型
|
||||||
|
allowed_types = {'image/jpeg', 'image/png', 'image/jpg'}
|
||||||
|
if file.content_type not in allowed_types:
|
||||||
|
raise AppException(
|
||||||
|
message="Only JPG and PNG images are allowed",
|
||||||
|
code=ErrorCode.INVALID_PARAMETER,
|
||||||
|
status_code=400
|
||||||
|
)
|
||||||
|
|
||||||
|
# 读取文件内容
|
||||||
|
contents = await file.read()
|
||||||
|
max_size = 5 * 1024 * 1024 # 5MB
|
||||||
|
if len(contents) > max_size:
|
||||||
|
raise AppException(
|
||||||
|
message="File size exceeds 5MB limit",
|
||||||
|
code=ErrorCode.INVALID_PARAMETER,
|
||||||
|
status_code=400
|
||||||
|
)
|
||||||
|
|
||||||
|
# 使用 PIL 处理图片
|
||||||
|
image = Image.open(io.BytesIO(contents))
|
||||||
|
|
||||||
|
# 转换为 RGB(处理RGBA图片)
|
||||||
|
if image.mode in ('RGBA', 'LA', 'P'):
|
||||||
|
background = Image.new('RGB', image.size, (255, 255, 255))
|
||||||
|
if image.mode == 'P':
|
||||||
|
image = image.convert('RGBA')
|
||||||
|
if image.mode in ('RGBA', 'LA'):
|
||||||
|
background.paste(image, mask=image.split()[-1] if image.mode in ('RGBA', 'LA') else None)
|
||||||
|
image = background
|
||||||
|
|
||||||
|
# 裁剪为正方形(从中心)
|
||||||
|
width, height = image.size
|
||||||
|
min_dim = min(width, height)
|
||||||
|
left = (width - min_dim) // 2
|
||||||
|
top = (height - min_dim) // 2
|
||||||
|
right = left + min_dim
|
||||||
|
bottom = top + min_dim
|
||||||
|
image = image.crop((left, top, right, bottom))
|
||||||
|
|
||||||
|
# 调整大小为 256x256
|
||||||
|
image = image.resize((256, 256), Image.Resampling.LANCZOS)
|
||||||
|
|
||||||
|
# 保存为 PNG
|
||||||
|
output = io.BytesIO()
|
||||||
|
image.save(output, format='PNG', optimize=True)
|
||||||
|
output.seek(0)
|
||||||
|
|
||||||
|
# 生成存储路径: avatars/{user_id}/{timestamp}.png
|
||||||
|
timestamp = int(datetime.now().timestamp())
|
||||||
|
storage_path = f"avatars/{current_user.id}/{timestamp}.png"
|
||||||
|
|
||||||
|
# 上传到存储
|
||||||
|
avatar_url = storage_manager.save(storage_path, output.getvalue())
|
||||||
|
|
||||||
|
# 更新用户头像 URL
|
||||||
|
updated_user = await user_service.update_user(
|
||||||
|
current_user.id,
|
||||||
|
avatar_url=avatar_url
|
||||||
|
)
|
||||||
|
|
||||||
|
if not updated_user:
|
||||||
|
raise AppException(
|
||||||
|
message="Failed to update user avatar",
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Avatar uploaded for user {current_user.username}: {avatar_url}")
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data={
|
||||||
|
"avatar_url": avatar_url,
|
||||||
|
"user": {
|
||||||
|
"id": updated_user.id,
|
||||||
|
"username": updated_user.username,
|
||||||
|
"email": updated_user.email,
|
||||||
|
"avatar_url": updated_user.avatar_url,
|
||||||
|
"is_active": updated_user.is_active,
|
||||||
|
"is_superuser": updated_user.is_superuser,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except AppException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Avatar upload error: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message=f"Failed to upload avatar: {str(e)}",
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ===== Password Reset Models =====
|
||||||
|
|
||||||
|
class ForgotPasswordRequest(BaseModel):
|
||||||
|
"""忘记密码请求"""
|
||||||
|
email: EmailStr = Field(..., description="用户邮箱")
|
||||||
|
|
||||||
|
|
||||||
|
class ResetPasswordRequest(BaseModel):
|
||||||
|
"""重置密码请求"""
|
||||||
|
token: str = Field(..., description="重置令牌")
|
||||||
|
new_password: str = Field(..., min_length=6, description="新密码")
|
||||||
|
|
||||||
|
|
||||||
|
class VerifyResetTokenRequest(BaseModel):
|
||||||
|
"""验证重置令牌请求"""
|
||||||
|
token: str = Field(..., description="重置令牌")
|
||||||
|
|
||||||
|
|
||||||
|
# Password reset token storage (using Redis if available, else memory)
|
||||||
|
_reset_tokens = {} # In-memory fallback {token: {"user_id": str, "expires": float}}
|
||||||
|
|
||||||
|
|
||||||
|
async def _store_reset_token(token: str, user_id: str, expires_in: int = 3600) -> None:
|
||||||
|
"""Store reset token in Redis or memory"""
|
||||||
|
if REDIS_ENABLED:
|
||||||
|
try:
|
||||||
|
import redis.asyncio as aioredis
|
||||||
|
from src.config.settings import REDIS_URL
|
||||||
|
redis_client = await aioredis.from_url(REDIS_URL, encoding="utf-8", decode_responses=True)
|
||||||
|
await redis_client.setex(f"pwd_reset:{token}", expires_in, user_id)
|
||||||
|
await redis_client.close()
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Redis not available for reset token, using memory: {e}")
|
||||||
|
|
||||||
|
# Fallback to memory
|
||||||
|
_reset_tokens[token] = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"expires": datetime.now().timestamp() + expires_in
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_reset_token_user_id(token: str) -> Optional[str]:
|
||||||
|
"""Get user_id from reset token"""
|
||||||
|
if REDIS_ENABLED:
|
||||||
|
try:
|
||||||
|
import redis.asyncio as aioredis
|
||||||
|
from src.config.settings import REDIS_URL
|
||||||
|
redis_client = await aioredis.from_url(REDIS_URL, encoding="utf-8", decode_responses=True)
|
||||||
|
user_id = await redis_client.get(f"pwd_reset:{token}")
|
||||||
|
await redis_client.close()
|
||||||
|
return user_id
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fallback to memory
|
||||||
|
token_data = _reset_tokens.get(token)
|
||||||
|
if token_data:
|
||||||
|
if token_data["expires"] > datetime.now().timestamp():
|
||||||
|
return token_data["user_id"]
|
||||||
|
else:
|
||||||
|
# Expired, remove it
|
||||||
|
del _reset_tokens[token]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _delete_reset_token(token: str) -> None:
|
||||||
|
"""Delete reset token"""
|
||||||
|
if REDIS_ENABLED:
|
||||||
|
try:
|
||||||
|
import redis.asyncio as aioredis
|
||||||
|
from src.config.settings import REDIS_URL
|
||||||
|
redis_client = await aioredis.from_url(REDIS_URL, encoding="utf-8", decode_responses=True)
|
||||||
|
await redis_client.delete(f"pwd_reset:{token}")
|
||||||
|
await redis_client.close()
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fallback to memory
|
||||||
|
if token in _reset_tokens:
|
||||||
|
del _reset_tokens[token]
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_reset_token() -> str:
|
||||||
|
"""Generate secure reset token"""
|
||||||
|
import secrets
|
||||||
|
return secrets.token_urlsafe(32)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/forgot-password", response_model=ResponseModel)
|
||||||
|
async def forgot_password(request: ForgotPasswordRequest):
|
||||||
|
"""
|
||||||
|
请求密码重置
|
||||||
|
|
||||||
|
发送密码重置邮件。
|
||||||
|
令牌有效期为 1 小时。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 查找用户
|
||||||
|
user = await user_service.get_user_by_email(request.email)
|
||||||
|
if not user:
|
||||||
|
# 为了安全,即使用户不存在也返回成功,不暴露邮箱是否存在
|
||||||
|
logger.info(f"Password reset requested for non-existent email: {request.email}")
|
||||||
|
return ResponseModel(
|
||||||
|
message="If the email exists, a reset link has been sent"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 生成重置令牌
|
||||||
|
token = _generate_reset_token()
|
||||||
|
|
||||||
|
# 存储令牌(1小时有效)
|
||||||
|
await _store_reset_token(token, user.id, expires_in=3600)
|
||||||
|
|
||||||
|
# 发送重置邮件
|
||||||
|
email_result = await email_service.send_password_reset(
|
||||||
|
to_email=user.email,
|
||||||
|
username=user.username,
|
||||||
|
reset_token=token,
|
||||||
|
expires_in=1
|
||||||
|
)
|
||||||
|
|
||||||
|
if email_result["success"]:
|
||||||
|
logger.info(f"Password reset email sent to {user.email}")
|
||||||
|
return ResponseModel(
|
||||||
|
message="Password reset email sent"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 邮件发送失败,但在开发环境可以返回令牌
|
||||||
|
logger.warning(f"Failed to send reset email: {email_result.get('error')}")
|
||||||
|
|
||||||
|
# 生产环境不返回令牌
|
||||||
|
if NODE_ENV == "production":
|
||||||
|
return ResponseModel(
|
||||||
|
message="Failed to send reset email. Please try again later."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 开发环境返回令牌以便测试
|
||||||
|
return ResponseModel(
|
||||||
|
data={
|
||||||
|
"message": "Password reset email sent",
|
||||||
|
"reset_token": token, # 开发环境返回,生产环境不应返回
|
||||||
|
"expires_in": 3600,
|
||||||
|
"email_error": email_result.get("error")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Forgot password error: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message="Failed to process password reset request",
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/verify-reset-token", response_model=ResponseModel)
|
||||||
|
async def verify_reset_token(request: VerifyResetTokenRequest):
|
||||||
|
"""
|
||||||
|
验证密码重置令牌
|
||||||
|
|
||||||
|
检查令牌是否有效,用于前端重置密码页面验证。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
user_id = await _get_reset_token_user_id(request.token)
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
raise AppException(
|
||||||
|
message="Invalid or expired reset token",
|
||||||
|
code=ErrorCode.INVALID_PARAMETER,
|
||||||
|
status_code=400
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取用户信息
|
||||||
|
user = await user_service.get_user_by_id(user_id)
|
||||||
|
if not user:
|
||||||
|
raise AppException(
|
||||||
|
message="User not found",
|
||||||
|
code=ErrorCode.INVALID_PARAMETER,
|
||||||
|
status_code=400
|
||||||
|
)
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data={
|
||||||
|
"valid": True,
|
||||||
|
"user_id": user_id,
|
||||||
|
"email": user.email
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except AppException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Verify reset token error: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message="Failed to verify reset token",
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/reset-password", response_model=ResponseModel)
|
||||||
|
async def reset_password(request: ResetPasswordRequest):
|
||||||
|
"""
|
||||||
|
重置密码
|
||||||
|
|
||||||
|
使用有效的重置令牌设置新密码,并撤销该用户所有 Token。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 验证令牌
|
||||||
|
user_id = await _get_reset_token_user_id(request.token)
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
raise AppException(
|
||||||
|
message="Invalid or expired reset token",
|
||||||
|
code=ErrorCode.INVALID_PARAMETER,
|
||||||
|
status_code=400
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新密码
|
||||||
|
updated_user = await user_service.update_user(
|
||||||
|
user_id,
|
||||||
|
password=request.new_password
|
||||||
|
)
|
||||||
|
|
||||||
|
if not updated_user:
|
||||||
|
raise AppException(
|
||||||
|
message="User not found",
|
||||||
|
code=ErrorCode.INVALID_PARAMETER,
|
||||||
|
status_code=400
|
||||||
|
)
|
||||||
|
|
||||||
|
# 删除已使用的令牌
|
||||||
|
await _delete_reset_token(request.token)
|
||||||
|
|
||||||
|
# 撤销该用户所有 Token(强制重新登录)
|
||||||
|
await token_blacklist_service.revoke_all_user_tokens(
|
||||||
|
user_id,
|
||||||
|
reason="password_reset"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 使该用户所有缓存失效
|
||||||
|
await user_service.invalidate_user_cache(user_id)
|
||||||
|
|
||||||
|
logger.info(f"Password reset successful for user {updated_user.username}")
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
message="Password reset successful",
|
||||||
|
data={
|
||||||
|
"user_id": updated_user.id,
|
||||||
|
"email": updated_user.email
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except AppException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Reset password error: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message="Failed to reset password",
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
81
backend/src/api/canvas.py
Normal file
81
backend/src/api/canvas.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
import logging
|
||||||
|
from fastapi import APIRouter, Query
|
||||||
|
from sqlmodel import Session
|
||||||
|
from src.config.database import engine
|
||||||
|
from src.models.entities import CanvasDB
|
||||||
|
from src.models.schemas import ResponseModel, CanvasState
|
||||||
|
from src.mappers import CanvasMapper
|
||||||
|
from src.utils.errors import BusinessException, ErrorCode
|
||||||
|
|
||||||
|
router = APIRouter(tags=["canvas"])
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@router.get("/canvas", response_model=ResponseModel)
|
||||||
|
async def get_canvas_state(id: str = "default", projectId: str = Query(None)):
|
||||||
|
"""Get canvas state by ID
|
||||||
|
|
||||||
|
Returns empty default state if canvas not found.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BusinessException: Failed to fetch canvas
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with Session(engine) as session:
|
||||||
|
canvas_db = session.get(CanvasDB, id)
|
||||||
|
|
||||||
|
if not canvas_db:
|
||||||
|
# If not found, return empty default state
|
||||||
|
return ResponseModel(data=CanvasState(id=id, projectId=projectId))
|
||||||
|
|
||||||
|
# Map to CanvasState
|
||||||
|
canvas_state = CanvasState(
|
||||||
|
id=canvas_db.id,
|
||||||
|
projectId=canvas_db.project_id,
|
||||||
|
nodes=canvas_db.nodes,
|
||||||
|
connections=canvas_db.connections,
|
||||||
|
groups=canvas_db.groups,
|
||||||
|
history=canvas_db.history,
|
||||||
|
historyIndex=canvas_db.history_index,
|
||||||
|
updated_at=canvas_db.updated_at
|
||||||
|
)
|
||||||
|
return ResponseModel(data=canvas_state)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to fetch canvas {id}: {e}", exc_info=True)
|
||||||
|
raise BusinessException(
|
||||||
|
ErrorCode.CANVAS_NOT_FOUND,
|
||||||
|
"Failed to fetch canvas state",
|
||||||
|
{"canvas_id": id, "reason": str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post("/canvas", response_model=ResponseModel)
|
||||||
|
async def save_canvas_state(state: CanvasState):
|
||||||
|
"""Save canvas state
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BusinessException: Failed to save canvas
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with Session(engine) as session:
|
||||||
|
# Use merge to handle both insert and update gracefully
|
||||||
|
canvas_db = CanvasDB(
|
||||||
|
id=state.id,
|
||||||
|
project_id=state.projectId,
|
||||||
|
nodes=state.nodes,
|
||||||
|
connections=state.connections,
|
||||||
|
groups=state.groups,
|
||||||
|
history=state.history,
|
||||||
|
history_index=state.history_index,
|
||||||
|
updated_at=state.updated_at
|
||||||
|
)
|
||||||
|
|
||||||
|
session.merge(canvas_db)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return ResponseModel(data={"id": state.id, "updated": True})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save canvas {state.id}: {e}", exc_info=True)
|
||||||
|
raise BusinessException(
|
||||||
|
ErrorCode.UNKNOWN_ERROR,
|
||||||
|
"Failed to save canvas state",
|
||||||
|
{"canvas_id": state.id, "reason": str(e)}
|
||||||
|
)
|
||||||
428
backend/src/api/canvas_metadata.py
Normal file
428
backend/src/api/canvas_metadata.py
Normal file
@@ -0,0 +1,428 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional, List
|
||||||
|
from fastapi import APIRouter, Query, Depends
|
||||||
|
from sqlmodel import Session
|
||||||
|
from src.config.database import engine
|
||||||
|
from src.services.canvas_metadata_service import CanvasMetadataService
|
||||||
|
from src.models.schemas import (
|
||||||
|
ResponseModel,
|
||||||
|
CanvasMetadata,
|
||||||
|
CreateGeneralCanvasRequest,
|
||||||
|
UpdateCanvasMetadataRequest
|
||||||
|
)
|
||||||
|
from src.utils.errors import ResourceNotFoundException, BusinessException, ErrorCode
|
||||||
|
|
||||||
|
router = APIRouter(tags=["canvas_metadata"])
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_session():
|
||||||
|
""" 依赖 to get database session"""
|
||||||
|
with Session(engine) as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/projects/{project_id}/canvases", response_model=ResponseModel)
|
||||||
|
async def list_project_canvases(
|
||||||
|
project_id: str,
|
||||||
|
canvas_type: Optional[str] = Query(None, description="Filter by type: general, asset, storyboard"),
|
||||||
|
include_deleted: bool = Query(False, description="Include deleted canvases"),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
""" 查询 all canvases for a project.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project ID
|
||||||
|
canvas_type: Optional filter by canvas type (general, asset, storyboard)
|
||||||
|
include_deleted: Whether to include soft-deleted canvases
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of canvas metadata
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BusinessException: Failed to list canvases
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = CanvasMetadataService(session)
|
||||||
|
canvases = service.list_canvases(project_id, canvas_type, include_deleted)
|
||||||
|
|
||||||
|
# 转换 to dict format
|
||||||
|
canvas_data = [
|
||||||
|
CanvasMetadata(
|
||||||
|
id=c.id,
|
||||||
|
projectId=c.project_id,
|
||||||
|
canvasType=c.canvas_type,
|
||||||
|
relatedEntityType=c.related_entity_type,
|
||||||
|
relatedEntityId=c.related_entity_id,
|
||||||
|
name=c.name,
|
||||||
|
description=c.description,
|
||||||
|
orderIndex=c.order_index,
|
||||||
|
isPinned=c.is_pinned,
|
||||||
|
tags=c.tags,
|
||||||
|
nodeCount=c.node_count,
|
||||||
|
lastAccessedAt=c.last_accessed_at,
|
||||||
|
accessCount=c.access_count,
|
||||||
|
createdAt=c.created_at,
|
||||||
|
updatedAt=c.updated_at,
|
||||||
|
deletedAt=c.deleted_at
|
||||||
|
).model_dump(by_alias=True)
|
||||||
|
for c in canvases
|
||||||
|
]
|
||||||
|
|
||||||
|
return ResponseModel(data=canvas_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to list canvases for project {project_id}: {e}", exc_info=True)
|
||||||
|
raise BusinessException(
|
||||||
|
ErrorCode.UNKNOWN_ERROR,
|
||||||
|
"Failed to list canvases",
|
||||||
|
{"project_id": project_id, "reason": str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/canvases/{canvas_id}/metadata", response_model=ResponseModel)
|
||||||
|
async def get_canvas_metadata(
|
||||||
|
canvas_id: str,
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
""" Get canvas metadata by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
canvas_id: Canvas ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Canvas metadata
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ResourceNotFoundException: Canvas not found
|
||||||
|
BusinessException: Failed to get canvas metadata
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = CanvasMetadataService(session)
|
||||||
|
canvas = service.get_canvas(canvas_id)
|
||||||
|
|
||||||
|
if not canvas:
|
||||||
|
raise ResourceNotFoundException("canvas", canvas_id)
|
||||||
|
|
||||||
|
# 更新 access statistics
|
||||||
|
service.update_access_stats(canvas.id)
|
||||||
|
|
||||||
|
# 转换 to response format
|
||||||
|
canvas_data = CanvasMetadata(
|
||||||
|
id=canvas.id,
|
||||||
|
projectId=canvas.project_id,
|
||||||
|
canvasType=canvas.canvas_type,
|
||||||
|
relatedEntityType=canvas.related_entity_type,
|
||||||
|
relatedEntityId=canvas.related_entity_id,
|
||||||
|
name=canvas.name,
|
||||||
|
description=canvas.description,
|
||||||
|
orderIndex=canvas.order_index,
|
||||||
|
isPinned=canvas.is_pinned,
|
||||||
|
tags=canvas.tags,
|
||||||
|
nodeCount=canvas.node_count,
|
||||||
|
lastAccessedAt=canvas.last_accessed_at,
|
||||||
|
accessCount=canvas.access_count,
|
||||||
|
createdAt=canvas.created_at,
|
||||||
|
updatedAt=canvas.updated_at,
|
||||||
|
deletedAt=canvas.deleted_at
|
||||||
|
).model_dump(by_alias=True)
|
||||||
|
|
||||||
|
return ResponseModel(data=canvas_data)
|
||||||
|
except ResourceNotFoundException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get canvas metadata {canvas_id}: {e}", exc_info=True)
|
||||||
|
raise BusinessException(
|
||||||
|
ErrorCode.CANVAS_NOT_FOUND,
|
||||||
|
"Failed to get canvas metadata",
|
||||||
|
{"canvas_id": canvas_id, "reason": str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/assets/{asset_id}/canvas", response_model=ResponseModel)
|
||||||
|
async def get_asset_canvas(
|
||||||
|
asset_id: str,
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
""" Get canvas associated with an asset.
|
||||||
|
Automatically creates the canvas if it doesn't exist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
asset_id: Asset ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Canvas metadata
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ResourceNotFoundException: Asset not found
|
||||||
|
BusinessException: Failed to get asset canvas
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = CanvasMetadataService(session)
|
||||||
|
canvas = service.get_or_create_asset_canvas(asset_id)
|
||||||
|
|
||||||
|
# 转换 to response format
|
||||||
|
canvas_data = CanvasMetadata(
|
||||||
|
id=canvas.id,
|
||||||
|
projectId=canvas.project_id,
|
||||||
|
canvasType=canvas.canvas_type,
|
||||||
|
relatedEntityType=canvas.related_entity_type,
|
||||||
|
relatedEntityId=canvas.related_entity_id,
|
||||||
|
name=canvas.name,
|
||||||
|
description=canvas.description,
|
||||||
|
orderIndex=canvas.order_index,
|
||||||
|
isPinned=canvas.is_pinned,
|
||||||
|
tags=canvas.tags,
|
||||||
|
nodeCount=canvas.node_count,
|
||||||
|
lastAccessedAt=canvas.last_accessed_at,
|
||||||
|
accessCount=canvas.access_count,
|
||||||
|
createdAt=canvas.created_at,
|
||||||
|
updatedAt=canvas.updated_at,
|
||||||
|
deletedAt=canvas.deleted_at
|
||||||
|
).model_dump(by_alias=True)
|
||||||
|
|
||||||
|
return ResponseModel(data=canvas_data)
|
||||||
|
except ValueError as e:
|
||||||
|
raise ResourceNotFoundException("asset", asset_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get asset canvas {asset_id}: {e}", exc_info=True)
|
||||||
|
raise BusinessException(
|
||||||
|
ErrorCode.UNKNOWN_ERROR,
|
||||||
|
"Failed to get asset canvas",
|
||||||
|
{"asset_id": asset_id, "reason": str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/storyboards/{storyboard_id}/canvas", response_model=ResponseModel)
|
||||||
|
async def get_storyboard_canvas(
|
||||||
|
storyboard_id: str,
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
""" Get canvas associated with a storyboard.
|
||||||
|
Automatically creates the canvas if it doesn't exist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storyboard_id: Storyboard ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Canvas metadata
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ResourceNotFoundException: Storyboard not found
|
||||||
|
BusinessException: Failed to get storyboard canvas
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = CanvasMetadataService(session)
|
||||||
|
canvas = service.get_or_create_storyboard_canvas(storyboard_id)
|
||||||
|
|
||||||
|
# 转换 to response format
|
||||||
|
canvas_data = CanvasMetadata(
|
||||||
|
id=canvas.id,
|
||||||
|
projectId=canvas.project_id,
|
||||||
|
canvasType=canvas.canvas_type,
|
||||||
|
relatedEntityType=canvas.related_entity_type,
|
||||||
|
relatedEntityId=canvas.related_entity_id,
|
||||||
|
name=canvas.name,
|
||||||
|
description=canvas.description,
|
||||||
|
orderIndex=canvas.order_index,
|
||||||
|
isPinned=canvas.is_pinned,
|
||||||
|
tags=canvas.tags,
|
||||||
|
nodeCount=canvas.node_count,
|
||||||
|
lastAccessedAt=canvas.last_accessed_at,
|
||||||
|
accessCount=canvas.access_count,
|
||||||
|
createdAt=canvas.created_at,
|
||||||
|
updatedAt=canvas.updated_at,
|
||||||
|
deletedAt=canvas.deleted_at
|
||||||
|
).model_dump(by_alias=True)
|
||||||
|
|
||||||
|
return ResponseModel(data=canvas_data)
|
||||||
|
except ValueError as e:
|
||||||
|
raise ResourceNotFoundException("storyboard", storyboard_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get storyboard canvas {storyboard_id}: {e}", exc_info=True)
|
||||||
|
raise BusinessException(
|
||||||
|
ErrorCode.UNKNOWN_ERROR,
|
||||||
|
"Failed to get storyboard canvas",
|
||||||
|
{"storyboard_id": storyboard_id, "reason": str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/projects/{project_id}/canvases", response_model=ResponseModel)
|
||||||
|
async def create_general_canvas(
|
||||||
|
project_id: str,
|
||||||
|
request: CreateGeneralCanvasRequest,
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
""" Create a new general canvas for a project.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project ID
|
||||||
|
request: Canvas creation request with name and optional description
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created canvas metadata
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BusinessException: Failed to create canvas
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = CanvasMetadataService(session)
|
||||||
|
canvas = service.create_general_canvas(
|
||||||
|
project_id=project_id,
|
||||||
|
name=request.name,
|
||||||
|
description=request.description
|
||||||
|
)
|
||||||
|
|
||||||
|
# 转换 to response format
|
||||||
|
canvas_data = CanvasMetadata(
|
||||||
|
id=canvas.id,
|
||||||
|
projectId=canvas.project_id,
|
||||||
|
canvasType=canvas.canvas_type,
|
||||||
|
relatedEntityType=canvas.related_entity_type,
|
||||||
|
relatedEntityId=canvas.related_entity_id,
|
||||||
|
name=canvas.name,
|
||||||
|
description=canvas.description,
|
||||||
|
orderIndex=canvas.order_index,
|
||||||
|
isPinned=canvas.is_pinned,
|
||||||
|
tags=canvas.tags,
|
||||||
|
nodeCount=canvas.node_count,
|
||||||
|
lastAccessedAt=canvas.last_accessed_at,
|
||||||
|
accessCount=canvas.access_count,
|
||||||
|
createdAt=canvas.created_at,
|
||||||
|
updatedAt=canvas.updated_at,
|
||||||
|
deletedAt=canvas.deleted_at
|
||||||
|
).model_dump(by_alias=True)
|
||||||
|
|
||||||
|
return ResponseModel(data=canvas_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create general canvas for project {project_id}: {e}", exc_info=True)
|
||||||
|
raise BusinessException(
|
||||||
|
ErrorCode.UNKNOWN_ERROR,
|
||||||
|
"Failed to create canvas",
|
||||||
|
{"project_id": project_id, "reason": str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/canvases/{canvas_id}/metadata", response_model=ResponseModel)
|
||||||
|
async def update_canvas_metadata(
|
||||||
|
canvas_id: str,
|
||||||
|
request: UpdateCanvasMetadataRequest,
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
""" 更新 canvas metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
canvas_id: Canvas ID
|
||||||
|
request: Update request with optional fields
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated canvas metadata
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ResourceNotFoundException: Canvas not found
|
||||||
|
BusinessException: Failed to update canvas
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = CanvasMetadataService(session)
|
||||||
|
|
||||||
|
# 转换 request to dict, excluding unset fields
|
||||||
|
updates = request.model_dump(exclude_unset=True, by_alias=False)
|
||||||
|
|
||||||
|
canvas = service.update_canvas(canvas_id, updates)
|
||||||
|
|
||||||
|
if not canvas:
|
||||||
|
raise ResourceNotFoundException("canvas", canvas_id)
|
||||||
|
|
||||||
|
# 转换 to response format
|
||||||
|
canvas_data = CanvasMetadata(
|
||||||
|
id=canvas.id,
|
||||||
|
projectId=canvas.project_id,
|
||||||
|
canvasType=canvas.canvas_type,
|
||||||
|
relatedEntityType=canvas.related_entity_type,
|
||||||
|
relatedEntityId=canvas.related_entity_id,
|
||||||
|
name=canvas.name,
|
||||||
|
description=canvas.description,
|
||||||
|
orderIndex=canvas.order_index,
|
||||||
|
isPinned=canvas.is_pinned,
|
||||||
|
tags=canvas.tags,
|
||||||
|
nodeCount=canvas.node_count,
|
||||||
|
lastAccessedAt=canvas.last_accessed_at,
|
||||||
|
accessCount=canvas.access_count,
|
||||||
|
createdAt=canvas.created_at,
|
||||||
|
updatedAt=canvas.updated_at,
|
||||||
|
deletedAt=canvas.deleted_at
|
||||||
|
).model_dump(by_alias=True)
|
||||||
|
|
||||||
|
return ResponseModel(data=canvas_data)
|
||||||
|
except ResourceNotFoundException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to update canvas metadata {canvas_id}: {e}", exc_info=True)
|
||||||
|
raise BusinessException(
|
||||||
|
ErrorCode.UNKNOWN_ERROR,
|
||||||
|
"Failed to update canvas",
|
||||||
|
{"canvas_id": canvas_id, "reason": str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/projects/{project_id}/canvases/reorder", response_model=ResponseModel)
|
||||||
|
async def reorder_canvases(
|
||||||
|
project_id: str,
|
||||||
|
canvas_orders: List[dict],
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
""" 批处理 update canvas order indices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: Project ID
|
||||||
|
canvas_orders: List of {id, order_index} objects
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Success confirmation
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BusinessException: Failed to reorder canvases
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = CanvasMetadataService(session)
|
||||||
|
service.reorder_canvases(canvas_orders)
|
||||||
|
|
||||||
|
return ResponseModel(data={"updated": True})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to reorder canvases for project {project_id}: {e}", exc_info=True)
|
||||||
|
raise BusinessException(
|
||||||
|
ErrorCode.UNKNOWN_ERROR,
|
||||||
|
"Failed to reorder canvases",
|
||||||
|
{"project_id": project_id, "reason": str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/canvases/{canvas_id}", response_model=ResponseModel)
|
||||||
|
async def delete_canvas(
|
||||||
|
canvas_id: str,
|
||||||
|
hard_delete: bool = Query(False, description="Permanently delete (true) or soft delete (false)"),
|
||||||
|
session: Session = Depends(get_session)
|
||||||
|
):
|
||||||
|
""" 删除 a canvas.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
canvas_id: Canvas ID
|
||||||
|
hard_delete: If true, permanently delete; if false, soft delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Success confirmation
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BusinessException: Failed to delete canvas
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
service = CanvasMetadataService(session)
|
||||||
|
service.delete_canvas(canvas_id, hard_delete)
|
||||||
|
|
||||||
|
return ResponseModel(data={"deleted": True})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete canvas {canvas_id}: {e}", exc_info=True)
|
||||||
|
raise BusinessException(
|
||||||
|
ErrorCode.UNKNOWN_ERROR,
|
||||||
|
"Failed to delete canvas",
|
||||||
|
{"canvas_id": canvas_id, "hard_delete": hard_delete, "reason": str(e)}
|
||||||
|
)
|
||||||
61
backend/src/api/chat.py
Normal file
61
backend/src/api/chat.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
import logging
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from pydantic import BaseModel, Field, ConfigDict
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from src.services.agent_engine import AgentScopeService
|
||||||
|
from src.utils.errors import BusinessException, ErrorCode
|
||||||
|
from src.auth.dependencies import get_current_user, UserAuth
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class Message(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
class ChatCompletionRequest(BaseModel):
|
||||||
|
messages: List[Message]
|
||||||
|
model: Optional[str] = None
|
||||||
|
stream: bool = True
|
||||||
|
temperature: Optional[float] = 0.7
|
||||||
|
max_tokens: Optional[int] = Field(None, alias="maxTokens")
|
||||||
|
|
||||||
|
model_config = ConfigDict(populate_by_name=True)
|
||||||
|
|
||||||
|
@router.post("/completions")
|
||||||
|
async def chat_completions(
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
OpenAI-compatible chat completion endpoint.
|
||||||
|
Supports streaming and non-streaming modes.
|
||||||
|
Now integrated with AgentScope.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BusinessException: Chat completion failed
|
||||||
|
"""
|
||||||
|
if not request.stream:
|
||||||
|
raise BusinessException(
|
||||||
|
ErrorCode.INVALID_PARAMETER,
|
||||||
|
"Non-stream mode is not supported",
|
||||||
|
{"field": "stream", "expected": True}
|
||||||
|
)
|
||||||
|
|
||||||
|
messages_payload = [m.model_dump() for m in request.messages]
|
||||||
|
|
||||||
|
try:
|
||||||
|
agent_service = AgentScopeService(user_id=current_user.id)
|
||||||
|
return StreamingResponse(
|
||||||
|
agent_service.stream_chat(messages_payload),
|
||||||
|
media_type="text/event-stream"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"AgentScope service error: {e}", exc_info=True)
|
||||||
|
raise BusinessException(
|
||||||
|
ErrorCode.GENERATION_FAILED,
|
||||||
|
"Chat completion failed",
|
||||||
|
{"reason": str(e)}
|
||||||
|
)
|
||||||
45
backend/src/api/config/__init__.py
Normal file
45
backend/src/api/config/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
"""Config API routes - split into multiple modules for maintainability.
|
||||||
|
|
||||||
|
This package provides configuration management endpoints including:
|
||||||
|
- System configuration (/config/system)
|
||||||
|
- Provider and model configuration (/config/providers, /config/models, etc.)
|
||||||
|
- Style configuration (/config/styles)
|
||||||
|
- Health monitoring (/config/health)
|
||||||
|
- Configuration validation (/config/validate)
|
||||||
|
- Admin operations (/admin/*)
|
||||||
|
- Storage utilities (/storage/*)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from . import system, providers, styles, health, validation, admin, storage
|
||||||
|
|
||||||
|
# Create main router
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
# Include all sub-routers
|
||||||
|
router.include_router(system.router)
|
||||||
|
router.include_router(providers.router)
|
||||||
|
router.include_router(styles.router)
|
||||||
|
router.include_router(health.router)
|
||||||
|
router.include_router(validation.router)
|
||||||
|
router.include_router(admin.router)
|
||||||
|
router.include_router(storage.router)
|
||||||
|
|
||||||
|
# Export models for backward compatibility
|
||||||
|
from .models import (
|
||||||
|
SignUrlRequest,
|
||||||
|
SystemConfig,
|
||||||
|
ModelRegistrationRequest,
|
||||||
|
ProviderKeyField,
|
||||||
|
ProviderConfigResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"router",
|
||||||
|
"SignUrlRequest",
|
||||||
|
"SystemConfig",
|
||||||
|
"ModelRegistrationRequest",
|
||||||
|
"ProviderKeyField",
|
||||||
|
"ProviderConfigResponse",
|
||||||
|
]
|
||||||
321
backend/src/api/config/admin.py
Normal file
321
backend/src/api/config/admin.py
Normal file
@@ -0,0 +1,321 @@
|
|||||||
|
"""Provider and model management endpoints (Admin Only)."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
from src.config.settings import REDIS_ENABLED
|
||||||
|
from src.models.schemas import ResponseModel
|
||||||
|
from src.auth.dependencies import require_admin
|
||||||
|
from src.auth.models import UserAuth
|
||||||
|
from src.services.cache_service import get_cache_service
|
||||||
|
from src.services.provider.registry import ModelRegistry, ModelType
|
||||||
|
from src.utils.errors import AppException, ErrorCode
|
||||||
|
from src.services.provider.health import health_monitor
|
||||||
|
from src.utils.service_loader import register_service
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/admin/providers/{provider_id}/toggle", response_model=ResponseModel)
|
||||||
|
async def toggle_provider(
|
||||||
|
provider_id: str,
|
||||||
|
enabled: bool,
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
):
|
||||||
|
"""Enable/disable Provider (Admin only)."""
|
||||||
|
try:
|
||||||
|
provider_config = ModelRegistry.get_provider_config(provider_id)
|
||||||
|
if not provider_config:
|
||||||
|
raise AppException(
|
||||||
|
message=f"Provider not found: {provider_id}",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_config["enabled"] = enabled
|
||||||
|
|
||||||
|
src_dir = os.path.dirname(os.path.dirname(__file__))
|
||||||
|
custom_providers_path = os.path.join(src_dir, "config", "services", "custom_providers.json")
|
||||||
|
|
||||||
|
existing_configs = []
|
||||||
|
if os.path.exists(custom_providers_path):
|
||||||
|
try:
|
||||||
|
with open(custom_providers_path, 'r', encoding='utf-8') as f:
|
||||||
|
existing_configs = json.load(f)
|
||||||
|
except (json.JSONDecodeError, OSError) as e:
|
||||||
|
logger.warning("Failed to load custom providers config %s: %s", custom_providers_path, e)
|
||||||
|
|
||||||
|
updated = False
|
||||||
|
for i, config in enumerate(existing_configs):
|
||||||
|
if config.get("id") == provider_id:
|
||||||
|
existing_configs[i] = {**config, **provider_config}
|
||||||
|
updated = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not updated:
|
||||||
|
existing_configs.append(provider_config)
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(custom_providers_path), exist_ok=True)
|
||||||
|
with open(custom_providers_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(existing_configs, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
if enabled:
|
||||||
|
register_service(provider_config)
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data={
|
||||||
|
"provider_id": provider_id,
|
||||||
|
"enabled": enabled,
|
||||||
|
"message": f"Provider {'enabled' if enabled else 'disabled'} successfully",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except AppException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error toggling provider: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/admin/providers/{provider_id}", response_model=ResponseModel)
|
||||||
|
async def delete_provider(
|
||||||
|
provider_id: str,
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
):
|
||||||
|
"""Delete custom Provider (Admin only)."""
|
||||||
|
try:
|
||||||
|
src_dir = os.path.dirname(os.path.dirname(__file__))
|
||||||
|
custom_providers_path = os.path.join(src_dir, "config", "services", "custom_providers.json")
|
||||||
|
|
||||||
|
if not os.path.exists(custom_providers_path):
|
||||||
|
raise AppException(
|
||||||
|
message="Custom providers configuration not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(custom_providers_path, 'r', encoding='utf-8') as f:
|
||||||
|
existing_configs = json.load(f)
|
||||||
|
|
||||||
|
initial_len = len(existing_configs)
|
||||||
|
existing_configs = [c for c in existing_configs if c.get("id") != provider_id]
|
||||||
|
|
||||||
|
if len(existing_configs) == initial_len:
|
||||||
|
raise AppException(
|
||||||
|
message="Provider not found or not a custom provider",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(custom_providers_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(existing_configs, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data={
|
||||||
|
"provider_id": provider_id,
|
||||||
|
"deleted": True,
|
||||||
|
"message": "Provider deleted successfully",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except AppException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting provider: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/admin/models/{model_id}/toggle", response_model=ResponseModel)
|
||||||
|
async def toggle_model(
|
||||||
|
model_id: str,
|
||||||
|
enabled: bool,
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
):
|
||||||
|
"""Enable/disable model (Admin only)."""
|
||||||
|
try:
|
||||||
|
model_config = ModelRegistry.get_config(model_id)
|
||||||
|
if not model_config:
|
||||||
|
raise AppException(
|
||||||
|
message=f"Model not found: {model_id}",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
model_config["enabled"] = enabled
|
||||||
|
|
||||||
|
src_dir = os.path.dirname(os.path.dirname(__file__))
|
||||||
|
custom_models_path = os.path.join(src_dir, "config", "services", "custom_models.json")
|
||||||
|
|
||||||
|
existing_configs = []
|
||||||
|
if os.path.exists(custom_models_path):
|
||||||
|
try:
|
||||||
|
with open(custom_models_path, 'r', encoding='utf-8') as f:
|
||||||
|
existing_configs = json.load(f)
|
||||||
|
except (json.JSONDecodeError, OSError) as e:
|
||||||
|
logger.warning("Failed to load custom models config %s: %s", custom_models_path, e)
|
||||||
|
|
||||||
|
updated = False
|
||||||
|
for i, config in enumerate(existing_configs):
|
||||||
|
if config.get("id") == model_id:
|
||||||
|
existing_configs[i] = {**config, "enabled": enabled}
|
||||||
|
updated = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not updated:
|
||||||
|
existing_configs.append({**model_config, "enabled": enabled})
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(custom_models_path), exist_ok=True)
|
||||||
|
with open(custom_models_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(existing_configs, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
if enabled:
|
||||||
|
register_service(model_config)
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data={
|
||||||
|
"model_id": model_id,
|
||||||
|
"enabled": enabled,
|
||||||
|
"message": f"Model {'enabled' if enabled else 'disabled'} successfully",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except AppException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error toggling model: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/admin/models/{model_id}/default", response_model=ResponseModel)
|
||||||
|
async def set_default_model(
|
||||||
|
model_id: str,
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
):
|
||||||
|
"""Set default model (Admin only)."""
|
||||||
|
try:
|
||||||
|
model_config = ModelRegistry.get_config(model_id)
|
||||||
|
if not model_config:
|
||||||
|
raise AppException(
|
||||||
|
message=f"Model not found: {model_id}",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
model_type_str = model_config.get("type")
|
||||||
|
if not model_type_str:
|
||||||
|
raise AppException(
|
||||||
|
message="Model type not found",
|
||||||
|
code=ErrorCode.INVALID_PARAMETER,
|
||||||
|
status_code=400
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_type = ModelType(model_type_str.lower())
|
||||||
|
except ValueError:
|
||||||
|
raise AppException(
|
||||||
|
message=f"Invalid model type: {model_type_str}",
|
||||||
|
code=ErrorCode.INVALID_PARAMETER,
|
||||||
|
status_code=400
|
||||||
|
)
|
||||||
|
|
||||||
|
ModelRegistry.set_default_by_id(model_type, model_id)
|
||||||
|
|
||||||
|
config_key_map = {
|
||||||
|
"image": "defaultImageModel",
|
||||||
|
"video": "defaultVideoModel",
|
||||||
|
"audio": "defaultAudioModel",
|
||||||
|
"lyrics": "defaultLyricsModel",
|
||||||
|
"music": "defaultMusicModel",
|
||||||
|
"llm": "defaultLLMModel",
|
||||||
|
}
|
||||||
|
|
||||||
|
config_key = config_key_map.get(model_type_str.lower())
|
||||||
|
if config_key:
|
||||||
|
src_dir = os.path.dirname(os.path.dirname(__file__))
|
||||||
|
config_path = os.path.join(src_dir, "config", "user_config.json")
|
||||||
|
|
||||||
|
existing_data = {}
|
||||||
|
if os.path.exists(config_path):
|
||||||
|
with open(config_path, 'r', encoding='utf-8') as f:
|
||||||
|
existing_data = json.load(f)
|
||||||
|
|
||||||
|
existing_data[config_key] = model_id
|
||||||
|
|
||||||
|
with open(config_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(existing_data, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
if REDIS_ENABLED:
|
||||||
|
cache = get_cache_service()
|
||||||
|
await cache.delete("config:defaults")
|
||||||
|
await cache.delete("config:system")
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data={
|
||||||
|
"model_id": model_id,
|
||||||
|
"model_type": model_type_str,
|
||||||
|
"is_default": True,
|
||||||
|
"message": f"Default {model_type_str} model set to {model_id}",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except AppException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error setting default model: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/admin/models/{model_id}/test", response_model=ResponseModel)
|
||||||
|
async def test_model(
|
||||||
|
model_id: str,
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
):
|
||||||
|
"""Test model connection (Admin only)."""
|
||||||
|
try:
|
||||||
|
service = ModelRegistry.get(model_id)
|
||||||
|
if not service:
|
||||||
|
raise AppException(
|
||||||
|
message=f"Model service not found: {model_id}",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await health_monitor.check_service_health(model_id, service)
|
||||||
|
|
||||||
|
health_monitor.update_health(model_id, result)
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data={
|
||||||
|
"model_id": model_id,
|
||||||
|
"success": result.status.value == "healthy",
|
||||||
|
"status": result.status.value,
|
||||||
|
"latency_ms": result.latency_ms,
|
||||||
|
"message": "Model test completed",
|
||||||
|
"error": result.error,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except AppException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error testing model: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
137
backend/src/api/config/health.py
Normal file
137
backend/src/api/config/health.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
"""Health check and monitoring endpoints."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from src.models.schemas import ResponseModel
|
||||||
|
from src.services.provider.registry import ModelRegistry
|
||||||
|
from src.utils.errors import AppException, ErrorCode
|
||||||
|
from src.services.provider.health import health_monitor
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/config/health", response_model=ResponseModel)
|
||||||
|
async def get_health_status():
|
||||||
|
"""Get health status of all registered services."""
|
||||||
|
try:
|
||||||
|
summary = health_monitor.get_health_summary()
|
||||||
|
all_health = health_monitor.get_all_health()
|
||||||
|
|
||||||
|
services = {}
|
||||||
|
for service_id, health in all_health.items():
|
||||||
|
services[service_id] = {
|
||||||
|
"status": health.status.value,
|
||||||
|
"last_check": health.last_check.isoformat() if health.last_check else None,
|
||||||
|
"last_success": health.last_success.isoformat() if health.last_success else None,
|
||||||
|
"last_failure": health.last_failure.isoformat() if health.last_failure else None,
|
||||||
|
"consecutive_failures": health.consecutive_failures,
|
||||||
|
"consecutive_successes": health.consecutive_successes,
|
||||||
|
"total_checks": health.total_checks,
|
||||||
|
"total_failures": health.total_failures,
|
||||||
|
"success_rate": health.get_success_rate(),
|
||||||
|
"avg_latency_ms": health.avg_latency_ms,
|
||||||
|
"should_circuit_break": health.should_circuit_break()
|
||||||
|
}
|
||||||
|
|
||||||
|
return ResponseModel(data={
|
||||||
|
"summary": summary,
|
||||||
|
"services": services,
|
||||||
|
"unhealthy": health_monitor.get_unhealthy_services(),
|
||||||
|
"degraded": health_monitor.get_degraded_services()
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get health status: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/config/health/{service_id}", response_model=ResponseModel)
|
||||||
|
async def get_service_health(service_id: str):
|
||||||
|
"""Get detailed health status for a specific service."""
|
||||||
|
try:
|
||||||
|
health = health_monitor.get_health(service_id)
|
||||||
|
|
||||||
|
if not health:
|
||||||
|
raise AppException(
|
||||||
|
message=f"Service not found: {service_id}",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
history = [
|
||||||
|
{
|
||||||
|
"status": result.status.value,
|
||||||
|
"latency_ms": result.latency_ms,
|
||||||
|
"timestamp": result.timestamp.isoformat(),
|
||||||
|
"error": result.error,
|
||||||
|
"metadata": result.metadata
|
||||||
|
}
|
||||||
|
for result in health.history
|
||||||
|
]
|
||||||
|
|
||||||
|
return ResponseModel(data={
|
||||||
|
"service_id": service_id,
|
||||||
|
"status": health.status.value,
|
||||||
|
"last_check": health.last_check.isoformat() if health.last_check else None,
|
||||||
|
"last_success": health.last_success.isoformat() if health.last_success else None,
|
||||||
|
"last_failure": health.last_failure.isoformat() if health.last_failure else None,
|
||||||
|
"consecutive_failures": health.consecutive_failures,
|
||||||
|
"consecutive_successes": health.consecutive_successes,
|
||||||
|
"total_checks": health.total_checks,
|
||||||
|
"total_failures": health.total_failures,
|
||||||
|
"success_rate": health.get_success_rate(),
|
||||||
|
"avg_latency_ms": health.avg_latency_ms,
|
||||||
|
"should_circuit_break": health.should_circuit_break(),
|
||||||
|
"history": history
|
||||||
|
})
|
||||||
|
except AppException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get service health: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/config/health/{service_id}/check", response_model=ResponseModel)
|
||||||
|
async def check_service_health(service_id: str):
|
||||||
|
"""Manually trigger a health check for a specific service."""
|
||||||
|
try:
|
||||||
|
service = ModelRegistry.get(service_id)
|
||||||
|
if not service:
|
||||||
|
raise AppException(
|
||||||
|
message=f"Service not found: {service_id}",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
health_monitor.register_service(service_id)
|
||||||
|
|
||||||
|
result = await health_monitor.check_service_health(service_id, service)
|
||||||
|
|
||||||
|
health_monitor.update_health(service_id, result)
|
||||||
|
|
||||||
|
return ResponseModel(data={
|
||||||
|
"service_id": service_id,
|
||||||
|
"status": result.status.value,
|
||||||
|
"latency_ms": result.latency_ms,
|
||||||
|
"timestamp": result.timestamp.isoformat(),
|
||||||
|
"error": result.error
|
||||||
|
})
|
||||||
|
except AppException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to check service health: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
49
backend/src/api/config/models.py
Normal file
49
backend/src/api/config/models.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
"""Pydantic models for config API."""
|
||||||
|
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class SignUrlRequest(BaseModel):
|
||||||
|
"""Request to sign a URL."""
|
||||||
|
url: str
|
||||||
|
|
||||||
|
|
||||||
|
class SystemConfig(BaseModel):
|
||||||
|
"""System configuration model."""
|
||||||
|
default_image_model: Optional[str] = None
|
||||||
|
default_video_model: Optional[str] = None
|
||||||
|
default_audio_model: Optional[str] = None
|
||||||
|
default_llm_model: Optional[str] = None
|
||||||
|
default_style: Optional[str] = None
|
||||||
|
default_resolution: Optional[str] = None
|
||||||
|
default_ratio: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRegistrationRequest(BaseModel):
|
||||||
|
"""Request to register a new model."""
|
||||||
|
model_id: str
|
||||||
|
provider: str
|
||||||
|
name: str
|
||||||
|
type: str
|
||||||
|
enabled: bool = True
|
||||||
|
config: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderKeyField(BaseModel):
|
||||||
|
"""Provider key field configuration."""
|
||||||
|
name: str
|
||||||
|
label: str
|
||||||
|
placeholder: str
|
||||||
|
required: bool
|
||||||
|
type: Optional[str] = "text" # 'password' or 'text'
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderConfigResponse(BaseModel):
|
||||||
|
"""Provider configuration response."""
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
icon: str
|
||||||
|
description: str
|
||||||
|
fields: List[ProviderKeyField]
|
||||||
|
helpUrl: Optional[str] = None
|
||||||
410
backend/src/api/config/providers.py
Normal file
410
backend/src/api/config/providers.py
Normal file
@@ -0,0 +1,410 @@
|
|||||||
|
"""Provider and model configuration endpoints."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from src.config.settings import REDIS_ENABLED
|
||||||
|
from src.models.schemas import ResponseModel
|
||||||
|
from src.services.cache_service import get_cache_service
|
||||||
|
from src.services.provider.registry import ModelRegistry, ModelType
|
||||||
|
from src.utils.errors import BusinessException, ErrorCode, ResourceNotFoundException, InvalidParameterException
|
||||||
|
from src.utils.service_loader import register_service
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRegistrationRequest(BaseModel):
|
||||||
|
"""Request to register a new model."""
|
||||||
|
name: str
|
||||||
|
model_id: str
|
||||||
|
provider: str
|
||||||
|
type: str
|
||||||
|
family: Optional[str] = None
|
||||||
|
capabilities: Optional[Dict[str, bool]] = None
|
||||||
|
variants: Optional[Dict[str, str]] = None
|
||||||
|
resolutions: Optional[Any] = None
|
||||||
|
durations: Optional[Any] = None
|
||||||
|
voices: Optional[List[Dict[str, str]]] = None
|
||||||
|
args: Optional[List[Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderKeyField(BaseModel):
|
||||||
|
"""Provider key field configuration."""
|
||||||
|
name: str
|
||||||
|
label: str
|
||||||
|
placeholder: str
|
||||||
|
required: bool
|
||||||
|
type: Optional[str] = "text"
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderConfigResponse(BaseModel):
|
||||||
|
"""Provider configuration response."""
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
fields: List[ProviderKeyField]
|
||||||
|
helpUrl: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
# Supported providers configuration with key field definitions
|
||||||
|
def _build_provider_configs() -> List[ProviderConfigResponse]:
|
||||||
|
"""Build provider configs from ModelRegistry metadata and registered models.
|
||||||
|
|
||||||
|
This dynamically generates provider configurations from provider.json files,
|
||||||
|
avoiding hardcoded configurations.
|
||||||
|
"""
|
||||||
|
configs = []
|
||||||
|
|
||||||
|
# Get all registered providers from ModelRegistry
|
||||||
|
providers = ModelRegistry.list_providers()
|
||||||
|
provider_ids = {p["id"] for p in providers}
|
||||||
|
|
||||||
|
# Build config for each provider
|
||||||
|
for provider_id in provider_ids:
|
||||||
|
metadata = ModelRegistry.get_provider_metadata(provider_id) or {}
|
||||||
|
|
||||||
|
# Get fields from metadata or use defaults
|
||||||
|
fields_data = metadata.get("fields", [])
|
||||||
|
fields = [ProviderKeyField(**f) for f in fields_data] if fields_data else [
|
||||||
|
ProviderKeyField(name="apiKey", label="API Key", placeholder="sk-...", required=True, type="password")
|
||||||
|
]
|
||||||
|
|
||||||
|
# Build config response
|
||||||
|
config = ProviderConfigResponse(
|
||||||
|
id=provider_id,
|
||||||
|
name=metadata.get("name") or provider_id.capitalize(),
|
||||||
|
description=metadata.get("description", ""),
|
||||||
|
fields=fields,
|
||||||
|
helpUrl=metadata.get("helpUrl") or metadata.get("dashboard_url")
|
||||||
|
)
|
||||||
|
configs.append(config)
|
||||||
|
|
||||||
|
return configs
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/config/providers", response_model=ResponseModel)
|
||||||
|
async def get_providers():
|
||||||
|
"""Get list of supported providers and their capabilities (cached for 5 minutes)."""
|
||||||
|
cache_key = "config:providers"
|
||||||
|
|
||||||
|
if REDIS_ENABLED:
|
||||||
|
cache = get_cache_service()
|
||||||
|
cached_providers = await cache.get(cache_key)
|
||||||
|
if cached_providers is not None:
|
||||||
|
return ResponseModel(data=cached_providers)
|
||||||
|
|
||||||
|
providers = ModelRegistry.list_providers()
|
||||||
|
result = {"providers": providers}
|
||||||
|
|
||||||
|
if REDIS_ENABLED:
|
||||||
|
cache = get_cache_service()
|
||||||
|
await cache.set(cache_key, result, ttl=300)
|
||||||
|
|
||||||
|
return ResponseModel(data=result)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/config/models", response_model=ResponseModel)
|
||||||
|
async def register_new_model(request: ModelRegistrationRequest):
|
||||||
|
"""Register a new model configuration.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidParameterException: Invalid model configuration
|
||||||
|
ResourceNotFoundException: Template not found
|
||||||
|
BusinessException: Failed to register model
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
provider = request.provider.lower()
|
||||||
|
model_type = request.type.lower()
|
||||||
|
family = request.family.lower() if request.family else "default"
|
||||||
|
|
||||||
|
module_name = ""
|
||||||
|
class_name = ""
|
||||||
|
|
||||||
|
# Find template service from registry
|
||||||
|
try:
|
||||||
|
type_enum = None
|
||||||
|
try:
|
||||||
|
type_enum = ModelType(model_type)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
services = ModelRegistry.find_services(
|
||||||
|
provider=provider,
|
||||||
|
model_type=type_enum
|
||||||
|
)
|
||||||
|
|
||||||
|
if not services and not type_enum:
|
||||||
|
all_services = ModelRegistry.find_services(provider=provider)
|
||||||
|
services = [s for s in all_services if s.get('type') == model_type]
|
||||||
|
|
||||||
|
if not services:
|
||||||
|
raise ResourceNotFoundException("template", f"{provider}/{model_type}")
|
||||||
|
|
||||||
|
template = services[0]
|
||||||
|
if family and family != "default":
|
||||||
|
for svc in services:
|
||||||
|
svc_id = svc.get('id', '').lower()
|
||||||
|
svc_class = (svc.get('class') or svc.get('class_name', '')).lower()
|
||||||
|
if family in svc_id or family in svc_class:
|
||||||
|
template = svc
|
||||||
|
break
|
||||||
|
|
||||||
|
module_name = template.get('module')
|
||||||
|
class_name = template.get('class') or template.get('class_name')
|
||||||
|
|
||||||
|
if not module_name or not class_name:
|
||||||
|
raise BusinessException(
|
||||||
|
ErrorCode.UNKNOWN_ERROR,
|
||||||
|
"Template service configuration is missing module or class",
|
||||||
|
{"template": template}
|
||||||
|
)
|
||||||
|
|
||||||
|
except (ResourceNotFoundException, BusinessException):
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error finding template for {provider}/{model_type}: {e}", exc_info=True)
|
||||||
|
raise BusinessException(
|
||||||
|
ErrorCode.UNKNOWN_ERROR,
|
||||||
|
"Failed to resolve service template",
|
||||||
|
{"provider": provider, "type": model_type, "reason": str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
|
service_id = f"{provider}-{request.model_id.replace('.', '-').replace('/', '-')}"
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"id": service_id,
|
||||||
|
"module": module_name,
|
||||||
|
"class": class_name,
|
||||||
|
"name": request.name,
|
||||||
|
"args": [request.model_id],
|
||||||
|
"type": model_type,
|
||||||
|
"provider": provider,
|
||||||
|
"enabled": True
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.capabilities:
|
||||||
|
config["capabilities"] = request.capabilities
|
||||||
|
if request.variants:
|
||||||
|
config["variants"] = request.variants
|
||||||
|
if request.resolutions:
|
||||||
|
config["resolutions"] = request.resolutions
|
||||||
|
if request.durations:
|
||||||
|
config["durations"] = request.durations
|
||||||
|
if request.args:
|
||||||
|
config["args"] = request.args
|
||||||
|
|
||||||
|
src_dir = os.path.dirname(os.path.dirname(__file__))
|
||||||
|
custom_config_path = os.path.join(src_dir, "config", "services", "custom_models.json")
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(custom_config_path), exist_ok=True)
|
||||||
|
|
||||||
|
existing_configs = []
|
||||||
|
if os.path.exists(custom_config_path):
|
||||||
|
try:
|
||||||
|
with open(custom_config_path, 'r', encoding='utf-8') as f:
|
||||||
|
existing_configs = json.load(f)
|
||||||
|
except (json.JSONDecodeError, OSError) as e:
|
||||||
|
logger.warning("Failed to load custom config %s: %s", custom_config_path, e)
|
||||||
|
|
||||||
|
existing_configs = [c for c in existing_configs if c.get("id") != service_id]
|
||||||
|
existing_configs.append(config)
|
||||||
|
|
||||||
|
with open(custom_config_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(existing_configs, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
register_service(config)
|
||||||
|
|
||||||
|
if REDIS_ENABLED:
|
||||||
|
cache = get_cache_service()
|
||||||
|
await cache.delete("config:models")
|
||||||
|
await cache.delete("config:defaults")
|
||||||
|
|
||||||
|
return ResponseModel(data={"status": "success", "model": config})
|
||||||
|
|
||||||
|
except (ResourceNotFoundException, BusinessException):
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to register model: {e}", exc_info=True)
|
||||||
|
raise BusinessException(
|
||||||
|
ErrorCode.UNKNOWN_ERROR,
|
||||||
|
"Failed to register model",
|
||||||
|
{"reason": str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/config/models", response_model=ResponseModel)
|
||||||
|
async def get_models_config():
|
||||||
|
"""Get all registered models configuration grouped by type (cached for 5 minutes)."""
|
||||||
|
cache_key = "config:models"
|
||||||
|
|
||||||
|
if REDIS_ENABLED:
|
||||||
|
cache = get_cache_service()
|
||||||
|
cached_models = await cache.get(cache_key)
|
||||||
|
if cached_models is not None:
|
||||||
|
return ResponseModel(data=cached_models)
|
||||||
|
|
||||||
|
models = ModelRegistry.list_models()
|
||||||
|
|
||||||
|
grouped = {
|
||||||
|
"image": {},
|
||||||
|
"video": {},
|
||||||
|
"audio": {},
|
||||||
|
"lyrics": {},
|
||||||
|
"music": {},
|
||||||
|
"llm": {}
|
||||||
|
}
|
||||||
|
|
||||||
|
for model_id, config_dict in models.items():
|
||||||
|
model_config = config_dict.copy() if isinstance(config_dict, dict) else {}
|
||||||
|
|
||||||
|
model_type_str = model_config.get("type")
|
||||||
|
if not model_type_str or model_type_str not in grouped:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_type = ModelType(model_type_str.lower())
|
||||||
|
default_id = ModelRegistry.get_default_id(model_type)
|
||||||
|
model_config["is_default"] = (model_id == default_id)
|
||||||
|
except (ValueError, KeyError):
|
||||||
|
model_config["is_default"] = False
|
||||||
|
|
||||||
|
model_config["id"] = model_id
|
||||||
|
|
||||||
|
if "provider" not in model_config:
|
||||||
|
if "/" in model_id:
|
||||||
|
model_config["provider"] = model_id.split("/", 1)[0]
|
||||||
|
|
||||||
|
if "model_key" not in model_config:
|
||||||
|
if "/" in model_id:
|
||||||
|
model_config["model_key"] = model_id.split("/", 1)[1]
|
||||||
|
|
||||||
|
grouped[model_type_str][model_id] = model_config
|
||||||
|
|
||||||
|
if REDIS_ENABLED:
|
||||||
|
cache = get_cache_service()
|
||||||
|
await cache.set(cache_key, grouped, ttl=300)
|
||||||
|
|
||||||
|
return ResponseModel(data=grouped)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/config/models/search", response_model=ResponseModel)
|
||||||
|
async def search_models(
|
||||||
|
provider: Optional[str] = None,
|
||||||
|
type: Optional[str] = None,
|
||||||
|
enabled_only: bool = True
|
||||||
|
):
|
||||||
|
"""Search for models by criteria."""
|
||||||
|
try:
|
||||||
|
model_type = None
|
||||||
|
if type:
|
||||||
|
try:
|
||||||
|
model_type = ModelType(type.lower())
|
||||||
|
except ValueError:
|
||||||
|
raise InvalidParameterException("type", f"Invalid model type: {type}")
|
||||||
|
|
||||||
|
results = ModelRegistry.find_services(
|
||||||
|
provider=provider,
|
||||||
|
model_type=model_type,
|
||||||
|
enabled_only=enabled_only
|
||||||
|
)
|
||||||
|
|
||||||
|
return ResponseModel(data={
|
||||||
|
"count": len(results),
|
||||||
|
"models": results
|
||||||
|
})
|
||||||
|
except InvalidParameterException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to search models: {e}", exc_info=True)
|
||||||
|
raise BusinessException(
|
||||||
|
ErrorCode.UNKNOWN_ERROR,
|
||||||
|
"Failed to search models",
|
||||||
|
{"provider": provider, "type": type, "reason": str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/config/models/{model_id}", response_model=ResponseModel)
|
||||||
|
async def get_model_config(model_id: str):
|
||||||
|
"""Get detailed configuration for a specific model."""
|
||||||
|
try:
|
||||||
|
config = ModelRegistry.get_config(model_id)
|
||||||
|
if not config:
|
||||||
|
raise ResourceNotFoundException("model", model_id)
|
||||||
|
|
||||||
|
return ResponseModel(data=config)
|
||||||
|
except ResourceNotFoundException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get model config: {e}", exc_info=True)
|
||||||
|
raise BusinessException(
|
||||||
|
ErrorCode.UNKNOWN_ERROR,
|
||||||
|
"Failed to get model configuration",
|
||||||
|
{"model_id": model_id, "reason": str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/config/defaults", response_model=ResponseModel)
|
||||||
|
async def get_defaults():
|
||||||
|
"""Get default models for each type (cached for 5 minutes)."""
|
||||||
|
cache_key = "config:defaults"
|
||||||
|
|
||||||
|
if REDIS_ENABLED:
|
||||||
|
cache = get_cache_service()
|
||||||
|
cached_defaults = await cache.get(cache_key)
|
||||||
|
if cached_defaults is not None:
|
||||||
|
return ResponseModel(data=cached_defaults)
|
||||||
|
|
||||||
|
try:
|
||||||
|
defaults = {}
|
||||||
|
for model_type in ModelType:
|
||||||
|
default_id = ModelRegistry.get_default_id(model_type)
|
||||||
|
if default_id:
|
||||||
|
defaults[model_type.value] = {
|
||||||
|
"id": default_id,
|
||||||
|
"config": ModelRegistry.get_config(default_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
if REDIS_ENABLED:
|
||||||
|
cache = get_cache_service()
|
||||||
|
await cache.set(cache_key, defaults, ttl=300)
|
||||||
|
|
||||||
|
return ResponseModel(data=defaults)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get defaults: {e}", exc_info=True)
|
||||||
|
raise BusinessException(
|
||||||
|
ErrorCode.UNKNOWN_ERROR,
|
||||||
|
"Failed to get default models",
|
||||||
|
{"reason": str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/config/provider-configs", response_model=ResponseModel)
|
||||||
|
async def get_provider_configs():
|
||||||
|
"""Get supported provider configurations for user API key management.
|
||||||
|
|
||||||
|
Configurations are dynamically built from provider.json files,
|
||||||
|
ensuring a single source of truth for provider metadata.
|
||||||
|
"""
|
||||||
|
cache_key = "config:provider_configs"
|
||||||
|
|
||||||
|
if REDIS_ENABLED:
|
||||||
|
cache = get_cache_service()
|
||||||
|
cached_configs = await cache.get(cache_key)
|
||||||
|
if cached_configs is not None:
|
||||||
|
return ResponseModel(data={"providers": cached_configs})
|
||||||
|
|
||||||
|
# Dynamically build provider configs from registry metadata
|
||||||
|
configs = _build_provider_configs()
|
||||||
|
providers_data = [config.model_dump() for config in configs]
|
||||||
|
|
||||||
|
if REDIS_ENABLED:
|
||||||
|
cache = get_cache_service()
|
||||||
|
await cache.set(cache_key, providers_data, ttl=300)
|
||||||
|
|
||||||
|
return ResponseModel(data={"providers": providers_data})
|
||||||
24
backend/src/api/config/storage.py
Normal file
24
backend/src/api/config/storage.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
"""Storage-related endpoints."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from src.models.schemas import ResponseModel
|
||||||
|
from src.utils.errors import ErrorCode
|
||||||
|
from src.services.storage_service import storage_manager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
class SignUrlRequest:
|
||||||
|
"""Request to sign a URL."""
|
||||||
|
url: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/storage/sign-url", response_model=ResponseModel)
|
||||||
|
async def sign_url_endpoint(request: dict):
|
||||||
|
"""Re-sign an OSS URL."""
|
||||||
|
new_url = storage_manager.sign_url(request.get("url", ""))
|
||||||
|
return ResponseModel(code=ErrorCode.SUCCESS, message="URL signed", data={"url": new_url})
|
||||||
208
backend/src/api/config/styles.py
Normal file
208
backend/src/api/config/styles.py
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
"""Style configuration endpoints."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from src.config.settings import REDIS_ENABLED
|
||||||
|
from src.models.schemas import ResponseModel
|
||||||
|
from src.services.cache_service import get_cache_service
|
||||||
|
from src.utils.errors import AppException, ErrorCode, ResourceNotFoundException, BusinessException
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/config/styles", response_model=ResponseModel)
|
||||||
|
async def get_styles():
|
||||||
|
"""Get all style configurations (cached for 5 minutes)."""
|
||||||
|
cache_key = "config:styles"
|
||||||
|
|
||||||
|
if REDIS_ENABLED:
|
||||||
|
cache = get_cache_service()
|
||||||
|
cached_styles = await cache.get(cache_key)
|
||||||
|
if cached_styles is not None:
|
||||||
|
return ResponseModel(data=cached_styles)
|
||||||
|
|
||||||
|
try:
|
||||||
|
src_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
styles_path = os.path.join(src_dir, "config", "styles.json")
|
||||||
|
|
||||||
|
result = {"styles": []}
|
||||||
|
if os.path.exists(styles_path):
|
||||||
|
with open(styles_path, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
result = {"styles": data.get("styles", [])}
|
||||||
|
|
||||||
|
if REDIS_ENABLED:
|
||||||
|
cache = get_cache_service()
|
||||||
|
await cache.set(cache_key, result, ttl=300)
|
||||||
|
|
||||||
|
return ResponseModel(data=result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load styles: {e}")
|
||||||
|
return ResponseModel(data={"styles": []})
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/config/styles", response_model=ResponseModel)
|
||||||
|
async def create_style(style: dict):
|
||||||
|
"""Create a new style configuration."""
|
||||||
|
try:
|
||||||
|
src_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
styles_path = os.path.join(src_dir, "config", "styles.json")
|
||||||
|
|
||||||
|
if not os.path.exists(styles_path):
|
||||||
|
data = {"styles": []}
|
||||||
|
else:
|
||||||
|
with open(styles_path, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
if 'id' not in style or not style['id']:
|
||||||
|
style['id'] = str(uuid.uuid4())
|
||||||
|
|
||||||
|
data["styles"].append(style)
|
||||||
|
|
||||||
|
with open(styles_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
if REDIS_ENABLED:
|
||||||
|
cache = get_cache_service()
|
||||||
|
await cache.delete("config:styles")
|
||||||
|
|
||||||
|
return ResponseModel(data=style)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create style: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/config/styles/{style_id}", response_model=ResponseModel)
|
||||||
|
async def update_style(style_id: str, style: dict):
|
||||||
|
"""Update an existing style configuration."""
|
||||||
|
try:
|
||||||
|
src_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
styles_path = os.path.join(src_dir, "config", "styles.json")
|
||||||
|
|
||||||
|
if not os.path.exists(styles_path):
|
||||||
|
raise AppException(
|
||||||
|
message="Styles configuration not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(styles_path, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
styles = data.get("styles", [])
|
||||||
|
updated = False
|
||||||
|
for i, s in enumerate(styles):
|
||||||
|
if s.get('id') == style_id:
|
||||||
|
styles[i] = {**s, **style, "id": style_id}
|
||||||
|
updated = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not updated:
|
||||||
|
raise AppException(
|
||||||
|
message="Style not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
data["styles"] = styles
|
||||||
|
with open(styles_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
if REDIS_ENABLED:
|
||||||
|
cache = get_cache_service()
|
||||||
|
await cache.delete("config:styles")
|
||||||
|
|
||||||
|
return ResponseModel(data=styles[i])
|
||||||
|
except AppException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to update style: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/config/styles/{style_id}", response_model=ResponseModel)
|
||||||
|
async def delete_style(style_id: str):
|
||||||
|
"""Delete a style configuration."""
|
||||||
|
try:
|
||||||
|
src_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
styles_path = os.path.join(src_dir, "config", "styles.json")
|
||||||
|
|
||||||
|
if not os.path.exists(styles_path):
|
||||||
|
raise ResourceNotFoundException("styles configuration", "styles.json")
|
||||||
|
|
||||||
|
with open(styles_path, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
styles = data.get("styles", [])
|
||||||
|
initial_len = len(styles)
|
||||||
|
styles = [s for s in styles if s.get('id') != style_id]
|
||||||
|
|
||||||
|
if len(styles) == initial_len:
|
||||||
|
raise ResourceNotFoundException("style", style_id)
|
||||||
|
|
||||||
|
data["styles"] = styles
|
||||||
|
with open(styles_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
if REDIS_ENABLED:
|
||||||
|
cache = get_cache_service()
|
||||||
|
await cache.delete("config:styles")
|
||||||
|
|
||||||
|
return ResponseModel(data={"status": "success"})
|
||||||
|
except ResourceNotFoundException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete style: {e}", exc_info=True)
|
||||||
|
raise BusinessException(
|
||||||
|
ErrorCode.UNKNOWN_ERROR,
|
||||||
|
"Failed to delete style",
|
||||||
|
{"style_id": style_id, "reason": str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/config/generation-options", response_model=ResponseModel)
|
||||||
|
async def get_generation_options():
|
||||||
|
"""Get generation options (cached for 5 minutes)."""
|
||||||
|
cache_key = "config:generation_options"
|
||||||
|
|
||||||
|
if REDIS_ENABLED:
|
||||||
|
cache = get_cache_service()
|
||||||
|
cached_options = await cache.get(cache_key)
|
||||||
|
if cached_options is not None:
|
||||||
|
return ResponseModel(data=cached_options)
|
||||||
|
|
||||||
|
try:
|
||||||
|
src_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
options_path = os.path.join(src_dir, "config", "generation_options.json")
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
if os.path.exists(options_path):
|
||||||
|
with open(options_path, 'r', encoding='utf-8') as f:
|
||||||
|
result = json.load(f)
|
||||||
|
|
||||||
|
if REDIS_ENABLED:
|
||||||
|
cache = get_cache_service()
|
||||||
|
await cache.set(cache_key, result, ttl=300)
|
||||||
|
|
||||||
|
return ResponseModel(data=result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load generation options: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
131
backend/src/api/config/system.py
Normal file
131
backend/src/api/config/system.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""System configuration endpoints."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from pydantic import BaseModel, Field, ConfigDict
|
||||||
|
|
||||||
|
from src.config.settings import REDIS_ENABLED
|
||||||
|
from src.models.schemas import ResponseModel
|
||||||
|
from src.services.cache_service import get_cache_service
|
||||||
|
from src.services.provider.registry import ModelRegistry, ModelType
|
||||||
|
from src.utils.errors import BusinessException, ErrorCode
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
class SystemConfig(BaseModel):
|
||||||
|
"""System configuration model."""
|
||||||
|
default_image_model: Optional[str] = None
|
||||||
|
default_video_model: Optional[str] = None
|
||||||
|
default_audio_model: Optional[str] = None
|
||||||
|
default_llm_model: Optional[str] = None
|
||||||
|
default_style: Optional[str] = None
|
||||||
|
default_resolution: Optional[str] = None
|
||||||
|
default_ratio: Optional[str] = None
|
||||||
|
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/config/system", response_model=ResponseModel)
|
||||||
|
async def get_system_config():
|
||||||
|
"""Get system configuration (cached for 5 minutes).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BusinessException: Failed to load system config
|
||||||
|
"""
|
||||||
|
cache_key = "config:system"
|
||||||
|
|
||||||
|
# Try cache first
|
||||||
|
if REDIS_ENABLED:
|
||||||
|
cache = get_cache_service()
|
||||||
|
cached_config = await cache.get(cache_key)
|
||||||
|
if cached_config is not None:
|
||||||
|
return ResponseModel(data=cached_config)
|
||||||
|
|
||||||
|
try:
|
||||||
|
src_dir = os.path.dirname(os.path.dirname(__file__))
|
||||||
|
config_path = os.path.join(src_dir, "config", "user_config.json")
|
||||||
|
|
||||||
|
config_data = {}
|
||||||
|
if os.path.exists(config_path):
|
||||||
|
with open(config_path, 'r', encoding='utf-8') as f:
|
||||||
|
config_data = json.load(f)
|
||||||
|
|
||||||
|
# Cache result for 5 minutes
|
||||||
|
if REDIS_ENABLED:
|
||||||
|
cache = get_cache_service()
|
||||||
|
await cache.set(cache_key, config_data, ttl=300)
|
||||||
|
|
||||||
|
return ResponseModel(data=config_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load system config: {e}", exc_info=True)
|
||||||
|
raise BusinessException(
|
||||||
|
ErrorCode.UNKNOWN_ERROR,
|
||||||
|
"Failed to load system configuration",
|
||||||
|
{"reason": str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/config/system", response_model=ResponseModel)
|
||||||
|
async def update_system_config(config: SystemConfig):
|
||||||
|
"""Update system configuration.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BusinessException: Failed to update system config
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
src_dir = os.path.dirname(os.path.dirname(__file__))
|
||||||
|
config_path = os.path.join(src_dir, "config", "user_config.json")
|
||||||
|
|
||||||
|
# Load existing to merge
|
||||||
|
existing_data = {}
|
||||||
|
if os.path.exists(config_path):
|
||||||
|
with open(config_path, 'r', encoding='utf-8') as f:
|
||||||
|
try:
|
||||||
|
existing_data = json.load(f)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning("Failed to parse existing config %s: %s", config_path, e)
|
||||||
|
|
||||||
|
new_data = config.model_dump(exclude_unset=True)
|
||||||
|
merged_data = {**existing_data, **new_data}
|
||||||
|
|
||||||
|
with open(config_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(merged_data, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
# Update default models in registry if changed
|
||||||
|
config_key_to_type = {
|
||||||
|
'defaultImageModel': ModelType.IMAGE,
|
||||||
|
'defaultVideoModel': ModelType.VIDEO,
|
||||||
|
'defaultAudioModel': ModelType.AUDIO,
|
||||||
|
'defaultLyricsModel': ModelType.LYRICS,
|
||||||
|
'defaultMusicModel': ModelType.MUSIC,
|
||||||
|
'defaultLLMModel': ModelType.LLM
|
||||||
|
}
|
||||||
|
|
||||||
|
for config_key, model_type in config_key_to_type.items():
|
||||||
|
if config_key in new_data:
|
||||||
|
model_id = new_data[config_key]
|
||||||
|
if model_id:
|
||||||
|
ModelRegistry.set_default_by_id(model_type, model_id)
|
||||||
|
logger.info(f"Updated default {model_type.value} model to: {model_id}")
|
||||||
|
|
||||||
|
# Invalidate all related caches
|
||||||
|
if REDIS_ENABLED:
|
||||||
|
cache = get_cache_service()
|
||||||
|
await cache.delete("config:system")
|
||||||
|
await cache.delete("config:models")
|
||||||
|
await cache.delete("config:defaults")
|
||||||
|
|
||||||
|
return ResponseModel(data=merged_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to update system config: {e}", exc_info=True)
|
||||||
|
raise BusinessException(
|
||||||
|
ErrorCode.UNKNOWN_ERROR,
|
||||||
|
"Failed to update system configuration",
|
||||||
|
{"reason": str(e)}
|
||||||
|
)
|
||||||
79
backend/src/api/config/validation.py
Normal file
79
backend/src/api/config/validation.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""Configuration validation endpoints."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from src.models.schemas import ResponseModel
|
||||||
|
from src.utils.errors import AppException, ErrorCode
|
||||||
|
from src.services.provider.validation import ConfigValidator, validate_all_configs
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/config/validate", response_model=ResponseModel)
|
||||||
|
async def validate_configurations():
|
||||||
|
"""Validate all service configurations."""
|
||||||
|
try:
|
||||||
|
report = validate_all_configs()
|
||||||
|
return ResponseModel(data=report)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to validate configurations: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/config/validate/service", response_model=ResponseModel)
|
||||||
|
async def validate_service_configuration(config: Dict[str, Any]):
|
||||||
|
"""Validate a single service configuration."""
|
||||||
|
try:
|
||||||
|
is_valid, errors = ConfigValidator.validate_config_deep(config)
|
||||||
|
return ResponseModel(data={
|
||||||
|
"valid": is_valid,
|
||||||
|
"errors": errors,
|
||||||
|
"config": config
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to validate service configuration: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/config/validate/file/{filename}", response_model=ResponseModel)
|
||||||
|
async def validate_config_file(filename: str):
|
||||||
|
"""Validate a specific configuration file."""
|
||||||
|
try:
|
||||||
|
src_dir = os.path.dirname(os.path.dirname(__file__))
|
||||||
|
config_path = os.path.join(src_dir, "config", "services", filename)
|
||||||
|
|
||||||
|
if not os.path.exists(config_path):
|
||||||
|
raise AppException(
|
||||||
|
message=f"Configuration file not found: {filename}",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
is_valid, errors = ConfigValidator.validate_config_file(config_path)
|
||||||
|
return ResponseModel(data={
|
||||||
|
"filename": filename,
|
||||||
|
"valid": is_valid,
|
||||||
|
"errors": errors
|
||||||
|
})
|
||||||
|
except AppException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to validate config file: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
45
backend/src/api/generations/__init__.py
Normal file
45
backend/src/api/generations/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
"""
|
||||||
|
Generations Module - 生成相关路由模块
|
||||||
|
|
||||||
|
将各个子模块的路由组合成统一的 router
|
||||||
|
|
||||||
|
子模块:
|
||||||
|
- script.py - 脚本分析端点
|
||||||
|
- tasks.py - 任务状态端点
|
||||||
|
- image.py - 图片生成端点
|
||||||
|
- video.py - 视频生成端点
|
||||||
|
- audio.py - 音频生成和导出端点
|
||||||
|
- music.py - 音乐生成端点
|
||||||
|
- batch.py - 批量生成端点
|
||||||
|
"""
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from .script import router as script_router
|
||||||
|
from .tasks import router as tasks_router
|
||||||
|
from .image import router as image_router
|
||||||
|
from .video import router as video_router
|
||||||
|
from .audio import router as audio_router
|
||||||
|
from .music import router as music_router
|
||||||
|
from .batch import router as batch_router
|
||||||
|
|
||||||
|
# 创建主路由
|
||||||
|
router = APIRouter(tags=["generations"])
|
||||||
|
|
||||||
|
# 包含所有子路由
|
||||||
|
router.include_router(script_router)
|
||||||
|
router.include_router(tasks_router)
|
||||||
|
router.include_router(image_router)
|
||||||
|
router.include_router(video_router)
|
||||||
|
router.include_router(audio_router)
|
||||||
|
router.include_router(music_router)
|
||||||
|
router.include_router(batch_router)
|
||||||
|
|
||||||
|
# 导出辅助函数供其他模块使用
|
||||||
|
from .helpers import resolve_service, ensure_url, get_available_styles_from_config
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"router",
|
||||||
|
"resolve_service",
|
||||||
|
"ensure_url",
|
||||||
|
"get_available_styles_from_config",
|
||||||
|
]
|
||||||
159
backend/src/api/generations/audio.py
Normal file
159
backend/src/api/generations/audio.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
"""
|
||||||
|
Audio & Export Endpoints - 音频生成和导出模块
|
||||||
|
|
||||||
|
包含:
|
||||||
|
- POST /audio/generate - 音频生成
|
||||||
|
- POST /generations/audio - 统一音频生成
|
||||||
|
- POST /export/project/{project_id} - 项目导出
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
from src.models.schemas import (
|
||||||
|
ResponseModel,
|
||||||
|
AudioGenerationRequest,
|
||||||
|
GenerateAudioRequest,
|
||||||
|
ExportProjectRequest
|
||||||
|
)
|
||||||
|
from src.services.provider.registry import ModelRegistry, ModelType
|
||||||
|
from src.services.task_manager import task_manager
|
||||||
|
from src.services.export_service import VideoExportService
|
||||||
|
from src.auth.dependencies import get_current_user, UserAuth
|
||||||
|
|
||||||
|
from .helpers import resolve_service, check_user_api_key
|
||||||
|
from src.utils.errors import ErrorCode, AppException
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 初始化 export service
|
||||||
|
export_service = VideoExportService()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/audio/generate", response_model=ResponseModel)
|
||||||
|
async def generate_audio(
|
||||||
|
request: GenerateAudioRequest,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""生成 audio from text(兼容旧接口,内部走统一任务系统)"""
|
||||||
|
logger.info("Audio generation request (legacy): model=%s, text=%s...", request.model, request.text[:50])
|
||||||
|
|
||||||
|
# 兼容旧接口:model 可为空,回退到默认音频模型
|
||||||
|
model_name = request.model or ModelRegistry.get_default_id(ModelType.AUDIO)
|
||||||
|
if not model_name:
|
||||||
|
raise AppException(
|
||||||
|
message="Audio service not configured",
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证模型可解析并获取服务
|
||||||
|
audio_service = resolve_service(model_name, ModelType.AUDIO)
|
||||||
|
|
||||||
|
# 检查用户是否配置了 API Key
|
||||||
|
# 从 audio_service 获取 provider_id
|
||||||
|
provider = getattr(audio_service, 'provider_id', None)
|
||||||
|
if provider:
|
||||||
|
check_user_api_key(current_user.id, provider)
|
||||||
|
|
||||||
|
try:
|
||||||
|
task_params = {
|
||||||
|
"text": request.text,
|
||||||
|
"voice": request.voice,
|
||||||
|
"model": model_name,
|
||||||
|
"project_id": request.project_id,
|
||||||
|
"storyboard_id": request.storyboard_id,
|
||||||
|
"extra_params": request.extra_params or {},
|
||||||
|
}
|
||||||
|
|
||||||
|
task = await task_manager.create_task(
|
||||||
|
task_type="audio",
|
||||||
|
model=model_name,
|
||||||
|
params=task_params,
|
||||||
|
user_id=current_user.id,
|
||||||
|
project_id=request.project_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 兼容前端:保留 audio_url 字段但异步任务初始为空
|
||||||
|
return ResponseModel(data={
|
||||||
|
"audio_url": None,
|
||||||
|
"task_id": task.id,
|
||||||
|
"status": task.status
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/generations/audio", response_model=ResponseModel)
|
||||||
|
async def generate_audio_unified(
|
||||||
|
request: AudioGenerationRequest,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""统一音频生成端点(推荐)"""
|
||||||
|
logger.info("Audio generation request: model=%s, text=%s...", request.model, request.text[:50])
|
||||||
|
|
||||||
|
audio_service = resolve_service(request.model, ModelType.AUDIO)
|
||||||
|
|
||||||
|
# 检查用户是否配置了 API Key
|
||||||
|
# 从 audio_service 获取 provider_id
|
||||||
|
provider = getattr(audio_service, 'provider_id', None)
|
||||||
|
if provider:
|
||||||
|
check_user_api_key(current_user.id, provider)
|
||||||
|
|
||||||
|
if not audio_service:
|
||||||
|
raise AppException(
|
||||||
|
message="Audio service not configured",
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
task_params = request.model_dump()
|
||||||
|
# Get user_id from authenticated user
|
||||||
|
user_id = current_user.id
|
||||||
|
|
||||||
|
# Handle project_id from multiple sources
|
||||||
|
# Priority 1: request.project_id (from request body, alias="projectId")
|
||||||
|
# Priority 2: request.source_id if source=='project' (from extra_params, alias="sourceId")
|
||||||
|
project_id = request.project_id
|
||||||
|
if not project_id and request.source == "project":
|
||||||
|
project_id = request.source_id
|
||||||
|
|
||||||
|
task = await task_manager.create_task(
|
||||||
|
task_type="audio",
|
||||||
|
model=request.model,
|
||||||
|
params=task_params,
|
||||||
|
user_id=user_id,
|
||||||
|
project_id=project_id
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to create audio task: %s", e)
|
||||||
|
raise AppException(
|
||||||
|
message=f"Failed to create task: {e}",
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
return ResponseModel(data={
|
||||||
|
"task_id": task.id,
|
||||||
|
"status": task.status
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/export/project/{project_id}", response_model=ResponseModel)
|
||||||
|
async def export_project(project_id: str, request: ExportProjectRequest):
|
||||||
|
""" 导出 project to video"""
|
||||||
|
try:
|
||||||
|
result = await export_service.export_project(project_id, format=request.format)
|
||||||
|
return ResponseModel(data=result)
|
||||||
|
except Exception as e:
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
320
backend/src/api/generations/batch.py
Normal file
320
backend/src/api/generations/batch.py
Normal file
@@ -0,0 +1,320 @@
|
|||||||
|
"""
|
||||||
|
Batch Generation Endpoints - 批量生成模块
|
||||||
|
|
||||||
|
包含批量生成相关的端点:
|
||||||
|
- POST /generations/batch - 批量创建生成任务
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from pydantic import BaseModel, Field, ConfigDict
|
||||||
|
|
||||||
|
from src.models.schemas import (
|
||||||
|
ResponseModel,
|
||||||
|
ImageGenerationRequest,
|
||||||
|
VideoGenerationRequest,
|
||||||
|
AudioGenerationRequest,
|
||||||
|
MusicGenerationRequest
|
||||||
|
)
|
||||||
|
from src.services.provider.registry import ModelRegistry, ModelType
|
||||||
|
from src.services.task_manager import task_manager
|
||||||
|
from src.auth.dependencies import get_current_user, UserAuth
|
||||||
|
from src.utils.errors import ErrorCode, AppException
|
||||||
|
|
||||||
|
from .helpers import resolve_service, ensure_url, check_user_api_key
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BatchGenerationRequest(BaseModel):
|
||||||
|
"""批量生成请求"""
|
||||||
|
items: List[Dict[str, Any]] = Field(..., description="生成任务列表")
|
||||||
|
|
||||||
|
model_config = ConfigDict(json_schema_extra={
|
||||||
|
"example": {
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"prompt": "A beautiful landscape",
|
||||||
|
"model": "wanx2.1-t2i-turbo"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
class BatchGenerationResponse(BaseModel):
|
||||||
|
"""批量生成响应"""
|
||||||
|
task_ids: List[str] = Field(..., description="创建的任务ID列表")
|
||||||
|
total: int = Field(..., description="总任务数")
|
||||||
|
created: int = Field(..., description="成功创建的任务数")
|
||||||
|
failed: int = Field(..., description="失败的任务数")
|
||||||
|
errors: List[Dict[str, str]] = Field(default=[], description="错误信息列表")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/generations/batch", response_model=ResponseModel)
|
||||||
|
async def generate_batch(
|
||||||
|
request: BatchGenerationRequest,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""批量创建生成任务
|
||||||
|
|
||||||
|
支持同时创建多个图片、视频、音频、音乐生成任务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: 批量生成请求,包含任务列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
批量生成响应,包含创建的任务ID列表
|
||||||
|
"""
|
||||||
|
logger.info(f"Batch generation request: {len(request.items)} items from user {current_user.id}")
|
||||||
|
|
||||||
|
if not request.items:
|
||||||
|
raise AppException(
|
||||||
|
message="No items provided for batch generation",
|
||||||
|
code=ErrorCode.INVALID_PARAMETER,
|
||||||
|
status_code=400
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(request.items) > 20:
|
||||||
|
raise AppException(
|
||||||
|
message="Batch size exceeds maximum limit of 20",
|
||||||
|
code=ErrorCode.INVALID_PARAMETER,
|
||||||
|
status_code=400
|
||||||
|
)
|
||||||
|
|
||||||
|
task_ids = []
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
for idx, item in enumerate(request.items):
|
||||||
|
try:
|
||||||
|
item_type = item.get("type", "image")
|
||||||
|
|
||||||
|
if item_type == "image":
|
||||||
|
task_id = await _create_image_task(item, current_user)
|
||||||
|
elif item_type == "video":
|
||||||
|
task_id = await _create_video_task(item, current_user)
|
||||||
|
elif item_type == "audio":
|
||||||
|
task_id = await _create_audio_task(item, current_user)
|
||||||
|
elif item_type == "music":
|
||||||
|
task_id = await _create_music_task(item, current_user)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported generation type: {item_type}")
|
||||||
|
|
||||||
|
task_ids.append(task_id)
|
||||||
|
logger.info(f"Created {item_type} task {task_id} for batch item {idx}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create task for batch item {idx}: {e}")
|
||||||
|
errors.append({
|
||||||
|
"index": idx,
|
||||||
|
"error": str(e)
|
||||||
|
})
|
||||||
|
|
||||||
|
result = BatchGenerationResponse(
|
||||||
|
task_ids=task_ids,
|
||||||
|
total=len(request.items),
|
||||||
|
created=len(task_ids),
|
||||||
|
failed=len(errors),
|
||||||
|
errors=errors
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Batch generation completed: {result.created}/{result.total} tasks created")
|
||||||
|
|
||||||
|
return ResponseModel(data=result)
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_image_task(item: Dict[str, Any], current_user: UserAuth) -> str:
|
||||||
|
"""创建图片生成任务"""
|
||||||
|
image_service = resolve_service(item.get("model"), ModelType.IMAGE)
|
||||||
|
|
||||||
|
if not image_service:
|
||||||
|
raise ValueError("Image service not configured")
|
||||||
|
|
||||||
|
# 检查 API Key
|
||||||
|
provider = getattr(image_service, 'provider_id', None)
|
||||||
|
if provider:
|
||||||
|
check_user_api_key(current_user.id, provider)
|
||||||
|
|
||||||
|
# 解析尺寸
|
||||||
|
aspect_ratio = item.get("aspect_ratio", "16:9")
|
||||||
|
resolution = item.get("resolution", "1K")
|
||||||
|
size = _resolve_image_size(image_service, aspect_ratio, resolution)
|
||||||
|
|
||||||
|
# 构建参数
|
||||||
|
task_params = {
|
||||||
|
"prompt": item.get("prompt", ""),
|
||||||
|
"size": size,
|
||||||
|
"aspect_ratio": aspect_ratio,
|
||||||
|
"resolution": resolution,
|
||||||
|
"image_inputs": [_sanitize_url(url) for url in item.get("image_inputs", []) if url],
|
||||||
|
"extra_params": item.get("extra_params", {})
|
||||||
|
}
|
||||||
|
|
||||||
|
# 处理 source 相关字段
|
||||||
|
if item.get("source"):
|
||||||
|
task_params["source"] = item["source"]
|
||||||
|
if item.get("source_id"):
|
||||||
|
task_params["source_id"] = item["source_id"]
|
||||||
|
if item.get("project_id"):
|
||||||
|
task_params["project_id"] = item["project_id"]
|
||||||
|
|
||||||
|
model_name = item.get("model") or getattr(image_service, "model_id", "default_image")
|
||||||
|
|
||||||
|
task = await task_manager.create_task(
|
||||||
|
task_type="image",
|
||||||
|
model=model_name,
|
||||||
|
params=task_params,
|
||||||
|
user_id=current_user.id,
|
||||||
|
project_id=item.get("project_id")
|
||||||
|
)
|
||||||
|
|
||||||
|
return task.id
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_video_task(item: Dict[str, Any], current_user: UserAuth) -> str:
|
||||||
|
"""创建视频生成任务"""
|
||||||
|
video_service = resolve_service(item.get("model"), ModelType.VIDEO)
|
||||||
|
|
||||||
|
if not video_service:
|
||||||
|
raise ValueError("Video service not configured")
|
||||||
|
|
||||||
|
# 检查 API Key
|
||||||
|
provider = getattr(video_service, 'provider_id', None)
|
||||||
|
if provider:
|
||||||
|
check_user_api_key(current_user.id, provider)
|
||||||
|
|
||||||
|
# 构建参数
|
||||||
|
task_params = {
|
||||||
|
"prompt": item.get("prompt", ""),
|
||||||
|
"duration": item.get("duration", 5),
|
||||||
|
"resolution": item.get("resolution", "720p"),
|
||||||
|
"image_inputs": [_sanitize_url(url) for url in item.get("image_inputs", []) if url],
|
||||||
|
"extra_params": item.get("extra_params", {})
|
||||||
|
}
|
||||||
|
|
||||||
|
if item.get("source"):
|
||||||
|
task_params["source"] = item["source"]
|
||||||
|
if item.get("source_id"):
|
||||||
|
task_params["source_id"] = item["source_id"]
|
||||||
|
if item.get("project_id"):
|
||||||
|
task_params["project_id"] = item["project_id"]
|
||||||
|
|
||||||
|
model_name = item.get("model") or getattr(video_service, "model_id", "default_video")
|
||||||
|
|
||||||
|
task = await task_manager.create_task(
|
||||||
|
task_type="video",
|
||||||
|
model=model_name,
|
||||||
|
params=task_params,
|
||||||
|
user_id=current_user.id,
|
||||||
|
project_id=item.get("project_id")
|
||||||
|
)
|
||||||
|
|
||||||
|
return task.id
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_audio_task(item: Dict[str, Any], current_user: UserAuth) -> str:
|
||||||
|
"""创建音频生成任务"""
|
||||||
|
audio_service = resolve_service(item.get("model"), ModelType.AUDIO)
|
||||||
|
|
||||||
|
if not audio_service:
|
||||||
|
raise ValueError("Audio service not configured")
|
||||||
|
|
||||||
|
# 检查 API Key
|
||||||
|
provider = getattr(audio_service, 'provider_id', None)
|
||||||
|
if provider:
|
||||||
|
check_user_api_key(current_user.id, provider)
|
||||||
|
|
||||||
|
task_params = {
|
||||||
|
"text": item.get("text", item.get("prompt", "")),
|
||||||
|
"voice": item.get("voice", "default"),
|
||||||
|
"speed": item.get("speed", 1.0),
|
||||||
|
"extra_params": item.get("extra_params", {})
|
||||||
|
}
|
||||||
|
|
||||||
|
if item.get("source"):
|
||||||
|
task_params["source"] = item["source"]
|
||||||
|
if item.get("source_id"):
|
||||||
|
task_params["source_id"] = item["source_id"]
|
||||||
|
|
||||||
|
model_name = item.get("model") or getattr(audio_service, "model_id", "default_audio")
|
||||||
|
|
||||||
|
task = await task_manager.create_task(
|
||||||
|
task_type="audio",
|
||||||
|
model=model_name,
|
||||||
|
params=task_params,
|
||||||
|
user_id=current_user.id,
|
||||||
|
project_id=item.get("project_id")
|
||||||
|
)
|
||||||
|
|
||||||
|
return task.id
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_music_task(item: Dict[str, Any], current_user: UserAuth) -> str:
|
||||||
|
"""创建音乐生成任务"""
|
||||||
|
music_service = resolve_service(item.get("model"), ModelType.MUSIC)
|
||||||
|
|
||||||
|
if not music_service:
|
||||||
|
raise ValueError("Music service not configured")
|
||||||
|
|
||||||
|
# 检查 API Key
|
||||||
|
provider = getattr(music_service, 'provider_id', None)
|
||||||
|
if provider:
|
||||||
|
check_user_api_key(current_user.id, provider)
|
||||||
|
|
||||||
|
task_params = {
|
||||||
|
"prompt": item.get("prompt", ""),
|
||||||
|
"duration": item.get("duration", 30),
|
||||||
|
"extra_params": item.get("extra_params", {})
|
||||||
|
}
|
||||||
|
|
||||||
|
if item.get("source"):
|
||||||
|
task_params["source"] = item["source"]
|
||||||
|
if item.get("source_id"):
|
||||||
|
task_params["source_id"] = item["source_id"]
|
||||||
|
|
||||||
|
model_name = item.get("model") or getattr(music_service, "model_id", "default_music")
|
||||||
|
|
||||||
|
task = await task_manager.create_task(
|
||||||
|
task_type="music",
|
||||||
|
model=model_name,
|
||||||
|
params=task_params,
|
||||||
|
user_id=current_user.id,
|
||||||
|
project_id=item.get("project_id")
|
||||||
|
)
|
||||||
|
|
||||||
|
return task.id
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_image_size(image_service, aspect_ratio: str, resolution: str) -> str:
|
||||||
|
"""解析图片尺寸"""
|
||||||
|
model_config = getattr(image_service, "config", {})
|
||||||
|
resolutions_config = model_config.get("resolutions", {})
|
||||||
|
|
||||||
|
if resolutions_config and resolution in resolutions_config:
|
||||||
|
ratio_map = resolutions_config[resolution]
|
||||||
|
if isinstance(ratio_map, dict) and aspect_ratio in ratio_map:
|
||||||
|
return ratio_map[aspect_ratio]
|
||||||
|
|
||||||
|
# 默认尺寸
|
||||||
|
defaults = {
|
||||||
|
"1K": {
|
||||||
|
"16:9": "1280*720",
|
||||||
|
"9:16": "720*1280",
|
||||||
|
"1:1": "1024*1024"
|
||||||
|
},
|
||||||
|
"2K": {
|
||||||
|
"16:9": "2560*1440",
|
||||||
|
"9:16": "1440*2560",
|
||||||
|
"1:1": "2048*2048"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaults.get(resolution, {}).get(aspect_ratio, "1024*1024")
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_url(url: str) -> str:
|
||||||
|
"""清理 URL"""
|
||||||
|
if not url:
|
||||||
|
return url
|
||||||
|
return ensure_url(url)
|
||||||
396
backend/src/api/generations/helpers.py
Normal file
396
backend/src/api/generations/helpers.py
Normal file
@@ -0,0 +1,396 @@
|
|||||||
|
"""
|
||||||
|
Generation Helpers - 辅助函数模块
|
||||||
|
|
||||||
|
包含:
|
||||||
|
- resolve_service: 服务路由函数
|
||||||
|
- ensure_url: URL 处理函数
|
||||||
|
- _get_available_styles_from_config: 样式配置获取
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import base64
|
||||||
|
import uuid
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from src.services.provider.registry import ModelRegistry, ModelType
|
||||||
|
from src.services.storage_service import storage_manager
|
||||||
|
from src.config.settings import REDIS_ENABLED
|
||||||
|
from src.services.cache_service import get_cache_service
|
||||||
|
from src.utils.errors import ErrorCode, AppException
|
||||||
|
from src.services.user_api_key_service import user_api_key_service
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_test_env() -> bool:
|
||||||
|
return (
|
||||||
|
os.getenv("PYTEST_CURRENT_TEST") is not None
|
||||||
|
or os.getenv("PIXEL_TEST_MODE") == "1"
|
||||||
|
or os.getenv("NODE_ENV") == "test"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_model_registry_loaded() -> None:
|
||||||
|
if ModelRegistry.list_models():
|
||||||
|
return
|
||||||
|
|
||||||
|
from src.utils.service_loader import load_services_from_config
|
||||||
|
|
||||||
|
services_config_path = os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
||||||
|
"config",
|
||||||
|
"services",
|
||||||
|
)
|
||||||
|
load_services_from_config(services_config_path)
|
||||||
|
|
||||||
|
|
||||||
|
def _reload_model_registry() -> None:
|
||||||
|
from src.utils.service_loader import load_services_from_config
|
||||||
|
|
||||||
|
services_config_path = os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
||||||
|
"config",
|
||||||
|
"services",
|
||||||
|
)
|
||||||
|
load_services_from_config(services_config_path)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=128)
|
||||||
|
def resolve_service(
|
||||||
|
model: str,
|
||||||
|
model_type: ModelType
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
根据复合 ID 路由到正确的服务
|
||||||
|
|
||||||
|
路由逻辑:
|
||||||
|
1. 验证 model 必须是 provider/model_key 格式
|
||||||
|
2. 解析 provider 和 model_key
|
||||||
|
3. 尝试用复合 ID 直接查找
|
||||||
|
4. 在该 provider 下搜索匹配的 model_key
|
||||||
|
5. 如果都找不到,抛出 HTTPException(404)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: 复合 ID 格式 "provider/model_key"(如 'dashscope/qwen-image')
|
||||||
|
model_type: 模型类型(IMAGE, VIDEO, AUDIO 等)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
服务实例
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 当 model 格式不正确时
|
||||||
|
HTTPException: 当模型不存在时
|
||||||
|
"""
|
||||||
|
_ensure_model_registry_loaded()
|
||||||
|
|
||||||
|
# 1. 验证格式
|
||||||
|
if not model or '/' not in model:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model must be in format 'provider/model_key', got: '{model}'. "
|
||||||
|
f"Example: 'dashscope/qwen-image'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. 解析 provider 和 model_key
|
||||||
|
parts = model.split('/', 1)
|
||||||
|
if len(parts) != 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model format invalid: '{model}'. "
|
||||||
|
f"Must have exactly one '/' separator."
|
||||||
|
)
|
||||||
|
|
||||||
|
provider, model_key = parts
|
||||||
|
if not provider or not model_key:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model format invalid: '{model}'. "
|
||||||
|
f"Both provider and model_key must be non-empty."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Resolving service: model='{model}', type={model_type.value}")
|
||||||
|
|
||||||
|
# 3. 方式1: 直接用复合 ID 查找
|
||||||
|
service = ModelRegistry.get(model)
|
||||||
|
if service:
|
||||||
|
logger.info(f"Found service by composite ID: '{model}'")
|
||||||
|
return service
|
||||||
|
|
||||||
|
# 4. 方式2: 在 provider 下搜索 model_key
|
||||||
|
services = ModelRegistry.find_services(provider=provider, model_type=model_type)
|
||||||
|
matching = [
|
||||||
|
s for s in services
|
||||||
|
if s.get('model_key') == model_key or s.get('id') == model
|
||||||
|
]
|
||||||
|
|
||||||
|
if matching:
|
||||||
|
service = ModelRegistry.get(matching[0].get('id'))
|
||||||
|
logger.info(
|
||||||
|
f"Found service by provider+model_key: "
|
||||||
|
f"provider='{provider}', model_key='{model_key}'"
|
||||||
|
)
|
||||||
|
return service
|
||||||
|
|
||||||
|
# 4.5 可能 registry 被测试或运行时临时覆盖,尝试强制回填一次
|
||||||
|
_reload_model_registry()
|
||||||
|
|
||||||
|
service = ModelRegistry.get(model)
|
||||||
|
if service:
|
||||||
|
logger.info(f"Found service by composite ID after registry reload: '{model}'")
|
||||||
|
return service
|
||||||
|
|
||||||
|
services = ModelRegistry.find_services(provider=provider, model_type=model_type)
|
||||||
|
matching = [
|
||||||
|
s for s in services
|
||||||
|
if s.get('model_key') == model_key or s.get('id') == model
|
||||||
|
]
|
||||||
|
if matching:
|
||||||
|
service = ModelRegistry.get(matching[0].get('id'))
|
||||||
|
logger.info(
|
||||||
|
f"Found service by provider+model_key after registry reload: "
|
||||||
|
f"provider='{provider}', model_key='{model_key}'"
|
||||||
|
)
|
||||||
|
return service
|
||||||
|
|
||||||
|
# 5. 未找到
|
||||||
|
logger.error(f"Model '{model}' not found")
|
||||||
|
raise AppException(
|
||||||
|
message=f"Model '{model}' not found. Available models can be fetched from /api/v1/models",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_url(url_or_base64: str) -> str:
|
||||||
|
"""
|
||||||
|
Ensure the input is a valid URL that external services can access.
|
||||||
|
- If it's a relative path and using local storage, convert to Base64.
|
||||||
|
- If it's a relative path and using OSS, upload to OSS.
|
||||||
|
- If it's a Base64 string, return as is.
|
||||||
|
- If it's a localhost URL, download and convert to Base64 or upload to OSS.
|
||||||
|
- If it's already a public URL, sign it if needed.
|
||||||
|
"""
|
||||||
|
if not url_or_base64:
|
||||||
|
return url_or_base64
|
||||||
|
|
||||||
|
# Check if it's a relative path (starts with /files/ or /uploads/)
|
||||||
|
if url_or_base64.startswith('/files/') or url_or_base64.startswith('/uploads/'):
|
||||||
|
try:
|
||||||
|
from src.config.settings import DATA_DIR, UPLOAD_DIR, STORAGE_TYPE
|
||||||
|
|
||||||
|
logger.info(f"[ensure_url] Detected relative path: {url_or_base64}")
|
||||||
|
|
||||||
|
# Determine the file path
|
||||||
|
if url_or_base64.startswith('/files/'):
|
||||||
|
rel_path = url_or_base64[7:] # Remove '/files/'
|
||||||
|
file_path = os.path.join(DATA_DIR, rel_path)
|
||||||
|
else: # /uploads/
|
||||||
|
rel_path = url_or_base64[9:] # Remove '/uploads/'
|
||||||
|
file_path = os.path.join(UPLOAD_DIR, rel_path)
|
||||||
|
|
||||||
|
# Check if file exists
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
logger.error(f"[ensure_url] File not found: {file_path}")
|
||||||
|
return url_or_base64
|
||||||
|
|
||||||
|
# Read the file
|
||||||
|
with open(file_path, 'rb') as f:
|
||||||
|
file_data = f.read()
|
||||||
|
|
||||||
|
# Determine extension and MIME type
|
||||||
|
ext = os.path.splitext(file_path)[1].lstrip('.') or 'png'
|
||||||
|
mime_type = f"image/{ext}"
|
||||||
|
if ext == 'jpg':
|
||||||
|
mime_type = "image/jpeg"
|
||||||
|
|
||||||
|
# If using OSS, upload the file to OSS
|
||||||
|
if STORAGE_TYPE == 'oss':
|
||||||
|
filename = f"temp/{uuid.uuid4()}.{ext}"
|
||||||
|
oss_url = storage_manager.save(filename, file_data)
|
||||||
|
logger.info(f"[ensure_url] Uploaded to OSS: {oss_url}")
|
||||||
|
return oss_url
|
||||||
|
|
||||||
|
# If using local storage, convert to Base64
|
||||||
|
else:
|
||||||
|
base64_data = base64.b64encode(file_data).decode('utf-8')
|
||||||
|
base64_url = f"data:{mime_type};base64,{base64_data}"
|
||||||
|
logger.info(f"[ensure_url] Converted to Base64 (length: {len(base64_url)})")
|
||||||
|
return base64_url
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[ensure_url] Failed to convert relative path: {e}")
|
||||||
|
return url_or_base64
|
||||||
|
|
||||||
|
# Check if it's a localhost URL — read from disk instead of HTTP to prevent SSRF
|
||||||
|
if any(host in url_or_base64.lower() for host in ['localhost', '127.0.0.1', '0.0.0.0']):
|
||||||
|
try:
|
||||||
|
from src.config.settings import STORAGE_TYPE
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
logger.info(f"[ensure_url] Detected localhost URL, reading from disk: {url_or_base64}")
|
||||||
|
|
||||||
|
# Extract the path component (e.g. /files/xxx or /uploads/xxx)
|
||||||
|
parsed = urlparse(url_or_base64)
|
||||||
|
path = parsed.path
|
||||||
|
|
||||||
|
# Map URL path to local filesystem
|
||||||
|
if path.startswith('/files/'):
|
||||||
|
rel_path = path[7:]
|
||||||
|
file_path = os.path.join(DATA_DIR, rel_path)
|
||||||
|
elif path.startswith('/uploads/'):
|
||||||
|
rel_path = path[9:]
|
||||||
|
file_path = os.path.join(UPLOAD_DIR, rel_path)
|
||||||
|
else:
|
||||||
|
logger.warning(f"[ensure_url] Unsupported localhost path: {path}")
|
||||||
|
return url_or_base64
|
||||||
|
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
logger.error(f"[ensure_url] File not found: {file_path}")
|
||||||
|
return url_or_base64
|
||||||
|
|
||||||
|
with open(file_path, 'rb') as f:
|
||||||
|
file_data = f.read()
|
||||||
|
|
||||||
|
ext = os.path.splitext(file_path)[1].lstrip('.') or 'png'
|
||||||
|
mime_type = f"image/{ext}"
|
||||||
|
if ext == 'jpg':
|
||||||
|
mime_type = "image/jpeg"
|
||||||
|
|
||||||
|
if STORAGE_TYPE == 'oss':
|
||||||
|
filename = f"temp/{uuid.uuid4()}.{ext}"
|
||||||
|
saved_url = storage_manager.save(filename, file_data)
|
||||||
|
logger.info(f"[ensure_url] Uploaded to OSS: {saved_url}")
|
||||||
|
return saved_url
|
||||||
|
else:
|
||||||
|
base64_data = base64.b64encode(file_data).decode('utf-8')
|
||||||
|
base64_url = f"data:{mime_type};base64,{base64_data}"
|
||||||
|
logger.info(f"[ensure_url] Converted to Base64 (length: {len(base64_url)})")
|
||||||
|
return base64_url
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[ensure_url] Failed to convert localhost URL: {e}")
|
||||||
|
return url_or_base64
|
||||||
|
|
||||||
|
# Check if Base64 - return as is (already in correct format)
|
||||||
|
if url_or_base64.startswith("data:"):
|
||||||
|
logger.info(f"[ensure_url] Already Base64 format, returning as is")
|
||||||
|
return url_or_base64
|
||||||
|
|
||||||
|
# If it's already a public URL, check if it belongs to our OSS bucket
|
||||||
|
if url_or_base64.startswith(('http://', 'https://')):
|
||||||
|
from src.config.settings import STORAGE_TYPE
|
||||||
|
from src.utils.oss_utils import OSS_BUCKET
|
||||||
|
|
||||||
|
# Only sign URLs from our own OSS bucket; external URLs should pass through as-is
|
||||||
|
if STORAGE_TYPE == 'oss' and OSS_BUCKET and f"{OSS_BUCKET}." in url_or_base64:
|
||||||
|
return storage_manager.sign_url(url_or_base64) or url_or_base64
|
||||||
|
else:
|
||||||
|
# External public URL (e.g. dashscope-result, other CDN), return as-is
|
||||||
|
logger.info(f"[ensure_url] External public URL, returning as-is")
|
||||||
|
return url_or_base64
|
||||||
|
|
||||||
|
# Fallback: sign it if needed
|
||||||
|
return storage_manager.sign_url(url_or_base64) or url_or_base64
|
||||||
|
|
||||||
|
|
||||||
|
async def get_available_styles_from_config() -> List[dict]:
|
||||||
|
""" 获取 available style configurations from backend config file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of style dicts with full information (id, name, desc, type, etc.)
|
||||||
|
Example: [{'id': 'cyberpunk', 'name': '赛博朋克', 'desc': '高对比度霓虹灯光...', ...}, ...]
|
||||||
|
"""
|
||||||
|
cache_key = "config:styles_full"
|
||||||
|
|
||||||
|
# Cache first (5 minutes TTL)
|
||||||
|
if REDIS_ENABLED:
|
||||||
|
cache = get_cache_service()
|
||||||
|
cached_styles = await cache.get(cache_key)
|
||||||
|
if cached_styles is not None:
|
||||||
|
return cached_styles
|
||||||
|
|
||||||
|
# Get from config file
|
||||||
|
src_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
styles_path = os.path.join(src_dir, "config", "styles.json")
|
||||||
|
|
||||||
|
styles_list = []
|
||||||
|
if os.path.exists(styles_path):
|
||||||
|
with open(styles_path, 'r', encoding='utf-8') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
styles_list = data.get('styles', [])
|
||||||
|
|
||||||
|
# Fallback to common styles with descriptions if config doesn't exist
|
||||||
|
if not styles_list:
|
||||||
|
styles_list = [
|
||||||
|
{"id": "cyberpunk", "name": "赛博朋克", "desc": "高对比度霓虹灯光,未来主义建筑,机械改造"},
|
||||||
|
{"id": "ink", "name": "水墨风", "desc": "传统中国水墨渲染,黑白灰为主,意境深远"},
|
||||||
|
{"id": "pixar", "name": "皮克斯风格", "desc": "3D卡通渲染,色彩鲜艳,光影柔和,表情夸张"},
|
||||||
|
{"id": "anime", "name": "日漫", "desc": "典型日本动画风格,线条清晰,赛璐璐上色"},
|
||||||
|
{"id": "chinese-anime", "name": "国漫风", "desc": "中国现代动画风格,融合传统与现代元素"},
|
||||||
|
{"id": "realistic", "name": "写实", "desc": "电影级写实渲染,细节丰富,光照真实"},
|
||||||
|
{"id": "hand-drawn", "name": "手绘", "desc": "传统手绘质感,笔触明显,艺术感强"},
|
||||||
|
{"id": "watercolor", "name": "水彩画", "desc": "水彩晕染效果,色彩通透,艺术感强"},
|
||||||
|
{"id": "oil-painting", "name": "油画", "desc": "厚涂油画质感,笔触丰富,光影层次感强"},
|
||||||
|
{"id": "pixel-art", "name": "像素风", "desc": "复古8-bit/16-bit像素艺术,怀旧游戏风格"},
|
||||||
|
]
|
||||||
|
|
||||||
|
# 缓存 for 5 minutes
|
||||||
|
if REDIS_ENABLED:
|
||||||
|
cache = get_cache_service()
|
||||||
|
await cache.set(cache_key, styles_list, ttl=300)
|
||||||
|
|
||||||
|
return styles_list
|
||||||
|
|
||||||
|
|
||||||
|
def check_user_api_key(user_id: str, provider: str) -> None:
|
||||||
|
"""
|
||||||
|
检查系统是否配置了指定 Provider 的 API Key。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID(保留参数兼容性)
|
||||||
|
provider: 服务商 ID (如 'dashscope', 'openai', 'kling' 等)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AppException: 如果系统未配置该 Provider 的 API Key
|
||||||
|
"""
|
||||||
|
if _is_test_env():
|
||||||
|
return
|
||||||
|
|
||||||
|
import os
|
||||||
|
env_map = {
|
||||||
|
"dashscope": "DASHSCOPE_API_KEY",
|
||||||
|
"openai": "OPENAI_API_KEY",
|
||||||
|
"kling": "KLING_ACCESS_KEY",
|
||||||
|
"midjourney": "MIDJOURNEY_API_KEY",
|
||||||
|
"modelscope": "MODELSCOPE_API_TOKEN",
|
||||||
|
"volcengine": "VOLCENGINE_API_KEY",
|
||||||
|
"minimax": "MINIMAX_API_KEY",
|
||||||
|
"google": "GOOGLE_API_KEY",
|
||||||
|
}
|
||||||
|
env_var = env_map.get(provider)
|
||||||
|
if env_var and os.getenv(env_var):
|
||||||
|
return
|
||||||
|
|
||||||
|
provider_names = {
|
||||||
|
"dashscope": "阿里云 DashScope",
|
||||||
|
"openai": "OpenAI",
|
||||||
|
"kling": "可灵 AI",
|
||||||
|
"midjourney": "Midjourney",
|
||||||
|
"modelscope": "ModelScope",
|
||||||
|
"volcengine": "火山引擎",
|
||||||
|
"minimax": "MiniMax",
|
||||||
|
"google": "Google",
|
||||||
|
}
|
||||||
|
provider_name = provider_names.get(provider, provider)
|
||||||
|
|
||||||
|
raise AppException(
|
||||||
|
message=f"未配置 {provider_name} 的 API Key",
|
||||||
|
code=ErrorCode.INVALID_PARAMETER,
|
||||||
|
status_code=400,
|
||||||
|
details={
|
||||||
|
"error": "API_KEY_NOT_CONFIGURED",
|
||||||
|
"message": f"尚未配置 {provider_name} 的 API Key,无法提交生成任务。",
|
||||||
|
"provider": provider,
|
||||||
|
"provider_name": provider_name,
|
||||||
|
"solution": "请联系管理员在 .env 中配置对应的 API Key。",
|
||||||
|
}
|
||||||
|
)
|
||||||
208
backend/src/api/generations/image.py
Normal file
208
backend/src/api/generations/image.py
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
"""
|
||||||
|
Image Generation Endpoints - 图片生成模块
|
||||||
|
|
||||||
|
包含图片生成相关的端点:
|
||||||
|
- POST /generations/image - 图片生成
|
||||||
|
- POST /image/upscale - 图片放大
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
from src.models.schemas import (
|
||||||
|
ResponseModel,
|
||||||
|
ImageGenerationRequest,
|
||||||
|
UpscaleImageRequest, UpscaleImageResponse
|
||||||
|
)
|
||||||
|
from src.services.provider.registry import ModelRegistry, ModelType
|
||||||
|
from src.services.task_manager import task_manager
|
||||||
|
from src.auth.dependencies import get_current_user, UserAuth
|
||||||
|
|
||||||
|
from .helpers import resolve_service, ensure_url, check_user_api_key
|
||||||
|
from src.utils.errors import ErrorCode, AppException
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/generations/image", response_model=ResponseModel)
|
||||||
|
async def generate_image(
|
||||||
|
request: ImageGenerationRequest,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
""" 泛型 Image Generation Endpoint.
|
||||||
|
Supports generation for any context (Storyboard, Asset, etc.)
|
||||||
|
"""
|
||||||
|
# Logging
|
||||||
|
logger.info(f"Image generation request: model={request.model}, prompt={request.prompt[:50]}...")
|
||||||
|
|
||||||
|
# 1. Resolve Model - 使用统一的路由逻辑
|
||||||
|
image_service = resolve_service(request.model, ModelType.IMAGE)
|
||||||
|
|
||||||
|
if not image_service:
|
||||||
|
raise AppException(
|
||||||
|
message="Image service not configured",
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1.5 检查用户是否配置了 API Key
|
||||||
|
# 从 image_service 获取 provider_id
|
||||||
|
provider = getattr(image_service, 'provider_id', None)
|
||||||
|
if provider:
|
||||||
|
check_user_api_key(current_user.id, provider)
|
||||||
|
|
||||||
|
# 2. Resolve size from aspect_ratio and resolution (BEFORE creating task)
|
||||||
|
aspect_ratio = request.aspect_ratio
|
||||||
|
size = None
|
||||||
|
|
||||||
|
# 调试 logging
|
||||||
|
logger.info(f"[ImageGeneration] Request received:")
|
||||||
|
logger.info(f" - Prompt: {request.prompt[:100]}...")
|
||||||
|
logger.info(f" - Model: {request.model}")
|
||||||
|
logger.info(f" - Aspect Ratio: {aspect_ratio}")
|
||||||
|
logger.info(f" - Resolution: {request.resolution}")
|
||||||
|
logger.info(f" - Image inputs count: {len(request.image_inputs or [])}")
|
||||||
|
if request.image_inputs:
|
||||||
|
for i, url in enumerate(request.image_inputs):
|
||||||
|
logger.info(f" - Image input {i}: {url[:100]}...")
|
||||||
|
|
||||||
|
if aspect_ratio:
|
||||||
|
# Look up resolution in model config
|
||||||
|
model_config = getattr(image_service, "config", {})
|
||||||
|
resolutions_config = model_config.get("resolutions", {})
|
||||||
|
|
||||||
|
# Use provided resolution level (e.g. "1K", "2K", "4K") or default to "1K"
|
||||||
|
res_level = request.resolution or "1K"
|
||||||
|
|
||||||
|
if resolutions_config and res_level in resolutions_config and isinstance(resolutions_config[res_level], dict):
|
||||||
|
ratio_map = resolutions_config[res_level]
|
||||||
|
if aspect_ratio in ratio_map:
|
||||||
|
size = ratio_map[aspect_ratio]
|
||||||
|
logger.info(f" - Resolved size from config: {size} (resolution={res_level}, ratio={aspect_ratio})")
|
||||||
|
|
||||||
|
if not size:
|
||||||
|
defaults = {
|
||||||
|
"1K": {
|
||||||
|
"16:9": "1280*720",
|
||||||
|
"9:16": "720*1280",
|
||||||
|
"1:1": "1024*1024",
|
||||||
|
"4:3": "1280*960",
|
||||||
|
"3:4": "960*1280",
|
||||||
|
"2.35:1": "1280*544"
|
||||||
|
},
|
||||||
|
"2K": {
|
||||||
|
"16:9": "2560*1440",
|
||||||
|
"9:16": "1440*2560",
|
||||||
|
"1:1": "2048*2048",
|
||||||
|
"4:3": "2560*1920",
|
||||||
|
"3:4": "1920*2560"
|
||||||
|
},
|
||||||
|
"4K": {
|
||||||
|
"16:9": "3840*2160",
|
||||||
|
"9:16": "2160*3840",
|
||||||
|
"1:1": "4096*4096",
|
||||||
|
"4:3": "3840*2880",
|
||||||
|
"3:4": "2880*3840"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
size = defaults.get(res_level, {}).get(aspect_ratio)
|
||||||
|
if size:
|
||||||
|
logger.info(f" - Resolved size from defaults: {size} (resolution={res_level}, ratio={aspect_ratio})")
|
||||||
|
|
||||||
|
# 3. Create Task in DB with resolved params
|
||||||
|
try:
|
||||||
|
model_name = request.model or getattr(image_service, "model_id", "default_image")
|
||||||
|
|
||||||
|
# 构建 params dict with resolved size
|
||||||
|
task_params = request.model_dump()
|
||||||
|
task_params["size"] = size # Override with resolved size (provider expects 'size')
|
||||||
|
|
||||||
|
# lora_strength 已经在 extra_params 中(前端传递),不需要额外处理
|
||||||
|
# TaskManager 会从 extra_params 中提取并合并到一级参数
|
||||||
|
|
||||||
|
# 清理 Image URLs (Convert relative paths, localhost URLs, Base64)
|
||||||
|
logger.info(f"[ImageGeneration] Before URL sanitization - image_inputs: {task_params.get('image_inputs')}")
|
||||||
|
|
||||||
|
for key in ["image_inputs"]:
|
||||||
|
if key in task_params and isinstance(task_params[key], list):
|
||||||
|
original_urls = task_params[key]
|
||||||
|
task_params[key] = [ensure_url(url) for url in task_params[key] if url]
|
||||||
|
logger.info(f"[ImageGeneration] Sanitized {key}: {original_urls} -> {task_params[key]}")
|
||||||
|
|
||||||
|
# Use unified task manager
|
||||||
|
# Get user_id from authenticated user
|
||||||
|
user_id = current_user.id
|
||||||
|
|
||||||
|
# Handle project_id from multiple sources
|
||||||
|
# Priority 1: request.project_id (from request body, alias="projectId")
|
||||||
|
# Priority 2: request.source_id if source=='project' (from extra_params, alias="sourceId")
|
||||||
|
project_id = request.project_id
|
||||||
|
if not project_id and request.source == 'project':
|
||||||
|
project_id = request.source_id
|
||||||
|
|
||||||
|
task = await task_manager.create_task(
|
||||||
|
task_type="image",
|
||||||
|
model=model_name,
|
||||||
|
params=task_params,
|
||||||
|
user_id=user_id,
|
||||||
|
project_id=project_id
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create task: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message=f"Failed to create task: {e}",
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
# 日志 context if provided
|
||||||
|
if request.source and request.source_id:
|
||||||
|
logger.info(f"Image generation started for {request.source}:{request.source_id}, task_id: {task.id}")
|
||||||
|
|
||||||
|
# Unified task manager handles execution automatically
|
||||||
|
return ResponseModel(data={
|
||||||
|
"task_id": task.id,
|
||||||
|
"status": task.status
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/image/upscale", response_model=ResponseModel)
|
||||||
|
async def upscale_image(request: UpscaleImageRequest):
|
||||||
|
"""Upscale image resolution"""
|
||||||
|
upscale_service = ModelRegistry.get_default(ModelType.UPSCALE)
|
||||||
|
if not upscale_service:
|
||||||
|
raise AppException(
|
||||||
|
message="Default Upscale service not configured",
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Prepare kwargs
|
||||||
|
upscale_kwargs = {
|
||||||
|
"image_url": request.image_url,
|
||||||
|
"rate": request.rate
|
||||||
|
}
|
||||||
|
if request.extra_params:
|
||||||
|
upscale_kwargs.update(request.extra_params)
|
||||||
|
|
||||||
|
upscaled_url = await upscale_service.upscale_image(**upscale_kwargs)
|
||||||
|
|
||||||
|
if not upscaled_url:
|
||||||
|
raise AppException(
|
||||||
|
message="Upscaling failed or returned no URL",
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
return ResponseModel(data=UpscaleImageResponse(
|
||||||
|
original_url=request.image_url,
|
||||||
|
upscaled_url=upscaled_url
|
||||||
|
))
|
||||||
|
except Exception as e:
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
82
backend/src/api/generations/music.py
Normal file
82
backend/src/api/generations/music.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
"""
|
||||||
|
Music Generation Endpoints - 音乐生成模块
|
||||||
|
|
||||||
|
包含音乐生成相关端点:
|
||||||
|
- POST /generations/music - 音乐生成
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
from src.models.schemas import ResponseModel, MusicGenerationRequest
|
||||||
|
from src.services.provider.registry import ModelType
|
||||||
|
from src.services.task_manager import task_manager
|
||||||
|
from src.auth.dependencies import get_current_user, UserAuth
|
||||||
|
|
||||||
|
from .helpers import resolve_service, check_user_api_key
|
||||||
|
from src.utils.errors import ErrorCode, AppException
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/generations/music", response_model=ResponseModel)
|
||||||
|
async def generate_music(
|
||||||
|
request: MusicGenerationRequest,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""统一 Music/Lyrics Generation Endpoint."""
|
||||||
|
logger.info("Music unified generation request: mode=%s, model=%s", request.generation_mode, request.model)
|
||||||
|
|
||||||
|
# 1) Resolve model by mode
|
||||||
|
model_type = ModelType.LYRICS if request.generation_mode == "lyrics" else ModelType.MUSIC
|
||||||
|
music_service = resolve_service(request.model, model_type)
|
||||||
|
|
||||||
|
# 1.5 检查用户是否配置了 API Key
|
||||||
|
# 从 music_service 获取 provider_id
|
||||||
|
provider = getattr(music_service, 'provider_id', None)
|
||||||
|
if provider:
|
||||||
|
check_user_api_key(current_user.id, provider)
|
||||||
|
|
||||||
|
if not music_service:
|
||||||
|
raise AppException(
|
||||||
|
message=f"{request.generation_mode} service not configured",
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2) Unified async task for both lyrics/music modes
|
||||||
|
try:
|
||||||
|
model_name = request.model or getattr(music_service, "model_id", "default_music")
|
||||||
|
task_params = request.model_dump()
|
||||||
|
|
||||||
|
# Handle project_id from multiple sources
|
||||||
|
# Priority 1: request.project_id (from request body, alias="projectId")
|
||||||
|
# Priority 2: request.source_id if source=='project' (from extra_params, alias="sourceId")
|
||||||
|
project_id = request.project_id
|
||||||
|
if not project_id and request.source == "project":
|
||||||
|
project_id = request.source_id
|
||||||
|
|
||||||
|
# Get user_id from authenticated user
|
||||||
|
user_id = current_user.id
|
||||||
|
|
||||||
|
task = await task_manager.create_task(
|
||||||
|
task_type="music",
|
||||||
|
model=model_name,
|
||||||
|
params=task_params,
|
||||||
|
user_id=user_id,
|
||||||
|
project_id=project_id
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to create unified music task: %s", e)
|
||||||
|
raise AppException(
|
||||||
|
message=f"Failed to create task: {e}",
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
return ResponseModel(data={
|
||||||
|
"task_id": task.id,
|
||||||
|
"status": task.status,
|
||||||
|
"mode": request.generation_mode
|
||||||
|
})
|
||||||
407
backend/src/api/generations/script.py
Normal file
407
backend/src/api/generations/script.py
Normal file
@@ -0,0 +1,407 @@
|
|||||||
|
"""
|
||||||
|
Script Analysis Endpoints - 脚本分析模块
|
||||||
|
|
||||||
|
包含所有与 LLM 脚本分析相关的端点:
|
||||||
|
- /script/analyze - 分析小说文本
|
||||||
|
- /prompt/optimize - 优化提示词
|
||||||
|
- /style/recommend - 推荐风格
|
||||||
|
- /script/summary - 小说摘要
|
||||||
|
- /script/characters - 角色提取
|
||||||
|
- /script/scenes - 场景提取
|
||||||
|
- /script/props - 道具提取
|
||||||
|
- /script/storyboards - 分镜拆分
|
||||||
|
- /script/split - 章节拆分
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from fastapi import APIRouter, BackgroundTasks, Depends
|
||||||
|
|
||||||
|
from src.auth.dependencies import get_current_user
|
||||||
|
from src.auth.models import UserAuth
|
||||||
|
from src.models.schemas import (
|
||||||
|
ResponseModel,
|
||||||
|
ScriptProcessRequest, ScriptResponse,
|
||||||
|
PromptOptimizationRequest, PromptOptimizationResponse,
|
||||||
|
StyleRecommendationRequest, StyleRecommendationResponse,
|
||||||
|
NovelSummaryRequest, NovelSummaryResponse,
|
||||||
|
CharacterExtractionRequest, CharacterExtractionResponse,
|
||||||
|
SceneExtractionRequest, SceneExtractionResponse,
|
||||||
|
PropExtractionRequest, PropExtractionResponse,
|
||||||
|
StoryboardSplitRequest, StoryboardSplitResponse,
|
||||||
|
ChapterSplitRequest, ChapterSplitResponse,
|
||||||
|
CharacterAsset, SceneAsset, PropAsset
|
||||||
|
)
|
||||||
|
from src.services.script import script_service
|
||||||
|
from src.services.project_service import project_manager
|
||||||
|
|
||||||
|
from .helpers import get_available_styles_from_config
|
||||||
|
from src.utils.errors import ErrorCode, AppException
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/script/analyze", response_model=ResponseModel)
|
||||||
|
async def analyze_script(
|
||||||
|
request: ScriptProcessRequest,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""Analyze novel text and generate script"""
|
||||||
|
try:
|
||||||
|
# Handle skip_storyboard priority (explicit param vs extra_params)
|
||||||
|
skip_sb = request.skip_storyboard
|
||||||
|
extra = request.extra_params or {}
|
||||||
|
if 'skip_storyboard' in extra:
|
||||||
|
skip_sb = extra.pop('skip_storyboard')
|
||||||
|
|
||||||
|
response_data = await script_service.analyze_novel(
|
||||||
|
novel_text=request.novel_text,
|
||||||
|
project_id=request.project_id,
|
||||||
|
model_name=request.model,
|
||||||
|
max_input_tokens=request.max_input_tokens,
|
||||||
|
skip_storyboard=skip_sb,
|
||||||
|
user_id=current_user.id,
|
||||||
|
**extra
|
||||||
|
)
|
||||||
|
|
||||||
|
# Automatically save extracted assets to the project
|
||||||
|
if request.project_id:
|
||||||
|
# 获取 existing project assets for deduplication
|
||||||
|
existing_assets_map = {}
|
||||||
|
try:
|
||||||
|
current_project = project_manager.get_project(request.project_id)
|
||||||
|
if current_project and current_project.assets:
|
||||||
|
for asset in current_project.assets:
|
||||||
|
existing_assets_map[(asset.type, asset.name)] = asset
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to fetch project for deduplication: {e}")
|
||||||
|
|
||||||
|
# 辅助函数 to create and add asset
|
||||||
|
def save_asset(item: dict, asset_type: str):
|
||||||
|
name = item.get('name')
|
||||||
|
if not name:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if exists
|
||||||
|
existing_asset = existing_assets_map.get((asset_type, name))
|
||||||
|
|
||||||
|
if existing_asset:
|
||||||
|
logger.info(f"Asset {name} ({asset_type}) already exists, updating.")
|
||||||
|
asset_id = existing_asset.id
|
||||||
|
else:
|
||||||
|
asset_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# 创建 specific asset models
|
||||||
|
try:
|
||||||
|
asset = None
|
||||||
|
if asset_type == 'character':
|
||||||
|
asset = CharacterAsset(
|
||||||
|
id=asset_id,
|
||||||
|
name=item.get('name'),
|
||||||
|
desc=item.get('desc', ''),
|
||||||
|
tags=item.get('tags', []),
|
||||||
|
age=item.get('age'),
|
||||||
|
role=item.get('role'),
|
||||||
|
appearance=item.get('appearance'),
|
||||||
|
image_prompt=item.get('image_prompt')
|
||||||
|
)
|
||||||
|
elif asset_type == 'scene':
|
||||||
|
asset = SceneAsset(
|
||||||
|
id=asset_id,
|
||||||
|
name=item.get('name'),
|
||||||
|
desc=item.get('desc', ''),
|
||||||
|
tags=item.get('tags', []),
|
||||||
|
location=item.get('location'),
|
||||||
|
time_of_day=item.get('time_of_day'),
|
||||||
|
atmosphere=item.get('atmosphere'),
|
||||||
|
image_prompt=item.get('image_prompt')
|
||||||
|
)
|
||||||
|
elif asset_type == 'prop':
|
||||||
|
asset = PropAsset(
|
||||||
|
id=asset_id,
|
||||||
|
name=item.get('name'),
|
||||||
|
desc=item.get('desc', ''),
|
||||||
|
tags=item.get('tags', []),
|
||||||
|
usage=item.get('usage'),
|
||||||
|
image_prompt=item.get('image_prompt')
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
|
if existing_asset:
|
||||||
|
project_manager.update_asset(request.project_id, asset_id, asset)
|
||||||
|
else:
|
||||||
|
project_manager.add_asset(request.project_id, asset)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save asset {item.get('name')}: {e}")
|
||||||
|
|
||||||
|
# Save characters
|
||||||
|
for char in (response_data.characters or []):
|
||||||
|
save_asset(char.model_dump(), 'character')
|
||||||
|
|
||||||
|
# Save scenes
|
||||||
|
for scene in (response_data.scenes or []):
|
||||||
|
save_asset(scene.model_dump(), 'scene')
|
||||||
|
|
||||||
|
# Save props
|
||||||
|
for prop in (response_data.props or []):
|
||||||
|
save_asset(prop.model_dump(), 'prop')
|
||||||
|
|
||||||
|
# Save summary if available
|
||||||
|
if response_data.summary:
|
||||||
|
try:
|
||||||
|
project_manager.update_project(request.project_id, {"description": response_data.summary})
|
||||||
|
logger.info(f"Updated project description with generated summary.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to update project description: {e}")
|
||||||
|
|
||||||
|
return ResponseModel(data=response_data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Script analysis failed: {str(e)}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/prompt/optimize", response_model=ResponseModel)
|
||||||
|
async def optimize_prompt_endpoint(
|
||||||
|
request: PromptOptimizationRequest,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""Optimize prompt for image or video generation"""
|
||||||
|
try:
|
||||||
|
optimized_prompt = await script_service.optimize_prompt(
|
||||||
|
prompt=request.prompt,
|
||||||
|
target_type=request.target_type,
|
||||||
|
template=request.template,
|
||||||
|
model_name=request.model,
|
||||||
|
provider=request.provider,
|
||||||
|
language=request.language or "Chinese",
|
||||||
|
user_id=current_user.id,
|
||||||
|
**(request.extra_params or {})
|
||||||
|
)
|
||||||
|
return ResponseModel(data=PromptOptimizationResponse(
|
||||||
|
original_prompt=request.prompt,
|
||||||
|
optimized_prompt=optimized_prompt,
|
||||||
|
target_type=request.target_type
|
||||||
|
))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Prompt optimization failed: {str(e)}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/style/recommend", response_model=ResponseModel)
|
||||||
|
async def recommend_style(
|
||||||
|
request: StyleRecommendationRequest,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""Recommend an art style based on novel text.
|
||||||
|
|
||||||
|
Backend automatically fetches available styles with full descriptions from config.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Automatically fetch styles with full information from backend config
|
||||||
|
styles_full = await get_available_styles_from_config()
|
||||||
|
logger.info(f"Auto-fetched {len(styles_full)} styles from backend config")
|
||||||
|
|
||||||
|
# 格式化 style information for AI: include name and description
|
||||||
|
available_styles = [
|
||||||
|
f"{style['name']} - {style.get('desc', '')}"
|
||||||
|
for style in styles_full
|
||||||
|
if style.get('name')
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info(f"Formatted {len(available_styles)} styles with descriptions for AI")
|
||||||
|
|
||||||
|
result = await script_service.recommend_style(
|
||||||
|
novel_text=request.novel_text,
|
||||||
|
available_styles=available_styles,
|
||||||
|
model_name=request.model,
|
||||||
|
provider=request.provider,
|
||||||
|
user_id=current_user.id,
|
||||||
|
**(request.extra_params or {})
|
||||||
|
)
|
||||||
|
return ResponseModel(data=StyleRecommendationResponse(**result))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Style recommendation failed: {str(e)}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/script/summary", response_model=ResponseModel)
|
||||||
|
async def summarize_novel_endpoint(
|
||||||
|
request: NovelSummaryRequest,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
""" 总和marize novel text into a concise overview (without characters)"""
|
||||||
|
try:
|
||||||
|
result = await script_service.summarize_novel(
|
||||||
|
novel_text=request.novel_text,
|
||||||
|
language=request.language,
|
||||||
|
model_name=request.model,
|
||||||
|
provider=request.provider,
|
||||||
|
global_summary=request.global_summary,
|
||||||
|
user_id=current_user.id,
|
||||||
|
**(request.extra_params or {})
|
||||||
|
)
|
||||||
|
return ResponseModel(data=NovelSummaryResponse(**result))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Novel summary failed: {str(e)}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/script/characters", response_model=ResponseModel)
|
||||||
|
async def extract_characters_endpoint(
|
||||||
|
request: CharacterExtractionRequest,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
""" 提取 characters from novel text"""
|
||||||
|
try:
|
||||||
|
result = await script_service.extract_characters(
|
||||||
|
novel_text=request.novel_text,
|
||||||
|
language=request.language,
|
||||||
|
model_name=request.model,
|
||||||
|
provider=request.provider,
|
||||||
|
global_summary=request.global_summary,
|
||||||
|
known_characters=request.known_characters,
|
||||||
|
user_id=current_user.id,
|
||||||
|
**(request.extra_params or {})
|
||||||
|
)
|
||||||
|
return ResponseModel(data=CharacterExtractionResponse(**result))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Character extraction failed: {str(e)}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/script/scenes", response_model=ResponseModel)
|
||||||
|
async def extract_scenes_endpoint(
|
||||||
|
request: SceneExtractionRequest,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
""" 提取 scenes from novel text"""
|
||||||
|
try:
|
||||||
|
result = await script_service.extract_scenes(
|
||||||
|
novel_text=request.novel_text,
|
||||||
|
language=request.language,
|
||||||
|
model_name=request.model,
|
||||||
|
global_summary=request.global_summary,
|
||||||
|
known_scenes=request.known_scenes,
|
||||||
|
user_id=current_user.id,
|
||||||
|
**(request.extra_params or {})
|
||||||
|
)
|
||||||
|
return ResponseModel(data=SceneExtractionResponse(**result))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Scene extraction failed: {str(e)}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/script/props", response_model=ResponseModel)
|
||||||
|
async def extract_props_endpoint(
|
||||||
|
request: PropExtractionRequest,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
""" 提取 props from novel text"""
|
||||||
|
try:
|
||||||
|
result = await script_service.extract_props(
|
||||||
|
novel_text=request.novel_text,
|
||||||
|
language=request.language,
|
||||||
|
model_name=request.model,
|
||||||
|
global_summary=request.global_summary,
|
||||||
|
known_props=request.known_props,
|
||||||
|
user_id=current_user.id,
|
||||||
|
**(request.extra_params or {})
|
||||||
|
)
|
||||||
|
return ResponseModel(data=PropExtractionResponse(**result))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Prop extraction failed: {str(e)}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/script/storyboards", response_model=ResponseModel)
|
||||||
|
async def split_storyboards_endpoint(
|
||||||
|
request: StoryboardSplitRequest,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""Split novel text into storyboards (shots)"""
|
||||||
|
try:
|
||||||
|
result = await script_service.split_storyboards(
|
||||||
|
novel_text=request.novel_text,
|
||||||
|
project_id=request.project_id,
|
||||||
|
model_name=request.model,
|
||||||
|
provider=request.provider,
|
||||||
|
language=request.language,
|
||||||
|
known_characters=request.known_characters,
|
||||||
|
known_scenes=request.known_scenes,
|
||||||
|
known_props=request.known_props,
|
||||||
|
user_id=current_user.id,
|
||||||
|
**(request.extra_params or {})
|
||||||
|
)
|
||||||
|
return ResponseModel(data=StoryboardSplitResponse(**result))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Storyboard splitting failed: {str(e)}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/script/split", response_model=ResponseModel)
|
||||||
|
async def split_chapters_endpoint(
|
||||||
|
request: ChapterSplitRequest,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""Split novel text into chapters using regex or agent"""
|
||||||
|
try:
|
||||||
|
# Check if agent splitting is requested via extra_params
|
||||||
|
extra = request.extra_params or {}
|
||||||
|
use_agent = extra.get("use_agent", False)
|
||||||
|
|
||||||
|
chapters = await script_service.split_chapters(
|
||||||
|
novel_text=request.novel_text,
|
||||||
|
regex_pattern=request.regex_pattern,
|
||||||
|
use_agent=use_agent,
|
||||||
|
model_name=request.model,
|
||||||
|
language=extra.get("language", "Chinese"),
|
||||||
|
user_id=current_user.id
|
||||||
|
)
|
||||||
|
return ResponseModel(data=ChapterSplitResponse(
|
||||||
|
chapters=chapters,
|
||||||
|
total_chapters=len(chapters)
|
||||||
|
))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Chapter split failed: {str(e)}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
212
backend/src/api/generations/tasks.py
Normal file
212
backend/src/api/generations/tasks.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
"""
|
||||||
|
Task Status Endpoints - 任务状态模块
|
||||||
|
|
||||||
|
包含任务管理相关的端点:
|
||||||
|
- GET /tasks - 列出任务
|
||||||
|
- GET /tasks/{task_id} - 获取任务状态
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Query, Request
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from src.models.schemas import ResponseModel, PaginationParams
|
||||||
|
from src.utils.pagination import Paginator
|
||||||
|
from src.services.provider.registry import ModelRegistry, ModelType
|
||||||
|
from src.services.provider.base import TaskStatus
|
||||||
|
from src.services.task_manager import task_manager
|
||||||
|
from src.services.storage_service import storage_manager
|
||||||
|
from src.config.settings import OSS_BUCKET, OSS_REGION
|
||||||
|
from src.utils.errors import TaskNotFoundException, AppException, ErrorCode
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/tasks", response_model=ResponseModel)
|
||||||
|
async def list_tasks(
|
||||||
|
request: Request,
|
||||||
|
type: Optional[str] = Query(None, description="Filter by task type"),
|
||||||
|
page: int = Query(1, ge=1, description="页码,从 1 开始"),
|
||||||
|
page_size: int = Query(20, ge=1, le=100, description="每页数量,最大 100"),
|
||||||
|
):
|
||||||
|
"""列表 tasks with optional filtering.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
type: 任务类型过滤
|
||||||
|
page: 页码
|
||||||
|
page_size: 每页数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
分页的任务列表
|
||||||
|
"""
|
||||||
|
# 计算偏移量
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
|
# 获取任务列表
|
||||||
|
tasks = task_manager.list_tasks(type=type, limit=page_size, offset=offset)
|
||||||
|
|
||||||
|
# 获取总数(简化处理,实际应该查询总数)
|
||||||
|
total = len(tasks)
|
||||||
|
|
||||||
|
# 创建分页器
|
||||||
|
paginator = Paginator(
|
||||||
|
items=tasks,
|
||||||
|
total=total,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size
|
||||||
|
)
|
||||||
|
|
||||||
|
return paginator.to_response(request)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/tasks/{task_id}", response_model=ResponseModel)
|
||||||
|
async def get_task_status(task_id: str):
|
||||||
|
"""
|
||||||
|
Check status of a task.
|
||||||
|
"""
|
||||||
|
# 1. Check local TaskManager
|
||||||
|
task = await task_manager.get_task(task_id)
|
||||||
|
|
||||||
|
if task:
|
||||||
|
# If task is active (pending/processing), check provider status
|
||||||
|
active_statuses = ["pending", "processing", "queued", "running", "submitted"]
|
||||||
|
if task.status.lower() in active_statuses and task.provider_task_id:
|
||||||
|
service = ModelRegistry.get(task.model)
|
||||||
|
# Fallback to default if model service not found by name
|
||||||
|
if not service:
|
||||||
|
if task.type == "image":
|
||||||
|
service = ModelRegistry.get_default(ModelType.IMAGE)
|
||||||
|
elif task.type == "video":
|
||||||
|
service = ModelRegistry.get_default(ModelType.VIDEO)
|
||||||
|
elif task.type == "audio":
|
||||||
|
service = ModelRegistry.get_default(ModelType.AUDIO)
|
||||||
|
elif task.type == "music":
|
||||||
|
service = ModelRegistry.get_default(ModelType.MUSIC)
|
||||||
|
|
||||||
|
if service and hasattr(service, 'check_status'):
|
||||||
|
try:
|
||||||
|
result = await service.check_status(task.provider_task_id, user_id=task.user_id)
|
||||||
|
|
||||||
|
# 更新 status
|
||||||
|
# 映射 TaskStatus enum to string
|
||||||
|
new_status = result.status.value.lower()
|
||||||
|
if new_status == "succeeded":
|
||||||
|
new_status = "success"
|
||||||
|
|
||||||
|
# If succeeded, process results (OSS upload)
|
||||||
|
final_results = {}
|
||||||
|
if result.status == TaskStatus.SUCCEEDED and result.results:
|
||||||
|
processed_urls = []
|
||||||
|
for idx, item in enumerate(result.results):
|
||||||
|
if item.url:
|
||||||
|
# Determine extension
|
||||||
|
ext = "png"
|
||||||
|
if task.type == "video":
|
||||||
|
ext = "mp4"
|
||||||
|
elif task.type in ("audio", "music"):
|
||||||
|
ext = "mp3"
|
||||||
|
|
||||||
|
key = f"generated/{task.type}s/{task.id}_{idx}.{ext}"
|
||||||
|
|
||||||
|
# Save to storage (Local or OSS)
|
||||||
|
oss_url = await asyncio.to_thread(storage_manager.save_from_url, item.url, key)
|
||||||
|
processed_urls.append(oss_url or item.url)
|
||||||
|
|
||||||
|
final_results = {"urls": processed_urls}
|
||||||
|
|
||||||
|
# 更新 task in DB
|
||||||
|
update_data = {"status": new_status}
|
||||||
|
if final_results:
|
||||||
|
update_data["result"] = final_results
|
||||||
|
if result.error:
|
||||||
|
update_data["error"] = str(result.error)
|
||||||
|
|
||||||
|
task = task_manager.update_task(task.id, **update_data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to check provider status for task {task.id}: {e}")
|
||||||
|
# Don't fail the request, just return current state
|
||||||
|
|
||||||
|
# Retry OSS upload if succeeded but urls are not OSS (resilience)
|
||||||
|
elif (task.status.lower() == "succeeded" or task.status.lower() == "success") and task.result and "urls" in task.result:
|
||||||
|
try:
|
||||||
|
urls = task.result["urls"]
|
||||||
|
updated_urls = []
|
||||||
|
needs_update = False
|
||||||
|
|
||||||
|
for idx, url in enumerate(urls):
|
||||||
|
# Check if it looks like an OSS url
|
||||||
|
is_oss = False
|
||||||
|
if OSS_BUCKET and OSS_REGION and f"{OSS_BUCKET}.{OSS_REGION}" in url:
|
||||||
|
is_oss = True
|
||||||
|
|
||||||
|
if not is_oss:
|
||||||
|
# Attempt upload
|
||||||
|
ext = "png"
|
||||||
|
if task.type == "video":
|
||||||
|
ext = "mp4"
|
||||||
|
elif task.type in ("audio", "music"):
|
||||||
|
ext = "mp3"
|
||||||
|
key = f"generated/{task.type}s/{task.id}_{idx}.{ext}"
|
||||||
|
|
||||||
|
new_url = await asyncio.to_thread(storage_manager.save_from_url, url, key)
|
||||||
|
if new_url and new_url != url:
|
||||||
|
updated_urls.append(new_url)
|
||||||
|
needs_update = True
|
||||||
|
else:
|
||||||
|
updated_urls.append(url)
|
||||||
|
else:
|
||||||
|
# It is OSS, maybe re-sign it
|
||||||
|
signed_url = storage_manager.sign_url(url)
|
||||||
|
updated_urls.append(signed_url if signed_url else url)
|
||||||
|
|
||||||
|
if needs_update:
|
||||||
|
task = task_manager.update_task(task.id, result={"urls": updated_urls})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to retry OSS upload for task {task.id}: {e}")
|
||||||
|
|
||||||
|
return ResponseModel(data=task)
|
||||||
|
|
||||||
|
raise AppException(
|
||||||
|
message="Task not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/tasks/{task_id}/cancel", response_model=ResponseModel)
|
||||||
|
async def cancel_task(task_id: str):
|
||||||
|
"""
|
||||||
|
Cancel a pending/processing task.
|
||||||
|
仅更新本地任务状态为 CANCELLED,不保证下游 provider 真正中止计算,
|
||||||
|
但对于绝大多数异步生成场景已经足够(前端不再轮询,结果也不会再写入)。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cancelled = await task_manager.cancel_task(task_id)
|
||||||
|
except TaskNotFoundException:
|
||||||
|
raise AppException(
|
||||||
|
message="Task not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to cancel task {task_id}: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message="Failed to cancel task",
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
if not cancelled:
|
||||||
|
# 已完成任务无法取消
|
||||||
|
raise AppException(
|
||||||
|
message="Task already finished and cannot be cancelled",
|
||||||
|
code=ErrorCode.INVALID_PARAMETER,
|
||||||
|
status_code=400
|
||||||
|
)
|
||||||
|
|
||||||
|
# 返回简单结果,前端如需最新任务详情可以再调用 GET /tasks/{task_id}
|
||||||
|
return ResponseModel(data={"task_id": task_id, "cancelled": True})
|
||||||
136
backend/src/api/generations/video.py
Normal file
136
backend/src/api/generations/video.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
"""
|
||||||
|
Video Generation Endpoints - 视频生成模块
|
||||||
|
|
||||||
|
包含视频生成相关的端点:
|
||||||
|
- POST /generations/video - 视频生成
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
from src.models.schemas import ResponseModel, VideoGenerationRequest
|
||||||
|
from src.services.provider.registry import ModelType
|
||||||
|
from src.services.task_manager import task_manager
|
||||||
|
from src.auth.dependencies import get_current_user, UserAuth
|
||||||
|
|
||||||
|
from .helpers import resolve_service, ensure_url, check_user_api_key
|
||||||
|
from src.utils.errors import ErrorCode, AppException
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/generations/video", response_model=ResponseModel)
|
||||||
|
async def generate_video(
|
||||||
|
request: VideoGenerationRequest,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
""" 泛型 Video Generation Endpoint.
|
||||||
|
"""
|
||||||
|
# 0. Sanitize URLs Early (Convert Base64 -> URL)
|
||||||
|
try:
|
||||||
|
if request.image_inputs:
|
||||||
|
request.image_inputs = [ensure_url(url) for url in request.image_inputs if url]
|
||||||
|
if request.video_inputs:
|
||||||
|
request.video_inputs = [ensure_url(url) for url in request.video_inputs if url]
|
||||||
|
if request.audio_inputs:
|
||||||
|
request.audio_inputs = [ensure_url(url) for url in request.audio_inputs if url]
|
||||||
|
|
||||||
|
# Also check extra_params for common image keys
|
||||||
|
if request.extra_params:
|
||||||
|
for key in ["image_url", "video_url", "input_image", "input_video", "audio_url"]:
|
||||||
|
if key in request.extra_params and isinstance(request.extra_params[key], str):
|
||||||
|
request.extra_params[key] = ensure_url(request.extra_params[key])
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to sanitize inputs: {e}")
|
||||||
|
|
||||||
|
# 1. Resolve Model - 使用统一的路由逻辑
|
||||||
|
video_service = resolve_service(request.model, ModelType.VIDEO)
|
||||||
|
|
||||||
|
if not video_service:
|
||||||
|
raise AppException(
|
||||||
|
message="Video service not configured",
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1.5 检查用户是否配置了 API Key
|
||||||
|
# 从 video_service 获取 provider_id
|
||||||
|
provider = getattr(video_service, 'provider_id', None)
|
||||||
|
if provider:
|
||||||
|
check_user_api_key(current_user.id, provider)
|
||||||
|
|
||||||
|
# 2. Resolve size from aspect_ratio (BEFORE creating task)
|
||||||
|
aspect_ratio = request.aspect_ratio
|
||||||
|
size = None
|
||||||
|
|
||||||
|
if aspect_ratio:
|
||||||
|
# Look up resolution in model config
|
||||||
|
model_config = getattr(video_service, "config", {}) or {}
|
||||||
|
if isinstance(model_config, dict):
|
||||||
|
resolutions_config = model_config.get("resolutions") or {}
|
||||||
|
else:
|
||||||
|
resolutions_config = {}
|
||||||
|
|
||||||
|
# Use provided resolution level (e.g. "720P", "1080P") or default to "720P"
|
||||||
|
res_level = request.resolution or "720P"
|
||||||
|
|
||||||
|
if resolutions_config and res_level in resolutions_config and isinstance(resolutions_config[res_level], dict):
|
||||||
|
ratio_map = resolutions_config[res_level]
|
||||||
|
if aspect_ratio in ratio_map:
|
||||||
|
size = ratio_map[aspect_ratio]
|
||||||
|
|
||||||
|
if not size:
|
||||||
|
defaults = {
|
||||||
|
"16:9": "1280*720",
|
||||||
|
"9:16": "720*1280",
|
||||||
|
"1:1": "1280*1280",
|
||||||
|
"4:3": "1280*960",
|
||||||
|
"3:4": "960*1280"
|
||||||
|
}
|
||||||
|
size = defaults.get(aspect_ratio)
|
||||||
|
|
||||||
|
# 3. Create Task in DB with resolved params
|
||||||
|
try:
|
||||||
|
model_name = request.model or getattr(video_service, "model_id", "default_video")
|
||||||
|
|
||||||
|
# 构建 params dict with resolved size
|
||||||
|
task_params = request.model_dump()
|
||||||
|
if size:
|
||||||
|
task_params["size"] = size # Override with resolved size
|
||||||
|
|
||||||
|
# Use unified task manager
|
||||||
|
# Get user_id from authenticated user
|
||||||
|
user_id = current_user.id
|
||||||
|
|
||||||
|
# Handle project_id from multiple sources
|
||||||
|
# Priority 1: request.project_id (from request body, alias="projectId")
|
||||||
|
# Priority 2: request.source_id if source=='project' (from extra_params, alias="sourceId")
|
||||||
|
project_id = request.project_id
|
||||||
|
if not project_id and request.source == 'project':
|
||||||
|
project_id = request.source_id
|
||||||
|
|
||||||
|
task = await task_manager.create_task(
|
||||||
|
task_type="video",
|
||||||
|
model=model_name,
|
||||||
|
params=task_params,
|
||||||
|
user_id=user_id,
|
||||||
|
project_id=project_id
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create task: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message=f"Failed to create task: {e}",
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
# 日志 context if provided
|
||||||
|
if request.source and request.source_id:
|
||||||
|
logger.info(f"Video generation started for {request.source}:{request.source_id}, task_id: {task.id}")
|
||||||
|
|
||||||
|
# Unified task manager handles execution automatically
|
||||||
|
return ResponseModel(data={
|
||||||
|
"task_id": task.id,
|
||||||
|
"status": task.status
|
||||||
|
})
|
||||||
350
backend/src/api/health.py
Normal file
350
backend/src/api/health.py
Normal file
@@ -0,0 +1,350 @@
|
|||||||
|
"""
|
||||||
|
健康检查和监控端点
|
||||||
|
|
||||||
|
提供系统健康状态、性能指标和诊断信息
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, Any
|
||||||
|
from fastapi import APIRouter, Response
|
||||||
|
from prometheus_client import generate_latest, CONTENT_TYPE_LATEST
|
||||||
|
|
||||||
|
from src.models.schemas import ResponseModel
|
||||||
|
from src.services.task_manager import task_manager
|
||||||
|
from src.services.provider.registry import ModelRegistry
|
||||||
|
from src.services.provider.health import health_monitor
|
||||||
|
from src.config.database import engine, get_pool_status
|
||||||
|
from sqlmodel import Session, text
|
||||||
|
from src.utils.errors import ErrorCode, AppException
|
||||||
|
|
||||||
|
router = APIRouter(tags=["health"])
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 应用启动时间
|
||||||
|
_start_time = time.time()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/health", response_model=ResponseModel)
|
||||||
|
async def health_check():
|
||||||
|
"""基础健康检查
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
基本的健康状态信息
|
||||||
|
"""
|
||||||
|
return ResponseModel(data={
|
||||||
|
"status": "healthy",
|
||||||
|
"service": "Pixel Backend",
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"uptime_seconds": int(time.time() - _start_time)
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/health/detailed", response_model=ResponseModel)
|
||||||
|
async def detailed_health_check():
|
||||||
|
"""详细健康检查
|
||||||
|
|
||||||
|
检查所有关键组件的健康状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
详细的健康状态信息,包括各组件状态
|
||||||
|
"""
|
||||||
|
health_status = {
|
||||||
|
"status": "healthy",
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"uptime_seconds": int(time.time() - _start_time),
|
||||||
|
"components": {}
|
||||||
|
}
|
||||||
|
|
||||||
|
# 检查数据库连接
|
||||||
|
try:
|
||||||
|
start = time.time()
|
||||||
|
with Session(engine) as session:
|
||||||
|
session.exec(text("SELECT 1"))
|
||||||
|
latency_ms = (time.time() - start) * 1000
|
||||||
|
|
||||||
|
# Get connection pool status
|
||||||
|
pool_status = get_pool_status()
|
||||||
|
|
||||||
|
health_status["components"]["database"] = {
|
||||||
|
"status": "healthy",
|
||||||
|
"message": "Database connection successful",
|
||||||
|
"latency_ms": round(latency_ms, 2),
|
||||||
|
"pool": pool_status
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
health_status["status"] = "unhealthy"
|
||||||
|
health_status["components"]["database"] = {
|
||||||
|
"status": "unhealthy",
|
||||||
|
"message": f"Database connection failed: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 检查Redis连接
|
||||||
|
from src.config.settings import REDIS_ENABLED
|
||||||
|
if REDIS_ENABLED:
|
||||||
|
try:
|
||||||
|
from src.services.cache_service import get_cache_service
|
||||||
|
cache = get_cache_service()
|
||||||
|
|
||||||
|
if cache._connected:
|
||||||
|
start = time.time()
|
||||||
|
await cache._redis.ping()
|
||||||
|
latency_ms = (time.time() - start) * 1000
|
||||||
|
|
||||||
|
# Get Redis info
|
||||||
|
info = await cache._redis.info()
|
||||||
|
health_status["components"]["redis"] = {
|
||||||
|
"status": "healthy",
|
||||||
|
"message": "Redis connection successful",
|
||||||
|
"latency_ms": round(latency_ms, 2),
|
||||||
|
"version": info.get("redis_version"),
|
||||||
|
"connected_clients": info.get("connected_clients"),
|
||||||
|
"used_memory_human": info.get("used_memory_human")
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
health_status["status"] = "degraded"
|
||||||
|
health_status["components"]["redis"] = {
|
||||||
|
"status": "unhealthy",
|
||||||
|
"message": "Redis not connected"
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
health_status["status"] = "degraded"
|
||||||
|
health_status["components"]["redis"] = {
|
||||||
|
"status": "unhealthy",
|
||||||
|
"message": f"Redis connection failed: {str(e)}"
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
health_status["components"]["redis"] = {
|
||||||
|
"status": "disabled",
|
||||||
|
"message": "Redis is disabled in configuration"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 检查任务管理器
|
||||||
|
try:
|
||||||
|
stats = task_manager.get_stats()
|
||||||
|
health_status["components"]["task_manager"] = {
|
||||||
|
"status": "healthy",
|
||||||
|
"stats": stats
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
health_status["status"] = "degraded"
|
||||||
|
health_status["components"]["task_manager"] = {
|
||||||
|
"status": "unhealthy",
|
||||||
|
"message": f"Task manager error: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 检查模型注册表
|
||||||
|
try:
|
||||||
|
models = ModelRegistry.list_models()
|
||||||
|
health_status["components"]["model_registry"] = {
|
||||||
|
"status": "healthy",
|
||||||
|
"total_models": len(models)
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
health_status["status"] = "degraded"
|
||||||
|
health_status["components"]["model_registry"] = {
|
||||||
|
"status": "unhealthy",
|
||||||
|
"message": f"Model registry error: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 检查 AI 服务健康状态
|
||||||
|
try:
|
||||||
|
service_health = health_monitor.get_health_summary()
|
||||||
|
health_status["components"]["ai_services"] = {
|
||||||
|
"status": "healthy" if service_health["healthy"] == service_health["total"] else "degraded",
|
||||||
|
"summary": service_health
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
health_status["status"] = "degraded"
|
||||||
|
health_status["components"]["ai_services"] = {
|
||||||
|
"status": "unknown",
|
||||||
|
"message": f"Health monitor error: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
return ResponseModel(data=health_status)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/health/live", response_model=ResponseModel)
|
||||||
|
async def liveness_probe():
|
||||||
|
""" Kubernetes liveness probe
|
||||||
|
|
||||||
|
简单检查应用是否运行
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
200 OK if alive
|
||||||
|
"""
|
||||||
|
return ResponseModel(data={"status": "alive"})
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/health/ready", response_model=ResponseModel)
|
||||||
|
async def readiness_probe():
|
||||||
|
""" Kubernetes readiness probe
|
||||||
|
|
||||||
|
检查应用是否准备好接收流量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
200 OK if ready, 503 if not ready
|
||||||
|
"""
|
||||||
|
# 检查关键组件
|
||||||
|
ready = True
|
||||||
|
components = {}
|
||||||
|
|
||||||
|
# 检查数据库
|
||||||
|
try:
|
||||||
|
with Session(engine) as session:
|
||||||
|
session.exec(text("SELECT 1"))
|
||||||
|
components["database"] = "ready"
|
||||||
|
except Exception as e:
|
||||||
|
ready = False
|
||||||
|
components["database"] = f"not ready: {str(e)}"
|
||||||
|
|
||||||
|
# 检查Redis(如果启用)
|
||||||
|
from src.config.settings import REDIS_ENABLED
|
||||||
|
if REDIS_ENABLED:
|
||||||
|
try:
|
||||||
|
from src.services.cache_service import get_cache_service
|
||||||
|
cache = get_cache_service()
|
||||||
|
|
||||||
|
if cache._connected:
|
||||||
|
await cache._redis.ping()
|
||||||
|
components["redis"] = "ready"
|
||||||
|
else:
|
||||||
|
ready = False
|
||||||
|
components["redis"] = "not ready: not connected"
|
||||||
|
except Exception as e:
|
||||||
|
ready = False
|
||||||
|
components["redis"] = f"not ready: {str(e)}"
|
||||||
|
else:
|
||||||
|
components["redis"] = "disabled"
|
||||||
|
|
||||||
|
# 检查任务管理器
|
||||||
|
try:
|
||||||
|
task_manager.get_stats()
|
||||||
|
components["task_manager"] = "ready"
|
||||||
|
except Exception as e:
|
||||||
|
ready = False
|
||||||
|
components["task_manager"] = f"not ready: {str(e)}"
|
||||||
|
|
||||||
|
if ready:
|
||||||
|
return ResponseModel(data={"status": "ready", "components": components})
|
||||||
|
else:
|
||||||
|
raise AppException(
|
||||||
|
message="Service not ready",
|
||||||
|
code=ErrorCode.SERVICE_UNAVAILABLE,
|
||||||
|
status_code=503,
|
||||||
|
details={"status": "not ready", "components": components}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/metrics")
|
||||||
|
async def prometheus_metrics():
|
||||||
|
"""Prometheus metrics endpoint
|
||||||
|
|
||||||
|
导出 Prometheus 格式的监控指标
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Prometheus 格式的指标数据
|
||||||
|
"""
|
||||||
|
# 更新 system and database metrics before returning
|
||||||
|
from src.middlewares.metrics import update_system_metrics, update_database_metrics
|
||||||
|
update_system_metrics()
|
||||||
|
update_database_metrics()
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
content=generate_latest(),
|
||||||
|
media_type=CONTENT_TYPE_LATEST
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/metrics/tasks", response_model=ResponseModel)
|
||||||
|
async def task_metrics():
|
||||||
|
"""任务管理器指标
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
任务管理器的详细统计信息
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
stats = task_manager.get_stats()
|
||||||
|
return ResponseModel(data=stats)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get task metrics: {e}", exc_info=True)
|
||||||
|
return ResponseModel(
|
||||||
|
code=500,
|
||||||
|
message="Failed to get task metrics",
|
||||||
|
data={"error": str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/metrics/models", response_model=ResponseModel)
|
||||||
|
async def model_metrics():
|
||||||
|
"""模型服务指标
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
所有注册模型的健康状态和统计信息
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 获取所有模型
|
||||||
|
models = ModelRegistry.list_models()
|
||||||
|
|
||||||
|
# 获取健康状态
|
||||||
|
health_summary = health_monitor.get_health_summary()
|
||||||
|
all_health = health_monitor.get_all_health()
|
||||||
|
|
||||||
|
# 构建响应
|
||||||
|
model_stats = []
|
||||||
|
for model_id, config in models.items():
|
||||||
|
health = all_health.get(model_id)
|
||||||
|
model_stat = {
|
||||||
|
"id": model_id,
|
||||||
|
"name": config.get("name"),
|
||||||
|
"type": config.get("type"),
|
||||||
|
"provider": config.get("provider"),
|
||||||
|
"enabled": config.get("enabled", True)
|
||||||
|
}
|
||||||
|
|
||||||
|
if health:
|
||||||
|
model_stat["health"] = {
|
||||||
|
"status": health.status.value,
|
||||||
|
"success_rate": health.get_success_rate(),
|
||||||
|
"avg_latency_ms": health.avg_latency_ms,
|
||||||
|
"total_checks": health.total_checks,
|
||||||
|
"total_failures": health.total_failures
|
||||||
|
}
|
||||||
|
|
||||||
|
model_stats.append(model_stat)
|
||||||
|
|
||||||
|
return ResponseModel(data={
|
||||||
|
"summary": health_summary,
|
||||||
|
"models": model_stats
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get model metrics: {e}", exc_info=True)
|
||||||
|
return ResponseModel(
|
||||||
|
code=500,
|
||||||
|
message="Failed to get model metrics",
|
||||||
|
data={"error": str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/debug/info", response_model=ResponseModel)
|
||||||
|
async def debug_info():
|
||||||
|
"""调试信息
|
||||||
|
|
||||||
|
提供系统配置和运行时信息(仅用于开发环境)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
系统调试信息
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
import platform
|
||||||
|
from src.config.settings import NODE_ENV, STORAGE_TYPE, REDIS_ENABLED
|
||||||
|
|
||||||
|
return ResponseModel(data={
|
||||||
|
"environment": NODE_ENV,
|
||||||
|
"python_version": sys.version,
|
||||||
|
"platform": platform.platform(),
|
||||||
|
"storage_type": STORAGE_TYPE,
|
||||||
|
"redis_enabled": REDIS_ENABLED,
|
||||||
|
"uptime_seconds": int(time.time() - _start_time),
|
||||||
|
"start_time": datetime.fromtimestamp(_start_time).isoformat()
|
||||||
|
})
|
||||||
29
backend/src/api/projects/__init__.py
Normal file
29
backend/src/api/projects/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
"""
|
||||||
|
Projects API Package
|
||||||
|
|
||||||
|
项目 API 模块,包含:
|
||||||
|
- core: 项目核心 CRUD 功能
|
||||||
|
- episodes: 剧集管理
|
||||||
|
- assets: 资产管理
|
||||||
|
- storyboards: 分镜管理
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from .core import router as core_router
|
||||||
|
from .episodes import router as episodes_router
|
||||||
|
from .assets import router as assets_router
|
||||||
|
from .storyboards import router as storyboards_router
|
||||||
|
|
||||||
|
# 创建主路由器
|
||||||
|
router = APIRouter(tags=["projects"])
|
||||||
|
|
||||||
|
# 包含所有子路由
|
||||||
|
router.include_router(core_router)
|
||||||
|
router.include_router(episodes_router)
|
||||||
|
router.include_router(assets_router)
|
||||||
|
router.include_router(storyboards_router)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"router",
|
||||||
|
]
|
||||||
305
backend/src/api/projects/assets.py
Normal file
305
backend/src/api/projects/assets.py
Normal file
@@ -0,0 +1,305 @@
|
|||||||
|
"""
|
||||||
|
Projects API - Assets Module
|
||||||
|
|
||||||
|
包含资产管理相关的 API 路由。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
from src.models.schemas import (
|
||||||
|
ResponseModel,
|
||||||
|
Asset, CreateAssetRequest, UpdateAssetRequest,
|
||||||
|
CharacterAsset, SceneAsset, PropAsset, OtherAsset,
|
||||||
|
GenerationRecord,
|
||||||
|
)
|
||||||
|
from src.auth.dependencies import get_current_user
|
||||||
|
from src.auth.models import UserAuth
|
||||||
|
from src.services.project_service import project_manager
|
||||||
|
from src.services.asset_deduplication_service import deduplicate_project_assets as deduplicate_assets_service
|
||||||
|
from src.utils.errors import ErrorCode, AppException
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(tags=["projects-assets"])
|
||||||
|
|
||||||
|
|
||||||
|
def _check_project_access(project_id: str, user_id: str):
|
||||||
|
"""检查用户是否有权限访问项目"""
|
||||||
|
project = project_manager.get_project(project_id, user_id=user_id)
|
||||||
|
return project
|
||||||
|
|
||||||
|
|
||||||
|
# --- Asset Management ---
|
||||||
|
|
||||||
|
@router.post("/projects/{project_id}/assets", response_model=ResponseModel)
|
||||||
|
async def create_asset(
|
||||||
|
project_id: str,
|
||||||
|
request: dict,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a new asset.
|
||||||
|
Accepts a dictionary, validates based on 'type' field.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check project access
|
||||||
|
if not _check_project_access(project_id, current_user.id):
|
||||||
|
raise AppException(
|
||||||
|
message="Project not found or access denied",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
asset_type = request.get('type', 'other')
|
||||||
|
if 'id' not in request:
|
||||||
|
request['id'] = str(uuid.uuid4())
|
||||||
|
|
||||||
|
asset = None
|
||||||
|
if asset_type == 'character':
|
||||||
|
asset = TypeAdapter(CharacterAsset).validate_python(request)
|
||||||
|
elif asset_type == 'scene':
|
||||||
|
asset = TypeAdapter(SceneAsset).validate_python(request)
|
||||||
|
elif asset_type == 'prop':
|
||||||
|
asset = TypeAdapter(PropAsset).validate_python(request)
|
||||||
|
else:
|
||||||
|
asset = TypeAdapter(OtherAsset).validate_python(request)
|
||||||
|
|
||||||
|
updated_project = project_manager.add_asset(project_id, asset)
|
||||||
|
if not updated_project:
|
||||||
|
raise AppException(
|
||||||
|
message="Project not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
return ResponseModel(data=updated_project)
|
||||||
|
except Exception as e:
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.put("/projects/{project_id}/assets/{asset_id}", response_model=ResponseModel)
|
||||||
|
async def update_asset(project_id: str, asset_id: str, request: UpdateAssetRequest):
|
||||||
|
"""Update an asset"""
|
||||||
|
try:
|
||||||
|
project = project_manager.get_project(project_id)
|
||||||
|
if not project:
|
||||||
|
raise AppException(
|
||||||
|
message="Project not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
existing_asset = next((a for a in project.assets if a.id == asset_id), None)
|
||||||
|
if not existing_asset:
|
||||||
|
raise AppException(
|
||||||
|
message="Asset not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
# Merge updates
|
||||||
|
update_data = request.model_dump(exclude_unset=True)
|
||||||
|
updated_asset = existing_asset.model_copy(update=update_data)
|
||||||
|
|
||||||
|
result = project_manager.update_asset(project_id, asset_id, updated_asset)
|
||||||
|
return ResponseModel(data=result)
|
||||||
|
except Exception as e:
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.delete("/projects/{project_id}/assets/{asset_id}", response_model=ResponseModel)
|
||||||
|
async def delete_asset(project_id: str, asset_id: str):
|
||||||
|
"""Delete an asset"""
|
||||||
|
try:
|
||||||
|
result = project_manager.delete_asset(project_id, asset_id)
|
||||||
|
if not result:
|
||||||
|
raise AppException(
|
||||||
|
message="Project not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
return ResponseModel(data=result)
|
||||||
|
except Exception as e:
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/projects/{project_id}/assets/deduplicate", response_model=ResponseModel)
|
||||||
|
async def deduplicate_project_assets(
|
||||||
|
project_id: str,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Deduplicate all assets in the project using LLM; update storyboard references.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
updated_project = await deduplicate_assets_service(project_id, current_user.id)
|
||||||
|
return ResponseModel(data=updated_project)
|
||||||
|
except ValueError as e:
|
||||||
|
if "not found" in str(e).lower():
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Asset deduplication failed: %s", e, exc_info=True)
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Asset Management ---
|
||||||
|
@router.get("/projects/{project_id}/assets", response_model=ResponseModel)
|
||||||
|
async def list_project_assets(
|
||||||
|
project_id: str,
|
||||||
|
asset_type: Optional[str] = Query(None, description="Filter by asset type"),
|
||||||
|
search_query: Optional[str] = Query(None, description="Search by name or description"),
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0
|
||||||
|
):
|
||||||
|
"""List assets for a project with optional filtering and pagination"""
|
||||||
|
try:
|
||||||
|
result = project_manager.list_assets(project_id, asset_type, search_query, limit, offset)
|
||||||
|
if result is None:
|
||||||
|
raise AppException(
|
||||||
|
message="Project not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
return ResponseModel(data=result)
|
||||||
|
except Exception as e:
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post("/projects/{project_id}/assets", response_model=ResponseModel)
|
||||||
|
async def create_asset(project_id: str, request: CreateAssetRequest):
|
||||||
|
"""Create a new asset"""
|
||||||
|
try:
|
||||||
|
asset_data = request.model_dump(by_alias=True)
|
||||||
|
asset_data['id'] = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Validate and convert to Asset Union
|
||||||
|
new_asset = TypeAdapter(Asset).validate_python(asset_data)
|
||||||
|
|
||||||
|
updated_project = project_manager.add_asset(project_id, new_asset)
|
||||||
|
if not updated_project:
|
||||||
|
raise AppException(
|
||||||
|
message="Project not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
return ResponseModel(data=updated_project)
|
||||||
|
except Exception as e:
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post("/projects/{project_id}/assets/batch", response_model=ResponseModel)
|
||||||
|
async def batch_create_assets(project_id: str, request: List[CreateAssetRequest]):
|
||||||
|
"""Batch create assets"""
|
||||||
|
try:
|
||||||
|
new_assets = []
|
||||||
|
for asset_req in request:
|
||||||
|
asset_data = asset_req.model_dump(by_alias=True)
|
||||||
|
asset_data['id'] = str(uuid.uuid4())
|
||||||
|
# Validate and convert to Asset Union
|
||||||
|
new_asset = TypeAdapter(Asset).validate_python(asset_data)
|
||||||
|
new_assets.append(new_asset)
|
||||||
|
|
||||||
|
updated_project = project_manager.batch_add_assets(project_id, new_assets)
|
||||||
|
if not updated_project:
|
||||||
|
raise AppException(
|
||||||
|
message="Project not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
return ResponseModel(data=updated_project)
|
||||||
|
except Exception as e:
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.put("/projects/{project_id}/assets/{asset_id}", response_model=ResponseModel)
|
||||||
|
async def update_asset(project_id: str, asset_id: str, request: UpdateAssetRequest):
|
||||||
|
"""Update an asset"""
|
||||||
|
try:
|
||||||
|
project = project_manager.get_project(project_id)
|
||||||
|
if not project:
|
||||||
|
raise AppException(
|
||||||
|
message="Project not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
existing_asset = next((a for a in project.assets if a.id == asset_id), None)
|
||||||
|
if not existing_asset:
|
||||||
|
raise AppException(
|
||||||
|
message="Asset not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
update_data = request.model_dump(exclude_unset=True)
|
||||||
|
|
||||||
|
# Manually validate generations if present to ensure they are model instances not dicts
|
||||||
|
if 'generations' in update_data and update_data['generations']:
|
||||||
|
update_data['generations'] = TypeAdapter(List[GenerationRecord]).validate_python(update_data['generations'])
|
||||||
|
|
||||||
|
updated_asset = existing_asset.model_copy(update=update_data)
|
||||||
|
|
||||||
|
updated_project = project_manager.update_asset(project_id, asset_id, updated_asset)
|
||||||
|
return ResponseModel(data=updated_project)
|
||||||
|
except Exception as e:
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.delete("/projects/{project_id}/assets/{asset_id}", response_model=ResponseModel)
|
||||||
|
async def delete_asset(project_id: str, asset_id: str):
|
||||||
|
"""Delete an asset"""
|
||||||
|
try:
|
||||||
|
updated_project = project_manager.delete_asset(project_id, asset_id)
|
||||||
|
if not updated_project:
|
||||||
|
raise AppException(
|
||||||
|
message="Project not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
return ResponseModel(data=updated_project)
|
||||||
|
except Exception as e:
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
367
backend/src/api/projects/core.py
Normal file
367
backend/src/api/projects/core.py
Normal file
@@ -0,0 +1,367 @@
|
|||||||
|
"""
|
||||||
|
Projects API - Core Module
|
||||||
|
|
||||||
|
包含项目核心 CRUD 相关的 API 路由。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, BackgroundTasks, Query, Request, Depends
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
from src.models.schemas import (
|
||||||
|
ResponseModel,
|
||||||
|
CreateProjectRequest, UpdateProjectRequest,
|
||||||
|
InitializeProjectRequest,
|
||||||
|
)
|
||||||
|
from src.auth.dependencies import get_current_user
|
||||||
|
from src.auth.models import UserAuth
|
||||||
|
from src.services.project_service import project_manager
|
||||||
|
from src.services.project_initialization_service import run_initialization_pipeline
|
||||||
|
from src.services.script_project_initialization_service import run_script_project_initialization
|
||||||
|
from src.services.script_to_canvas_service import convert_script_project_to_canvas
|
||||||
|
from src.utils.errors import ErrorCode, AppException
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(tags=["projects"])
|
||||||
|
|
||||||
|
|
||||||
|
def _check_project_access(project_id: str, user_id: str) -> Optional:
|
||||||
|
"""检查用户是否有权限访问项目"""
|
||||||
|
project = project_manager.get_project(project_id, user_id=user_id)
|
||||||
|
return project
|
||||||
|
|
||||||
|
|
||||||
|
def _progress_callback(project_id: str):
|
||||||
|
def _cb(step: str, percentage: int, message: str, details: Optional[dict] = None):
|
||||||
|
try:
|
||||||
|
project_manager.update_project(
|
||||||
|
project_id,
|
||||||
|
{
|
||||||
|
"progress": {
|
||||||
|
"current_step": step,
|
||||||
|
"percentage": percentage,
|
||||||
|
"message": message,
|
||||||
|
"details": details or {},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to update progress: %s", e)
|
||||||
|
return _cb
|
||||||
|
|
||||||
|
|
||||||
|
# --- 项目管理 ---
|
||||||
|
async def _create_initializing_project(
|
||||||
|
request: InitializeProjectRequest,
|
||||||
|
current_user: UserAuth,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
project_type: str,
|
||||||
|
):
|
||||||
|
project = project_manager.create_project(
|
||||||
|
name=request.name,
|
||||||
|
description="正在从小说初始化...",
|
||||||
|
type=project_type,
|
||||||
|
chapters=[],
|
||||||
|
assets=[],
|
||||||
|
status="initializing",
|
||||||
|
user_id=current_user.id
|
||||||
|
)
|
||||||
|
progress_cb = _progress_callback(project.id)
|
||||||
|
if project_type == "script":
|
||||||
|
background_tasks.add_task(
|
||||||
|
run_script_project_initialization,
|
||||||
|
project.id,
|
||||||
|
request.novel_text,
|
||||||
|
progress_cb,
|
||||||
|
current_user.id,
|
||||||
|
request.model,
|
||||||
|
request.provider,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
background_tasks.add_task(
|
||||||
|
run_initialization_pipeline,
|
||||||
|
project.id,
|
||||||
|
request.novel_text,
|
||||||
|
request.style,
|
||||||
|
progress_cb,
|
||||||
|
current_user.id,
|
||||||
|
)
|
||||||
|
return project
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/projects/initialize-from-novel", response_model=ResponseModel)
|
||||||
|
@router.post("/projects/init-from-novel", response_model=ResponseModel)
|
||||||
|
@router.post("/projects/initialize-canvas-from-novel", response_model=ResponseModel)
|
||||||
|
async def initialize_canvas_project_from_novel(
|
||||||
|
request: InitializeProjectRequest,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
使用多智能体管道从小说文本初始化项目(后台任务)
|
||||||
|
步骤: 1. 立即创建项目(草稿) 2. 触发后台任务进行完整初始化
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
project = await _create_initializing_project(request, current_user, background_tasks, "canvas")
|
||||||
|
return ResponseModel(data=project)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error initializing project: %s", e, exc_info=True)
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/projects/initialize-script-from-novel", response_model=ResponseModel)
|
||||||
|
async def initialize_script_project_from_novel(
|
||||||
|
request: InitializeProjectRequest,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""使用多 Agent 改编流程从小说文本初始化剧本项目。"""
|
||||||
|
try:
|
||||||
|
project = await _create_initializing_project(request, current_user, background_tasks, "script")
|
||||||
|
return ResponseModel(data=project)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error initializing script project: %s", e, exc_info=True)
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post("/projects", response_model=ResponseModel)
|
||||||
|
async def create_project(
|
||||||
|
request: CreateProjectRequest,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""Create a new project"""
|
||||||
|
try:
|
||||||
|
project = project_manager.create_project(
|
||||||
|
name=request.name,
|
||||||
|
description=request.description,
|
||||||
|
type=request.type,
|
||||||
|
chapters=request.chapters,
|
||||||
|
assets=request.assets,
|
||||||
|
user_id=current_user.id
|
||||||
|
)
|
||||||
|
return ResponseModel(data=project)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating project: {e}", exc_info=True)
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/projects", response_model=ResponseModel)
|
||||||
|
async def list_projects(
|
||||||
|
request: Request,
|
||||||
|
page: int = Query(1, ge=1, description="页码,从1开始"),
|
||||||
|
page_size: int = Query(20, ge=1, le=100, description="每页数量,最大100"),
|
||||||
|
sort: Optional[str] = Query(None, description="排序字段,格式: field:asc 或 field:desc"),
|
||||||
|
filter: Optional[str] = Query(None, description="过滤条件,JSON格式"),
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
List all projects with pagination support.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
page: 页码,从1开始
|
||||||
|
page_size: 每页数量 (1-100, 默认: 20)
|
||||||
|
sort: 排序字段,格式: "field:asc" 或 "field:desc"
|
||||||
|
filter: 过滤条件,JSON格式
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
分页的项目列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from src.models.schemas import PaginationParams
|
||||||
|
from src.utils.pagination import Paginator
|
||||||
|
|
||||||
|
# 创建分页参数
|
||||||
|
pagination_params = PaginationParams(
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
sort=sort,
|
||||||
|
filter=filter
|
||||||
|
)
|
||||||
|
|
||||||
|
# 计算偏移量
|
||||||
|
offset = pagination_params.get_offset()
|
||||||
|
limit = pagination_params.get_limit()
|
||||||
|
|
||||||
|
# 获取当前用户ID,只返回该用户的项目
|
||||||
|
user_id = current_user.id if current_user else None
|
||||||
|
|
||||||
|
# 获取项目列表(按用户过滤)
|
||||||
|
projects = project_manager.list_projects(limit=limit, offset=offset, user_id=user_id)
|
||||||
|
|
||||||
|
# 获取总数(按用户过滤)
|
||||||
|
total = project_manager.count_projects(user_id=user_id)
|
||||||
|
|
||||||
|
# 创建分页器
|
||||||
|
paginator = Paginator(
|
||||||
|
items=projects,
|
||||||
|
total=total,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# 返回分页响应
|
||||||
|
return paginator.to_response(request)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/projects/{project_id}", response_model=ResponseModel)
|
||||||
|
async def get_project(
|
||||||
|
project_id: str,
|
||||||
|
include_assets: bool = Query(default=True, description="Whether to include assets in the response"),
|
||||||
|
include_referenced_assets: bool = Query(default=False, description="Whether to include only referenced assets (if include_assets is False)"),
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""Get project details"""
|
||||||
|
project = project_manager.get_project(
|
||||||
|
project_id,
|
||||||
|
include_assets=include_assets,
|
||||||
|
include_referenced_assets=include_referenced_assets,
|
||||||
|
user_id=current_user.id
|
||||||
|
)
|
||||||
|
if not project:
|
||||||
|
raise AppException(
|
||||||
|
message="Project not found or access denied",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
return ResponseModel(data=project)
|
||||||
|
|
||||||
|
@router.put("/projects/{project_id}", response_model=ResponseModel)
|
||||||
|
async def update_project(
|
||||||
|
project_id: str,
|
||||||
|
request: UpdateProjectRequest,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""Update a project"""
|
||||||
|
try:
|
||||||
|
# First check if project exists and belongs to user
|
||||||
|
existing = project_manager.get_project(project_id, user_id=current_user.id)
|
||||||
|
if not existing:
|
||||||
|
raise AppException(
|
||||||
|
message="Project not found or access denied",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
update_data = request.model_dump(exclude_unset=True)
|
||||||
|
updated_project = project_manager.update_project(project_id, update_data)
|
||||||
|
if not updated_project:
|
||||||
|
raise AppException(
|
||||||
|
message="Project not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
return ResponseModel(data=updated_project)
|
||||||
|
except Exception as e:
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.delete("/projects/{project_id}", response_model=ResponseModel)
|
||||||
|
async def delete_project(
|
||||||
|
project_id: str,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""Delete a project"""
|
||||||
|
try:
|
||||||
|
# First check if project exists and belongs to user
|
||||||
|
existing = project_manager.get_project(project_id, user_id=current_user.id)
|
||||||
|
if not existing:
|
||||||
|
raise AppException(
|
||||||
|
message="Project not found or access denied",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
project_manager.delete_project(project_id)
|
||||||
|
return ResponseModel(data={"id": project_id, "deleted": True})
|
||||||
|
except Exception as e:
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/projects/{project_id}/convert-to-canvas", response_model=ResponseModel)
|
||||||
|
async def convert_project_to_canvas(
|
||||||
|
project_id: str,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""将剧本项目转换为独立的画布项目。"""
|
||||||
|
try:
|
||||||
|
source_project = project_manager.get_project(project_id, user_id=current_user.id)
|
||||||
|
if not source_project:
|
||||||
|
raise AppException(
|
||||||
|
message="Project not found or access denied",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
if source_project.type != "script":
|
||||||
|
raise AppException(
|
||||||
|
message="Only script projects can be converted to canvas projects",
|
||||||
|
code=ErrorCode.INVALID_PARAMETER,
|
||||||
|
status_code=400
|
||||||
|
)
|
||||||
|
|
||||||
|
chapters = [
|
||||||
|
{
|
||||||
|
"id": episode.id,
|
||||||
|
"title": episode.title,
|
||||||
|
"order": episode.order,
|
||||||
|
"content": episode.content,
|
||||||
|
"summary": episode.desc,
|
||||||
|
"status": episode.status,
|
||||||
|
}
|
||||||
|
for episode in (source_project.episodes or [])
|
||||||
|
]
|
||||||
|
|
||||||
|
target_project = project_manager.create_project(
|
||||||
|
name=f"{source_project.name} - 画布",
|
||||||
|
description="正在从剧本项目生成画布...",
|
||||||
|
type="canvas",
|
||||||
|
chapters=chapters,
|
||||||
|
assets=[],
|
||||||
|
status="initializing",
|
||||||
|
user_id=current_user.id
|
||||||
|
)
|
||||||
|
progress_cb = _progress_callback(target_project.id)
|
||||||
|
background_tasks.add_task(
|
||||||
|
convert_script_project_to_canvas,
|
||||||
|
source_project.id,
|
||||||
|
target_project.id,
|
||||||
|
progress_cb,
|
||||||
|
current_user.id,
|
||||||
|
)
|
||||||
|
return ResponseModel(data=target_project)
|
||||||
|
except AppException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error converting project to canvas: %s", e, exc_info=True)
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
179
backend/src/api/projects/episodes.py
Normal file
179
backend/src/api/projects/episodes.py
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
"""
|
||||||
|
Projects API - Episodes Module
|
||||||
|
|
||||||
|
包含剧集管理相关的 API 路由。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
from src.models.schemas import (
|
||||||
|
ResponseModel,
|
||||||
|
Episode, CreateEpisodeRequest, UpdateEpisodeRequest,
|
||||||
|
)
|
||||||
|
from src.auth.dependencies import get_current_user
|
||||||
|
from src.auth.models import UserAuth
|
||||||
|
from src.services.project_service import project_manager
|
||||||
|
from src.utils.errors import ErrorCode, AppException
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(tags=["projects-episodes"])
|
||||||
|
|
||||||
|
|
||||||
|
def _check_project_access(project_id: str, user_id: str):
|
||||||
|
"""检查用户是否有权限访问项目"""
|
||||||
|
project = project_manager.get_project(project_id, user_id=user_id)
|
||||||
|
return project
|
||||||
|
|
||||||
|
|
||||||
|
# --- Episode Management ---
|
||||||
|
@router.post("/projects/{project_id}/episodes", response_model=ResponseModel)
|
||||||
|
async def create_episode(
|
||||||
|
project_id: str,
|
||||||
|
request: CreateEpisodeRequest,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""Create a new episode"""
|
||||||
|
try:
|
||||||
|
# Check project access
|
||||||
|
if not _check_project_access(project_id, current_user.id):
|
||||||
|
raise AppException(
|
||||||
|
message="Project not found or access denied",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
episode_id = str(uuid.uuid4())
|
||||||
|
new_episode = Episode(
|
||||||
|
id=episode_id,
|
||||||
|
title=request.title,
|
||||||
|
order=request.order,
|
||||||
|
desc=request.desc,
|
||||||
|
status=request.status
|
||||||
|
)
|
||||||
|
updated_project = project_manager.add_episode(project_id, new_episode)
|
||||||
|
if not updated_project:
|
||||||
|
raise AppException(
|
||||||
|
message="Project not found or operation failed",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
return ResponseModel(data=updated_project)
|
||||||
|
except Exception as e:
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.put("/projects/{project_id}/episodes/{episode_id}", response_model=ResponseModel)
|
||||||
|
async def update_episode(
|
||||||
|
project_id: str,
|
||||||
|
episode_id: str,
|
||||||
|
request: UpdateEpisodeRequest,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""Update an episode"""
|
||||||
|
try:
|
||||||
|
project = project_manager.get_project(project_id, user_id=current_user.id)
|
||||||
|
if not project:
|
||||||
|
raise AppException(
|
||||||
|
message="Project not found or access denied",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
existing_ep = next((ep for ep in project.episodes if ep.id == episode_id), None)
|
||||||
|
if not existing_ep:
|
||||||
|
raise AppException(
|
||||||
|
message="Episode not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
update_data = request.model_dump(exclude_unset=True, by_alias=True)
|
||||||
|
updated_ep = existing_ep.model_copy(update=update_data)
|
||||||
|
|
||||||
|
updated_project = project_manager.update_episode(project_id, episode_id, updated_ep)
|
||||||
|
return ResponseModel(data=updated_project)
|
||||||
|
except Exception as e:
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.delete("/projects/{project_id}/episodes/{episode_id}", response_model=ResponseModel)
|
||||||
|
async def delete_episode(
|
||||||
|
project_id: str,
|
||||||
|
episode_id: str,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""Delete an episode"""
|
||||||
|
try:
|
||||||
|
project = project_manager.get_project(project_id, user_id=current_user.id)
|
||||||
|
if not project:
|
||||||
|
raise AppException(
|
||||||
|
message="Project not found or access denied",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_project = project_manager.delete_episode(project_id, episode_id)
|
||||||
|
if not updated_project:
|
||||||
|
raise AppException(
|
||||||
|
message="Project not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
return ResponseModel(data=updated_project)
|
||||||
|
except Exception as e:
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/projects/{project_id}/episodes/{episode_id}/analyze", response_model=ResponseModel)
|
||||||
|
async def analyze_episode(
|
||||||
|
project_id: str,
|
||||||
|
episode_id: str,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Analyze a specific episode: extract summary, characters/scenes/props, storyboards.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from src.services.episode_analysis_service import analyze_episode as analyze_episode_service
|
||||||
|
final_project = await analyze_episode_service(project_id, episode_id, current_user.id)
|
||||||
|
return ResponseModel(data=final_project)
|
||||||
|
except ValueError as e:
|
||||||
|
msg = str(e)
|
||||||
|
if "not found" in msg.lower():
|
||||||
|
raise AppException(
|
||||||
|
message=msg,
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
if "empty" in msg.lower():
|
||||||
|
raise AppException(
|
||||||
|
message=msg,
|
||||||
|
code=ErrorCode.INVALID_PARAMETER,
|
||||||
|
status_code=400
|
||||||
|
)
|
||||||
|
raise AppException(
|
||||||
|
message=msg,
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Episode analysis failed: %s", e, exc_info=True)
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
113
backend/src/api/projects/storyboards.py
Normal file
113
backend/src/api/projects/storyboards.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
"""
|
||||||
|
Projects API - Storyboards Module
|
||||||
|
|
||||||
|
包含分镜管理相关的 API 路由。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
|
from src.models.schemas import (
|
||||||
|
ResponseModel,
|
||||||
|
Storyboard, CreateStoryboardRequest, UpdateStoryboardRequest,
|
||||||
|
GenerationRecord,
|
||||||
|
)
|
||||||
|
from src.utils.errors import ErrorCode, AppException
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(tags=["projects-storyboards"])
|
||||||
|
|
||||||
|
|
||||||
|
# --- Storyboard Management ---
|
||||||
|
@router.post("/projects/{project_id}/storyboards", response_model=ResponseModel)
|
||||||
|
async def create_storyboard(project_id: str, request: CreateStoryboardRequest):
|
||||||
|
"""Create a new storyboard"""
|
||||||
|
try:
|
||||||
|
from src.services.project_service import project_manager
|
||||||
|
|
||||||
|
sb_id = str(uuid.uuid4())
|
||||||
|
sb_data = request.model_dump(by_alias=True)
|
||||||
|
sb_data['id'] = sb_id
|
||||||
|
|
||||||
|
new_sb = Storyboard(**sb_data)
|
||||||
|
|
||||||
|
updated_project = project_manager.add_storyboard(project_id, new_sb)
|
||||||
|
if not updated_project:
|
||||||
|
raise AppException(
|
||||||
|
message="Project not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
return ResponseModel(data=updated_project)
|
||||||
|
except Exception as e:
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.put("/projects/{project_id}/storyboards/{storyboard_id}", response_model=ResponseModel)
|
||||||
|
async def update_storyboard(project_id: str, storyboard_id: str, request: UpdateStoryboardRequest):
|
||||||
|
"""Update a storyboard"""
|
||||||
|
try:
|
||||||
|
from src.services.project_service import project_manager
|
||||||
|
|
||||||
|
project = project_manager.get_project(project_id)
|
||||||
|
if not project:
|
||||||
|
raise AppException(
|
||||||
|
message="Project not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
existing_sb = next((sb for sb in project.storyboards if sb.id == storyboard_id), None)
|
||||||
|
if not existing_sb:
|
||||||
|
raise AppException(
|
||||||
|
message="Storyboard not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
update_data = request.model_dump(exclude_unset=True, by_alias=True)
|
||||||
|
|
||||||
|
# Manually validate generations if present to ensure they are model instances not dicts
|
||||||
|
if 'generations' in update_data and update_data['generations']:
|
||||||
|
update_data['generations'] = TypeAdapter(List[GenerationRecord]).validate_python(update_data['generations'])
|
||||||
|
|
||||||
|
updated_sb = existing_sb.model_copy(update=update_data)
|
||||||
|
|
||||||
|
updated_project = project_manager.update_storyboard(project_id, storyboard_id, updated_sb)
|
||||||
|
return ResponseModel(data=updated_project)
|
||||||
|
except Exception as e:
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.delete("/projects/{project_id}/storyboards/{storyboard_id}", response_model=ResponseModel)
|
||||||
|
async def delete_storyboard(project_id: str, storyboard_id: str):
|
||||||
|
"""Delete a storyboard"""
|
||||||
|
try:
|
||||||
|
from src.services.project_service import project_manager
|
||||||
|
|
||||||
|
updated_project = project_manager.delete_storyboard(project_id, storyboard_id)
|
||||||
|
if not updated_project:
|
||||||
|
raise AppException(
|
||||||
|
message="Project not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
return ResponseModel(data=updated_project)
|
||||||
|
except Exception as e:
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
240
backend/src/api/prompt_templates.py
Normal file
240
backend/src/api/prompt_templates.py
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
"""
|
||||||
|
Prompt Template API - 提示词模板接口
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
|
||||||
|
from src.models.schemas import ResponseModel
|
||||||
|
from src.models.prompt_template import (
|
||||||
|
PromptTemplateCreate,
|
||||||
|
PromptTemplateUpdate,
|
||||||
|
PromptTemplateCategory
|
||||||
|
)
|
||||||
|
from src.services.prompt_template_service import PromptTemplateService
|
||||||
|
from src.auth.dependencies import get_current_user
|
||||||
|
from src.auth.models import UserAuth
|
||||||
|
from src.utils.errors import AppException, ErrorCode
|
||||||
|
|
||||||
|
router = APIRouter(tags=["prompt-templates"])
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/prompt-templates/init", response_model=ResponseModel)
|
||||||
|
async def init_templates(current_user: UserAuth = Depends(get_current_user)):
|
||||||
|
"""初始化系统默认模板(仅管理员)"""
|
||||||
|
# 这里可以添加管理员权限检查
|
||||||
|
PromptTemplateService.init_default_templates()
|
||||||
|
return ResponseModel(message="Templates initialized successfully")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/prompt-templates", response_model=ResponseModel)
|
||||||
|
async def list_templates(
|
||||||
|
category: Optional[str] = Query(None, description="分类过滤"),
|
||||||
|
target_type: Optional[str] = Query(None, description="目标类型过滤: image/video/audio/music"),
|
||||||
|
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||||
|
favorites_only: bool = Query(False, description="仅显示收藏"),
|
||||||
|
page: int = Query(1, ge=1, description="页码"),
|
||||||
|
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""获取提示词模板列表"""
|
||||||
|
try:
|
||||||
|
result = PromptTemplateService.list_templates(
|
||||||
|
user_id=current_user.id,
|
||||||
|
category=category,
|
||||||
|
target_type=target_type,
|
||||||
|
search=search,
|
||||||
|
favorites_only=favorites_only,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size
|
||||||
|
)
|
||||||
|
return ResponseModel(data=result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error listing templates: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message="Failed to list templates",
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/prompt-templates/categories", response_model=ResponseModel)
|
||||||
|
async def get_categories(
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""获取模板分类列表"""
|
||||||
|
categories = PromptTemplateService.get_categories()
|
||||||
|
return ResponseModel(data=categories)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/prompt-templates", response_model=ResponseModel)
|
||||||
|
async def create_template(
|
||||||
|
data: PromptTemplateCreate,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""创建新模板"""
|
||||||
|
try:
|
||||||
|
template = PromptTemplateService.create_template(
|
||||||
|
user_id=current_user.id,
|
||||||
|
data=data
|
||||||
|
)
|
||||||
|
return ResponseModel(
|
||||||
|
message="Template created successfully",
|
||||||
|
data=template
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating template: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message="Failed to create template",
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/prompt-templates/{template_id}", response_model=ResponseModel)
|
||||||
|
async def get_template(
|
||||||
|
template_id: str,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""获取单个模板详情"""
|
||||||
|
template = PromptTemplateService.get_template(
|
||||||
|
template_id=template_id,
|
||||||
|
user_id=current_user.id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not template:
|
||||||
|
raise AppException(
|
||||||
|
message="Template not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
return ResponseModel(data=template)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/prompt-templates/{template_id}", response_model=ResponseModel)
|
||||||
|
async def update_template(
|
||||||
|
template_id: str,
|
||||||
|
data: PromptTemplateUpdate,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""更新模板"""
|
||||||
|
try:
|
||||||
|
template = PromptTemplateService.update_template(
|
||||||
|
template_id=template_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
data=data
|
||||||
|
)
|
||||||
|
|
||||||
|
if not template:
|
||||||
|
raise AppException(
|
||||||
|
message="Template not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
message="Template updated successfully",
|
||||||
|
data=template
|
||||||
|
)
|
||||||
|
except PermissionError as e:
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.PERMISSION_DENIED,
|
||||||
|
status_code=403
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating template: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message="Failed to update template",
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/prompt-templates/{template_id}", response_model=ResponseModel)
|
||||||
|
async def delete_template(
|
||||||
|
template_id: str,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""删除模板"""
|
||||||
|
try:
|
||||||
|
success = PromptTemplateService.delete_template(
|
||||||
|
template_id=template_id,
|
||||||
|
user_id=current_user.id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
raise AppException(
|
||||||
|
message="Template not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
return ResponseModel(message="Template deleted successfully")
|
||||||
|
except PermissionError as e:
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.PERMISSION_DENIED,
|
||||||
|
status_code=403
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting template: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message="Failed to delete template",
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/prompt-templates/{template_id}/favorite", response_model=ResponseModel)
|
||||||
|
async def toggle_favorite(
|
||||||
|
template_id: str,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""切换模板收藏状态"""
|
||||||
|
try:
|
||||||
|
is_favorite = PromptTemplateService.toggle_favorite(
|
||||||
|
template_id=template_id,
|
||||||
|
user_id=current_user.id
|
||||||
|
)
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
message="Favorite toggled successfully",
|
||||||
|
data={"is_favorite": is_favorite}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error toggling favorite: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message="Failed to toggle favorite",
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/prompt-templates/{template_id}/apply", response_model=ResponseModel)
|
||||||
|
async def apply_template(
|
||||||
|
template_id: str,
|
||||||
|
user_prompt: str,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""应用模板到提示词"""
|
||||||
|
result = PromptTemplateService.apply_template(
|
||||||
|
template_id=template_id,
|
||||||
|
user_prompt=user_prompt
|
||||||
|
)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
raise AppException(
|
||||||
|
message="Template not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data={
|
||||||
|
"original": user_prompt,
|
||||||
|
"enhanced": result
|
||||||
|
}
|
||||||
|
)
|
||||||
195
backend/src/api/skills.py
Normal file
195
backend/src/api/skills.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
"""Agent Skills Management API
|
||||||
|
|
||||||
|
Provides endpoints for managing agent skills.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from src.services.agent_engine import AgentScopeService
|
||||||
|
from src.services.agent_engine.toolkit import ToolkitFactory
|
||||||
|
from src.models.schemas import ResponseModel
|
||||||
|
from src.utils.errors import ErrorCode, AppException
|
||||||
|
|
||||||
|
router = APIRouter(tags=["skills"])
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SkillInfo(BaseModel):
|
||||||
|
"""Skill information model"""
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
path: str
|
||||||
|
|
||||||
|
|
||||||
|
class RegisterSkillRequest(BaseModel):
|
||||||
|
"""Request model for registering a skill"""
|
||||||
|
skill_dir: str
|
||||||
|
|
||||||
|
|
||||||
|
class SkillPromptResponse(BaseModel):
|
||||||
|
"""Response model for skill prompt"""
|
||||||
|
prompt: Optional[str]
|
||||||
|
skills_count: int
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/skills", response_model=ResponseModel)
|
||||||
|
async def list_skills():
|
||||||
|
"""List all available agent skills.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of skill information
|
||||||
|
"""
|
||||||
|
skills = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 使用新的 ToolkitFactory
|
||||||
|
all_skills = ToolkitFactory.list_skills()
|
||||||
|
|
||||||
|
for skill_path in all_skills:
|
||||||
|
# skill_path 格式: "domain/skill_name"
|
||||||
|
domain, skill_name = skill_path.split('/')
|
||||||
|
|
||||||
|
# 读取 SKILL.md 获取描述
|
||||||
|
from pathlib import Path
|
||||||
|
# Fix path: services/agent_engine/skills (was services/agents/skills)
|
||||||
|
skill_dir = Path(__file__).parent.parent / "services" / "agent_engine" / "skills" / domain / skill_name
|
||||||
|
skill_md = skill_dir / "SKILL.md"
|
||||||
|
|
||||||
|
description = "No description available"
|
||||||
|
if skill_md.exists():
|
||||||
|
try:
|
||||||
|
with open(skill_md, 'r', encoding='utf-8') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# Extract description from YAML front matter
|
||||||
|
if content.startswith('---'):
|
||||||
|
parts = content.split('---', 2)
|
||||||
|
if len(parts) >= 3:
|
||||||
|
import yaml
|
||||||
|
try:
|
||||||
|
metadata = yaml.safe_load(parts[1])
|
||||||
|
description = metadata.get('description', description)
|
||||||
|
except yaml.YAMLError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
skills.append(SkillInfo(
|
||||||
|
name=skill_path, # 使用完整路径作为名称
|
||||||
|
description=description,
|
||||||
|
path=str(skill_dir)
|
||||||
|
))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to read skill {skill_path}: {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to list skills: {e}")
|
||||||
|
|
||||||
|
return ResponseModel(data=skills)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/skills/register", response_model=ResponseModel)
|
||||||
|
async def register_skill(request: RegisterSkillRequest):
|
||||||
|
"""Register a new agent skill.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Skill registration request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Success message
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If skill registration fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Validate skill directory
|
||||||
|
if not os.path.exists(request.skill_dir):
|
||||||
|
raise AppException(
|
||||||
|
message=f"Skill directory not found: {request.skill_dir}",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
skill_md = os.path.join(request.skill_dir, "SKILL.md")
|
||||||
|
if not os.path.exists(skill_md):
|
||||||
|
raise AppException(
|
||||||
|
message="SKILL.md file not found in skill directory",
|
||||||
|
code=ErrorCode.INVALID_PARAMETER,
|
||||||
|
status_code=400
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register skill (this would need to be implemented in AgentScopeService)
|
||||||
|
logger.info(f"Registered skill from: {request.skill_dir}")
|
||||||
|
|
||||||
|
return ResponseModel(data={"message": "Skill registered successfully", "path": request.skill_dir})
|
||||||
|
|
||||||
|
except AppException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to register skill: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/skills/prompt", response_model=ResponseModel)
|
||||||
|
async def get_skills_prompt():
|
||||||
|
"""Get the combined prompt for all registered skills.
|
||||||
|
|
||||||
|
This prompt can be attached to the agent's system prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined skills prompt
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 使用 ToolkitFactory
|
||||||
|
all_skills = ToolkitFactory.list_skills()
|
||||||
|
skills_count = len(all_skills)
|
||||||
|
|
||||||
|
# 获取 skills prompt
|
||||||
|
if skills_count > 0:
|
||||||
|
toolkit = ToolkitFactory.get_toolkit()
|
||||||
|
prompt = toolkit.get_agent_skill_prompt()
|
||||||
|
else:
|
||||||
|
prompt = None
|
||||||
|
|
||||||
|
return ResponseModel(data={"prompt": prompt, "skills_count": skills_count})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get skills prompt: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/skills/{skill_name}", response_model=ResponseModel)
|
||||||
|
async def remove_skill(skill_name: str):
|
||||||
|
"""Remove a registered agent skill.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skill_name: Name of the skill to remove
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Success message
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If skill removal fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# This would need to be implemented in AgentScopeService
|
||||||
|
logger.info(f"Removed skill: {skill_name}")
|
||||||
|
|
||||||
|
return ResponseModel(data={"message": f"Skill '{skill_name}' removed successfully"})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to remove skill: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
112
backend/src/api/storage.py
Normal file
112
backend/src/api/storage.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
from fastapi import APIRouter, Query, Response
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
import httpx
|
||||||
|
from src.models.schemas import ResponseModel
|
||||||
|
from src.services.storage_service import storage_manager
|
||||||
|
from src.utils.errors import InvalidParameterException, StorageException
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/storage", tags=["Storage"])
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@router.delete("/files", response_model=ResponseModel)
|
||||||
|
async def delete_file(path: str = Query(..., description="File path or URL to delete")):
|
||||||
|
""" 删除 a file from storage.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidParameterException: Path is invalid
|
||||||
|
StorageException: Failed to delete file
|
||||||
|
"""
|
||||||
|
# 参数 validation
|
||||||
|
if not path or not path.strip():
|
||||||
|
raise InvalidParameterException("path", "File path cannot be empty")
|
||||||
|
|
||||||
|
# Blob URLs are client-side only, nothing to delete on server
|
||||||
|
if path.startswith('blob:'):
|
||||||
|
return ResponseModel(data={"deleted": True, "path": path})
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 提取 key from URL if needed
|
||||||
|
key = path
|
||||||
|
if path.startswith(('http://', 'https://')):
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
parsed = urlparse(path)
|
||||||
|
key = parsed.path.lstrip('/')
|
||||||
|
|
||||||
|
# Handle local storage prefix /files/
|
||||||
|
if key.startswith('files/'):
|
||||||
|
key = key[6:]
|
||||||
|
|
||||||
|
success = storage_manager.delete(key)
|
||||||
|
if success:
|
||||||
|
return ResponseModel(data={"deleted": True, "path": key})
|
||||||
|
else:
|
||||||
|
# File doesn't exist or deletion failed
|
||||||
|
raise StorageException("delete", "File not found or deletion failed")
|
||||||
|
|
||||||
|
except StorageException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to delete file {path}: {e}", exc_info=True)
|
||||||
|
raise StorageException("delete", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/download")
|
||||||
|
async def download_file(url: str = Query(..., description="Original or signed media URL"),
|
||||||
|
filename: Optional[str] = Query(None, description="Optional download filename")):
|
||||||
|
"""
|
||||||
|
专用下载接口:
|
||||||
|
- 后端代理拉取远端资源
|
||||||
|
- 设置 Content-Disposition=attachment,避免浏览器直接在当前页打开视频/音频
|
||||||
|
"""
|
||||||
|
if not url or not url.strip():
|
||||||
|
raise InvalidParameterException("url", "Download url cannot be empty")
|
||||||
|
|
||||||
|
# 如果是 OSS 原始地址,可以在这里重新签名(视业务需要)
|
||||||
|
try:
|
||||||
|
signed_url = storage_manager.sign_url(url) or url
|
||||||
|
except Exception:
|
||||||
|
# 签名失败时退回原始 URL
|
||||||
|
signed_url = url
|
||||||
|
|
||||||
|
try:
|
||||||
|
client_timeout = httpx.Timeout(60.0)
|
||||||
|
async with httpx.AsyncClient(timeout=client_timeout, follow_redirects=True) as client:
|
||||||
|
upstream = await client.get(signed_url)
|
||||||
|
upstream.raise_for_status()
|
||||||
|
|
||||||
|
# 透传部分头部(例如 Content-Type)
|
||||||
|
content_type = upstream.headers.get("content-type", "application/octet-stream")
|
||||||
|
content_length = upstream.headers.get("content-length")
|
||||||
|
|
||||||
|
# 下载文件名:优先使用 query 参数,其次从 URL / 头部推断
|
||||||
|
final_name = filename
|
||||||
|
if not final_name:
|
||||||
|
# 从 URL path 中取最后一段
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
parsed = urlparse(url)
|
||||||
|
candidate = (parsed.path or "").rstrip("/").split("/")[-1] or "download"
|
||||||
|
final_name = candidate
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Disposition": f'attachment; filename="{final_name}"',
|
||||||
|
"Content-Type": content_type,
|
||||||
|
}
|
||||||
|
if content_length is not None:
|
||||||
|
headers["Content-Length"] = content_length
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
iter(upstream.iter_bytes()),
|
||||||
|
status_code=upstream.status_code,
|
||||||
|
headers=headers,
|
||||||
|
media_type=content_type,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to proxy download for {url}: {e}", exc_info=True)
|
||||||
|
# 返回一个统一错误响应,而不是让浏览器打开原地址
|
||||||
|
return Response(
|
||||||
|
content="Failed to download file",
|
||||||
|
status_code=500,
|
||||||
|
media_type="text/plain; charset=utf-8",
|
||||||
|
)
|
||||||
289
backend/src/api/storage_admin.py
Normal file
289
backend/src/api/storage_admin.py
Normal file
@@ -0,0 +1,289 @@
|
|||||||
|
"""
|
||||||
|
Storage Admin API
|
||||||
|
|
||||||
|
存储管理 API 端点,用于管理员管理存储资源。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Optional, Dict, Any, List
|
||||||
|
from datetime import datetime
|
||||||
|
from fastapi import APIRouter, Depends, Query, Request
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlmodel import Session, select, func
|
||||||
|
|
||||||
|
from src.auth.dependencies import require_admin
|
||||||
|
from src.auth.models import UserAuth
|
||||||
|
from src.models.schemas import ResponseModel
|
||||||
|
from src.utils.pagination import Paginator
|
||||||
|
from src.config.database import engine
|
||||||
|
from src.models.entities import ProjectDB, UserDB
|
||||||
|
from src.services.storage_service import storage_manager
|
||||||
|
from src.utils.errors import ErrorCode, AppException
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/admin/storage", tags=["admin-storage"])
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ===== 响应模型 =====
|
||||||
|
|
||||||
|
class StorageStatsResponse(BaseModel):
|
||||||
|
"""存储统计响应"""
|
||||||
|
total_capacity: int = Field(..., description="总容量 (bytes)")
|
||||||
|
used_space: int = Field(..., description="已用空间 (bytes)")
|
||||||
|
file_count: int = Field(..., description="文件数量")
|
||||||
|
usage_percent: float = Field(..., description="使用百分比")
|
||||||
|
|
||||||
|
|
||||||
|
class StorageFileItem(BaseModel):
|
||||||
|
"""存储文件项"""
|
||||||
|
id: str
|
||||||
|
path: str
|
||||||
|
size: int
|
||||||
|
user_id: str
|
||||||
|
username: Optional[str]
|
||||||
|
project_id: Optional[str]
|
||||||
|
project_name: Optional[str]
|
||||||
|
created_at: float
|
||||||
|
|
||||||
|
|
||||||
|
class StorageUserRankingItem(BaseModel):
|
||||||
|
"""用户存储排行项"""
|
||||||
|
user_id: str
|
||||||
|
username: str
|
||||||
|
storage_used: int
|
||||||
|
file_count: int
|
||||||
|
rank: int
|
||||||
|
|
||||||
|
|
||||||
|
class StorageConfigResponse(BaseModel):
|
||||||
|
"""存储配置响应"""
|
||||||
|
storage_type: str
|
||||||
|
base_path: Optional[str]
|
||||||
|
max_file_size: int
|
||||||
|
allowed_extensions: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
# ===== API 端点 =====
|
||||||
|
|
||||||
|
@router.get("/stats", response_model=ResponseModel)
|
||||||
|
async def get_storage_stats(
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取存储统计信息
|
||||||
|
|
||||||
|
需要管理员权限。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 计算存储统计
|
||||||
|
with Session(engine) as session:
|
||||||
|
# 这里需要根据实际存储实现调整
|
||||||
|
# 暂时返回一个基础统计
|
||||||
|
total_capacity = 100 * 1024 * 1024 * 1024 # 100GB
|
||||||
|
used_space = 0 # 待实现
|
||||||
|
file_count = 0 # 待实现
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data=StorageStatsResponse(
|
||||||
|
total_capacity=total_capacity,
|
||||||
|
used_space=used_space,
|
||||||
|
file_count=file_count,
|
||||||
|
usage_percent=(used_space / total_capacity * 100) if total_capacity > 0 else 0,
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting storage stats: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/files", response_model=ResponseModel)
|
||||||
|
async def list_storage_files(
|
||||||
|
request: Request,
|
||||||
|
page: int = Query(1, ge=1, description="页码"),
|
||||||
|
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||||
|
user_id: Optional[str] = Query(None, description="按用户 ID 过滤"),
|
||||||
|
project_id: Optional[str] = Query(None, description="按项目 ID 过滤"),
|
||||||
|
search: Optional[str] = Query(None, description="搜索文件路径"),
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
列出存储文件
|
||||||
|
|
||||||
|
支持按用户、项目过滤。
|
||||||
|
需要管理员权限。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with Session(engine) as session:
|
||||||
|
# 查询项目关联的文件
|
||||||
|
query = select(ProjectDB)
|
||||||
|
|
||||||
|
if user_id:
|
||||||
|
query = query.where(ProjectDB.user_id == user_id)
|
||||||
|
|
||||||
|
total = session.exec(
|
||||||
|
select(func.count()).select_from(query.subquery())
|
||||||
|
).one()
|
||||||
|
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
projects = session.exec(query.offset(offset).limit(page_size)).all()
|
||||||
|
|
||||||
|
items = []
|
||||||
|
for project in projects:
|
||||||
|
items.append(
|
||||||
|
StorageFileItem(
|
||||||
|
id=project.id,
|
||||||
|
path=f"/projects/{project.id}",
|
||||||
|
size=0, # 待实现
|
||||||
|
user_id=project.user_id or "",
|
||||||
|
username=None, # 待关联用户表
|
||||||
|
project_id=project.id,
|
||||||
|
project_name=project.name,
|
||||||
|
created_at=project.created_at,
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
|
||||||
|
paginator = Paginator(
|
||||||
|
items=items,
|
||||||
|
total=total,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return paginator.to_response(request)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error listing storage files: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/users/ranking", response_model=ResponseModel)
|
||||||
|
async def get_storage_user_ranking(
|
||||||
|
request: Request,
|
||||||
|
page: int = Query(1, ge=1, description="页码"),
|
||||||
|
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取用户存储使用排行
|
||||||
|
|
||||||
|
需要管理员权限。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with Session(engine) as session:
|
||||||
|
# 查询用户项目数量作为排行依据
|
||||||
|
query = """
|
||||||
|
SELECT user_id, COUNT(*) as project_count
|
||||||
|
FROM projects
|
||||||
|
WHERE deleted_at IS NULL
|
||||||
|
GROUP BY user_id
|
||||||
|
ORDER BY project_count DESC
|
||||||
|
"""
|
||||||
|
# 简化的实现
|
||||||
|
users_query = select(UserDB)
|
||||||
|
users = session.exec(users_query.limit(page_size).offset((page - 1) * page_size)).all()
|
||||||
|
|
||||||
|
items = []
|
||||||
|
for idx, user in enumerate(users):
|
||||||
|
project_count = session.exec(
|
||||||
|
select(func.count()).where(ProjectDB.user_id == user.id)
|
||||||
|
).one()
|
||||||
|
|
||||||
|
items.append(
|
||||||
|
StorageUserRankingItem(
|
||||||
|
user_id=user.id,
|
||||||
|
username=user.username,
|
||||||
|
storage_used=0, # 待实现
|
||||||
|
file_count=project_count,
|
||||||
|
rank=(page - 1) * page_size + idx + 1,
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
|
||||||
|
total = session.exec(select(func.count()).select_from(UserDB)).one()
|
||||||
|
|
||||||
|
paginator = Paginator(
|
||||||
|
items=items,
|
||||||
|
total=total,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return paginator.to_response(request)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting storage ranking: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/cleanup", response_model=ResponseModel)
|
||||||
|
async def cleanup_orphan_files(
|
||||||
|
dry_run: bool = Query(True, description="是否仅模拟执行"),
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
清理孤立文件(没有关联项目的文件)
|
||||||
|
|
||||||
|
需要管理员权限。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 简化的实现
|
||||||
|
return ResponseModel(
|
||||||
|
data={
|
||||||
|
"orphan_count": 0,
|
||||||
|
"deleted_count": 0,
|
||||||
|
"deleted_size": 0,
|
||||||
|
"dry_run": dry_run,
|
||||||
|
"message": "清理功能待实现",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error cleaning up orphan files: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/config", response_model=ResponseModel)
|
||||||
|
async def get_storage_config(
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取存储配置信息
|
||||||
|
|
||||||
|
需要管理员权限。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from src.config.settings import (
|
||||||
|
STORAGE_TYPE,
|
||||||
|
STORAGE_BASE_PATH,
|
||||||
|
STORAGE_MAX_FILE_SIZE,
|
||||||
|
STORAGE_ALLOWED_EXTENSIONS,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data=StorageConfigResponse(
|
||||||
|
storage_type=STORAGE_TYPE or "local",
|
||||||
|
base_path=STORAGE_BASE_PATH,
|
||||||
|
max_file_size=STORAGE_MAX_FILE_SIZE,
|
||||||
|
allowed_extensions=STORAGE_ALLOWED_EXTENSIONS or [],
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting storage config: {e}")
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
175
backend/src/api/tasks.py
Normal file
175
backend/src/api/tasks.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
"""
|
||||||
|
Tasks Controller - 任务管理
|
||||||
|
使用新的任务管理器 V2
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
from fastapi import APIRouter, Query, Request, Depends
|
||||||
|
|
||||||
|
from src.models.schemas import ResponseModel, Task, PaginationParams
|
||||||
|
from src.services.task_manager import task_manager, TaskPriority
|
||||||
|
from src.auth.dependencies import get_current_user
|
||||||
|
from src.auth.models import UserAuth
|
||||||
|
from src.utils.errors import (
|
||||||
|
TaskNotFoundException,
|
||||||
|
InvalidParameterException,
|
||||||
|
BusinessException,
|
||||||
|
ErrorCode
|
||||||
|
)
|
||||||
|
from src.utils.pagination import Paginator
|
||||||
|
|
||||||
|
router = APIRouter(tags=["tasks"])
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# 注意:更具体的路由(如 /tasks/stats)必须在参数化路由(如 /tasks/{task_id})之前定义
|
||||||
|
# 否则 FastAPI 会将 "stats" 当作 task_id 处理
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/tasks/stats", response_model=ResponseModel)
|
||||||
|
async def get_task_stats():
|
||||||
|
"""获取任务统计信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
任务统计数据
|
||||||
|
"""
|
||||||
|
stats = task_manager.get_stats()
|
||||||
|
return ResponseModel(data=stats)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/tasks", response_model=ResponseModel)
|
||||||
|
async def list_tasks(
|
||||||
|
request: Request,
|
||||||
|
task_type: Optional[str] = Query(None, description="Filter by task type (image, video, script)"),
|
||||||
|
status: Optional[str] = Query(None, description="Filter by status"),
|
||||||
|
page: int = Query(1, ge=1, description="页码,从1开始"),
|
||||||
|
page_size: int = Query(20, ge=1, le=100, description="每页数量,最大100"),
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""列出任务(分页)
|
||||||
|
|
||||||
|
只返回当前用户的任务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_type: 任务类型过滤
|
||||||
|
status: 状态过滤
|
||||||
|
page: 页码
|
||||||
|
page_size: 每页数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
分页的任务列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 计算偏移量
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
|
||||||
|
# 获取当前用户的任务
|
||||||
|
tasks = task_manager.list_tasks(
|
||||||
|
type=task_type,
|
||||||
|
limit=page_size,
|
||||||
|
offset=offset,
|
||||||
|
user_id=current_user.id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 计算总数
|
||||||
|
total = len(tasks) # 简化处理,实际应该查询总数
|
||||||
|
|
||||||
|
# 创建分页器
|
||||||
|
paginator = Paginator(
|
||||||
|
items=tasks,
|
||||||
|
total=total,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size
|
||||||
|
)
|
||||||
|
|
||||||
|
return paginator.to_response(request)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error listing tasks: {e}", exc_info=True)
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INTERNAL_ERROR,
|
||||||
|
status_code=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/tasks/{task_id}", response_model=ResponseModel)
|
||||||
|
async def get_task_status(
|
||||||
|
task_id: str,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""获取任务状态
|
||||||
|
|
||||||
|
只能访问当前用户的任务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: 任务 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
任务详情
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TaskNotFoundException: 任务不存在或无权访问
|
||||||
|
"""
|
||||||
|
task = await task_manager.get_task(task_id)
|
||||||
|
|
||||||
|
if not task:
|
||||||
|
raise TaskNotFoundException(task_id)
|
||||||
|
|
||||||
|
# 检查任务是否属于当前用户
|
||||||
|
if task.user_id and task.user_id != current_user.id:
|
||||||
|
raise TaskNotFoundException(task_id)
|
||||||
|
|
||||||
|
return ResponseModel(data=task)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/tasks/{task_id}", response_model=ResponseModel)
|
||||||
|
async def cancel_task(
|
||||||
|
task_id: str,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""取消任务
|
||||||
|
|
||||||
|
只能取消当前用户的任务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: 任务 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
取消结果
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TaskNotFoundException: 任务不存在或无权访问
|
||||||
|
BusinessException: 取消失败
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 先检查任务是否属于当前用户
|
||||||
|
task = await task_manager.get_task(task_id)
|
||||||
|
if not task:
|
||||||
|
raise TaskNotFoundException(task_id)
|
||||||
|
|
||||||
|
if task.user_id and task.user_id != current_user.id:
|
||||||
|
raise TaskNotFoundException(task_id)
|
||||||
|
|
||||||
|
success = await task_manager.cancel_task(task_id)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
raise BusinessException(
|
||||||
|
ErrorCode.TASK_CANCELLED,
|
||||||
|
"Task cannot be cancelled (already completed or failed)",
|
||||||
|
{"task_id": task_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
message="Task cancelled successfully",
|
||||||
|
data={"task_id": task_id, "cancelled": True}
|
||||||
|
)
|
||||||
|
|
||||||
|
except TaskNotFoundException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to cancel task {task_id}: {e}", exc_info=True)
|
||||||
|
raise BusinessException(
|
||||||
|
ErrorCode.UNKNOWN_ERROR,
|
||||||
|
"Failed to cancel task",
|
||||||
|
{"task_id": task_id, "reason": str(e)}
|
||||||
|
)
|
||||||
371
backend/src/api/user_api_keys.py
Normal file
371
backend/src/api/user_api_keys.py
Normal file
@@ -0,0 +1,371 @@
|
|||||||
|
"""
|
||||||
|
用户 API Key 管理接口
|
||||||
|
|
||||||
|
提供用户管理自己的 API Key 的 CRUD 接口,以及管理员管理所有用户 API Key 的接口。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
from fastapi import APIRouter, Depends, status, Query, Request
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from src.auth.dependencies import get_current_user, require_admin
|
||||||
|
from src.auth.models import UserAuth
|
||||||
|
from src.services.user_api_key_service import user_api_key_service
|
||||||
|
from src.models.schemas import ResponseModel, PaginationParams
|
||||||
|
from src.utils.pagination import Paginator
|
||||||
|
from src.utils.errors import ErrorCode, AppException
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/user-api-keys", tags=["user-api-keys"])
|
||||||
|
|
||||||
|
|
||||||
|
# ===== 请求/响应模型 =====
|
||||||
|
|
||||||
|
class ApiKeyCreateRequest(BaseModel):
|
||||||
|
"""创建 API Key 请求"""
|
||||||
|
provider: str = Field(..., description="提供商,如 openai, dashscope")
|
||||||
|
api_key: str = Field(..., description="API Key 值", min_length=1)
|
||||||
|
name: Optional[str] = Field(None, description="自定义名称")
|
||||||
|
extra_config: Optional[dict] = Field(None, description="额外配置")
|
||||||
|
|
||||||
|
|
||||||
|
class ApiKeyUpdateRequest(BaseModel):
|
||||||
|
"""更新 API Key 请求"""
|
||||||
|
name: Optional[str] = Field(None, description="新名称")
|
||||||
|
api_key: Optional[str] = Field(None, description="新 API Key 值")
|
||||||
|
is_active: Optional[bool] = Field(None, description="是否启用")
|
||||||
|
extra_config: Optional[dict] = Field(None, description="额外配置")
|
||||||
|
|
||||||
|
|
||||||
|
class ApiKeyResponse(BaseModel):
|
||||||
|
"""API Key 响应"""
|
||||||
|
id: str
|
||||||
|
user_id: str
|
||||||
|
provider: str
|
||||||
|
name: str
|
||||||
|
masked_key: str
|
||||||
|
is_active: bool
|
||||||
|
created_at: float
|
||||||
|
updated_at: float
|
||||||
|
last_used_at: Optional[float]
|
||||||
|
usage_count: int
|
||||||
|
extra_config: Optional[dict]
|
||||||
|
|
||||||
|
|
||||||
|
class ApiKeyVerifyResponse(BaseModel):
|
||||||
|
"""API Key 验证响应"""
|
||||||
|
valid: bool
|
||||||
|
provider: Optional[str] = None
|
||||||
|
message: Optional[str] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
# ===== 管理员响应模型 =====
|
||||||
|
|
||||||
|
class AdminApiKeyListItem(BaseModel):
|
||||||
|
"""管理员 API Key 列表项"""
|
||||||
|
id: str
|
||||||
|
user_id: str
|
||||||
|
username: str = Field(..., description="关联用户名")
|
||||||
|
email: str = Field(..., description="关联用户邮箱")
|
||||||
|
provider: str
|
||||||
|
name: str
|
||||||
|
masked_key: str
|
||||||
|
is_active: bool
|
||||||
|
created_at: float
|
||||||
|
last_used_at: Optional[float]
|
||||||
|
usage_count: int
|
||||||
|
|
||||||
|
|
||||||
|
class AdminApiKeyUsageRecord(BaseModel):
|
||||||
|
"""API Key 使用记录"""
|
||||||
|
task_id: Optional[str]
|
||||||
|
task_type: Optional[str]
|
||||||
|
model: Optional[str]
|
||||||
|
provider: str
|
||||||
|
credits_used: Optional[float]
|
||||||
|
created_at: str
|
||||||
|
|
||||||
|
|
||||||
|
# ===== API 端点 =====
|
||||||
|
|
||||||
|
@router.get("", response_model=ResponseModel)
|
||||||
|
async def list_api_keys(
|
||||||
|
current_user: UserAuth = Depends(get_current_user),
|
||||||
|
include_inactive: bool = False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取当前用户的所有 API Key
|
||||||
|
|
||||||
|
Args:
|
||||||
|
include_inactive: 是否包含已禁用的 Key
|
||||||
|
"""
|
||||||
|
keys = user_api_key_service.get_user_api_keys(
|
||||||
|
user_id=current_user.id,
|
||||||
|
include_inactive=include_inactive
|
||||||
|
)
|
||||||
|
return ResponseModel(data=keys)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", status_code=status.HTTP_201_CREATED, response_model=ResponseModel)
|
||||||
|
async def create_api_key(
|
||||||
|
request: ApiKeyCreateRequest,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
创建新的 API Key
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
key_db = user_api_key_service.create_api_key(
|
||||||
|
user_id=current_user.id,
|
||||||
|
provider=request.provider,
|
||||||
|
api_key=request.api_key,
|
||||||
|
name=request.name,
|
||||||
|
extra_config=request.extra_config
|
||||||
|
)
|
||||||
|
key = user_api_key_service.get_api_key_by_id(key_db.id, current_user.id)
|
||||||
|
return ResponseModel(data=key)
|
||||||
|
except ValueError as e:
|
||||||
|
raise AppException(
|
||||||
|
message=str(e),
|
||||||
|
code=ErrorCode.INVALID_PARAMETER,
|
||||||
|
status_code=400
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{key_id}", response_model=ResponseModel)
|
||||||
|
async def get_api_key(
|
||||||
|
key_id: str,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取指定 API Key 的详细信息
|
||||||
|
"""
|
||||||
|
key = user_api_key_service.get_api_key_by_id(key_id, current_user.id)
|
||||||
|
if not key:
|
||||||
|
raise AppException(
|
||||||
|
message="API Key not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
return ResponseModel(data=key)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{key_id}", response_model=ResponseModel)
|
||||||
|
async def update_api_key(
|
||||||
|
key_id: str,
|
||||||
|
request: ApiKeyUpdateRequest,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
更新 API Key
|
||||||
|
"""
|
||||||
|
key = user_api_key_service.update_api_key(
|
||||||
|
key_id=key_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
name=request.name,
|
||||||
|
api_key=request.api_key,
|
||||||
|
is_active=request.is_active,
|
||||||
|
extra_config=request.extra_config
|
||||||
|
)
|
||||||
|
if not key:
|
||||||
|
raise AppException(
|
||||||
|
message="API Key not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
return ResponseModel(data=key)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{key_id}", response_model=ResponseModel)
|
||||||
|
async def delete_api_key(
|
||||||
|
key_id: str,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
删除 API Key
|
||||||
|
"""
|
||||||
|
success = user_api_key_service.delete_api_key(key_id, current_user.id)
|
||||||
|
if not success:
|
||||||
|
raise AppException(
|
||||||
|
message="API Key not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
return ResponseModel(data={"deleted": True, "id": key_id})
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{key_id}/verify", response_model=ResponseModel)
|
||||||
|
async def verify_api_key(
|
||||||
|
key_id: str,
|
||||||
|
current_user: UserAuth = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
验证 API Key 是否有效
|
||||||
|
"""
|
||||||
|
result = await user_api_key_service.verify_api_key(key_id, current_user.id)
|
||||||
|
return ResponseModel(data=result)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# 管理员 API 端点
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
@router.get("/admin/api-keys", response_model=ResponseModel)
|
||||||
|
async def admin_list_api_keys(
|
||||||
|
request: Request,
|
||||||
|
page: int = Query(1, ge=1, description="页码"),
|
||||||
|
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||||
|
user_id: Optional[str] = Query(None, description="按用户 ID 过滤"),
|
||||||
|
provider: Optional[str] = Query(None, description="按提供商过滤"),
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
管理员列出所有 API Keys
|
||||||
|
|
||||||
|
支持分页、按用户 ID、提供商过滤。
|
||||||
|
需要管理员权限。
|
||||||
|
"""
|
||||||
|
from sqlmodel import Session, select, func
|
||||||
|
from src.config.database import engine
|
||||||
|
from src.models.entities import UserApiKeyDB, UserDB
|
||||||
|
|
||||||
|
with Session(engine) as session:
|
||||||
|
query = select(UserApiKeyDB)
|
||||||
|
|
||||||
|
# 应用过滤
|
||||||
|
if user_id:
|
||||||
|
query = query.where(UserApiKeyDB.user_id == user_id)
|
||||||
|
if provider:
|
||||||
|
query = query.where(UserApiKeyDB.provider == provider)
|
||||||
|
|
||||||
|
# 获取总数
|
||||||
|
total = session.exec(
|
||||||
|
select(func.count()).select_from(query.subquery())
|
||||||
|
).one()
|
||||||
|
|
||||||
|
# 分页
|
||||||
|
offset = (page - 1) * page_size
|
||||||
|
query = query.order_by(UserApiKeyDB.created_at.desc()).offset(offset).limit(page_size)
|
||||||
|
|
||||||
|
api_keys = session.exec(query).all()
|
||||||
|
|
||||||
|
# 构建响应项(关联用户信息)
|
||||||
|
items = []
|
||||||
|
for key in api_keys:
|
||||||
|
user = session.get(UserDB, key.user_id)
|
||||||
|
# 获取脱敏 key
|
||||||
|
raw_key = key.encrypted_key[:8] + "***" if key.encrypted_key else ""
|
||||||
|
items.append(
|
||||||
|
AdminApiKeyListItem(
|
||||||
|
id=key.id,
|
||||||
|
user_id=key.user_id,
|
||||||
|
username=user.username if user else "Unknown",
|
||||||
|
email=user.email if user else "",
|
||||||
|
provider=key.provider,
|
||||||
|
name=key.name or f"{key.provider} Key",
|
||||||
|
masked_key=raw_key,
|
||||||
|
is_active=key.is_active,
|
||||||
|
created_at=key.created_at,
|
||||||
|
last_used_at=key.last_used_at,
|
||||||
|
usage_count=key.usage_count,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
paginator = Paginator(
|
||||||
|
items=[item.model_dump() for item in items],
|
||||||
|
total=total,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return paginator.to_response(request)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/admin/api-keys/{key_id}/revoke", response_model=ResponseModel)
|
||||||
|
async def admin_revoke_api_key(
|
||||||
|
key_id: str,
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
撤销 API Key
|
||||||
|
|
||||||
|
需要管理员权限。
|
||||||
|
"""
|
||||||
|
from sqlmodel import Session, select
|
||||||
|
from src.config.database import engine
|
||||||
|
from src.models.entities import UserApiKeyDB
|
||||||
|
|
||||||
|
with Session(engine) as session:
|
||||||
|
key = session.get(UserApiKeyDB, key_id)
|
||||||
|
if not key:
|
||||||
|
raise AppException(
|
||||||
|
message="API Key not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
key.is_active = False
|
||||||
|
key.updated_at = datetime.now().timestamp()
|
||||||
|
session.add(key)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data={
|
||||||
|
"id": key_id,
|
||||||
|
"is_active": False,
|
||||||
|
"message": "API Key 已成功撤销",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/admin/api-keys/{key_id}/usage", response_model=ResponseModel)
|
||||||
|
async def admin_get_api_key_usage(
|
||||||
|
key_id: str,
|
||||||
|
current_user: UserAuth = Depends(require_admin),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取 API Key 使用记录
|
||||||
|
|
||||||
|
需要管理员权限。
|
||||||
|
"""
|
||||||
|
from sqlmodel import Session, select
|
||||||
|
from src.config.database import engine
|
||||||
|
from src.models.entities import UserApiKeyDB, TaskDB
|
||||||
|
|
||||||
|
with Session(engine) as session:
|
||||||
|
key = session.get(UserApiKeyDB, key_id)
|
||||||
|
if not key:
|
||||||
|
raise AppException(
|
||||||
|
message="API Key not found",
|
||||||
|
code=ErrorCode.NOT_FOUND,
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
# 查询相关任务记录
|
||||||
|
tasks = session.exec(
|
||||||
|
select(TaskDB)
|
||||||
|
.where(TaskDB.user_id == key.user_id)
|
||||||
|
.order_by(TaskDB.created_at.desc())
|
||||||
|
.limit(50)
|
||||||
|
).all()
|
||||||
|
|
||||||
|
usage_records = [
|
||||||
|
AdminApiKeyUsageRecord(
|
||||||
|
task_id=task.id,
|
||||||
|
task_type=task.type,
|
||||||
|
model=task.model,
|
||||||
|
provider=task.provider or key.provider,
|
||||||
|
credits_used=None,
|
||||||
|
created_at=datetime.fromtimestamp(task.created_at).isoformat(),
|
||||||
|
)
|
||||||
|
for task in tasks
|
||||||
|
]
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
data={
|
||||||
|
"key_id": key_id,
|
||||||
|
"usage_count": key.usage_count,
|
||||||
|
"records": [r.model_dump() for r in usage_records],
|
||||||
|
}
|
||||||
|
)
|
||||||
41
backend/src/auth/__init__.py
Normal file
41
backend/src/auth/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
"""
|
||||||
|
认证授权模块
|
||||||
|
|
||||||
|
提供 JWT Token 认证、OAuth2 集成和 HTTP 轮询状态查询。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.auth.jwt import (
|
||||||
|
create_access_token,
|
||||||
|
create_refresh_token,
|
||||||
|
verify_token,
|
||||||
|
verify_refresh_token,
|
||||||
|
TokenPayload,
|
||||||
|
TokenPair,
|
||||||
|
)
|
||||||
|
from src.auth.dependencies import (
|
||||||
|
get_current_user,
|
||||||
|
get_current_active_user,
|
||||||
|
require_permissions,
|
||||||
|
)
|
||||||
|
from src.auth.middleware import AuthMiddleware
|
||||||
|
from src.auth.models import UserAuth, TokenData, RefreshTokenRequest
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# JWT
|
||||||
|
"create_access_token",
|
||||||
|
"create_refresh_token",
|
||||||
|
"verify_token",
|
||||||
|
"verify_refresh_token",
|
||||||
|
"TokenPayload",
|
||||||
|
"TokenPair",
|
||||||
|
# Dependencies
|
||||||
|
"get_current_user",
|
||||||
|
"get_current_active_user",
|
||||||
|
"require_permissions",
|
||||||
|
# Middleware
|
||||||
|
"AuthMiddleware",
|
||||||
|
# Models
|
||||||
|
"UserAuth",
|
||||||
|
"TokenData",
|
||||||
|
"RefreshTokenRequest",
|
||||||
|
]
|
||||||
250
backend/src/auth/dependencies.py
Normal file
250
backend/src/auth/dependencies.py
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
"""
|
||||||
|
FastAPI 认证依赖
|
||||||
|
|
||||||
|
提供可注入的依赖函数用于获取当前用户和验证权限。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Optional, List, Annotated
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer, HTTPBearer
|
||||||
|
|
||||||
|
from src.auth.jwt import verify_token
|
||||||
|
from src.auth.models import UserAuth, TokenPayload
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_test_env() -> bool:
|
||||||
|
return (
|
||||||
|
os.getenv("PYTEST_CURRENT_TEST") is not None
|
||||||
|
or os.getenv("PIXEL_TEST_MODE") == "1"
|
||||||
|
or os.getenv("NODE_ENV") == "test"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_test_user() -> UserAuth:
|
||||||
|
return UserAuth(
|
||||||
|
id="test-user",
|
||||||
|
username="test_user",
|
||||||
|
email=None,
|
||||||
|
is_active=True,
|
||||||
|
is_superuser=True,
|
||||||
|
permissions=["*"],
|
||||||
|
roles=["test"],
|
||||||
|
created_at=None,
|
||||||
|
last_login=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Lazy import to avoid circular import
|
||||||
|
_user_service = None
|
||||||
|
|
||||||
|
def _get_user_service():
|
||||||
|
global _user_service
|
||||||
|
if _user_service is None:
|
||||||
|
from src.services.user_service import user_service
|
||||||
|
_user_service = user_service
|
||||||
|
return _user_service
|
||||||
|
|
||||||
|
# OAuth2 scheme for Swagger UI
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(
|
||||||
|
tokenUrl="/api/v1/auth/login",
|
||||||
|
auto_error=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# HTTP Bearer for direct token usage
|
||||||
|
http_bearer = HTTPBearer(auto_error=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_user(
|
||||||
|
token: Annotated[Optional[str], Depends(oauth2_scheme)]
|
||||||
|
) -> UserAuth:
|
||||||
|
"""
|
||||||
|
获取当前认证用户(从 HTTP 请求)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: OAuth2 token from Authorization header
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserAuth 对象
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 401 if authentication fails
|
||||||
|
"""
|
||||||
|
credentials_exception = HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Could not validate credentials",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not token:
|
||||||
|
if _is_test_env():
|
||||||
|
return _build_test_user()
|
||||||
|
raise credentials_exception
|
||||||
|
|
||||||
|
payload = await verify_token(token, token_type="access")
|
||||||
|
if not payload:
|
||||||
|
raise credentials_exception
|
||||||
|
|
||||||
|
if payload.sid:
|
||||||
|
from src.services.session_service import session_service
|
||||||
|
|
||||||
|
if not session_service.is_session_active(payload.sid):
|
||||||
|
raise credentials_exception
|
||||||
|
|
||||||
|
# 从缓存或数据库获取用户信息
|
||||||
|
user = await _get_user_service().get_user_by_id(payload.sub)
|
||||||
|
|
||||||
|
if not user:
|
||||||
|
raise credentials_exception
|
||||||
|
|
||||||
|
if not user.is_active:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="User is inactive"
|
||||||
|
)
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_active_user(
|
||||||
|
current_user: Annotated[UserAuth, Depends(get_current_user)]
|
||||||
|
) -> UserAuth:
|
||||||
|
"""
|
||||||
|
获取当前活跃用户
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_user: 当前用户
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserAuth 对象
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 403 if user is inactive
|
||||||
|
"""
|
||||||
|
if not current_user.is_active:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Inactive user"
|
||||||
|
)
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
def require_permissions(required_permissions: List[str]):
|
||||||
|
"""
|
||||||
|
权限检查依赖工厂
|
||||||
|
|
||||||
|
Args:
|
||||||
|
required_permissions: 需要的权限列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
依赖函数
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@router.get("/admin")
|
||||||
|
async def admin_endpoint(
|
||||||
|
user: UserAuth = Depends(require_permissions(["admin"]))
|
||||||
|
):
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
async def permission_checker(
|
||||||
|
current_user: Annotated[UserAuth, Depends(get_current_user)]
|
||||||
|
) -> UserAuth:
|
||||||
|
# 超级用户跳过权限检查
|
||||||
|
if current_user.is_superuser:
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
# 检查权限
|
||||||
|
user_perms = set(current_user.permissions)
|
||||||
|
required_perms = set(required_permissions)
|
||||||
|
|
||||||
|
if not required_perms.issubset(user_perms):
|
||||||
|
missing = required_perms - user_perms
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=f"Missing permissions: {', '.join(missing)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
return permission_checker
|
||||||
|
|
||||||
|
|
||||||
|
async def optional_user(
|
||||||
|
token: Annotated[Optional[str], Depends(oauth2_scheme)]
|
||||||
|
) -> Optional[UserAuth]:
|
||||||
|
"""
|
||||||
|
可选的用户认证
|
||||||
|
|
||||||
|
如果提供了 token 则验证,否则返回 None
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: OAuth2 token
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserAuth 对象或 None
|
||||||
|
"""
|
||||||
|
if not token:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await get_current_user(token)
|
||||||
|
except HTTPException:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def require_admin(
|
||||||
|
current_user: Annotated[UserAuth, Depends(get_current_user)]
|
||||||
|
) -> UserAuth:
|
||||||
|
"""
|
||||||
|
要求管理员权限
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_user: 当前用户
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserAuth 对象
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: 403 if user is not an admin
|
||||||
|
"""
|
||||||
|
if not current_user.is_superuser:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Admin access required"
|
||||||
|
)
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
class AuthDependencies:
|
||||||
|
"""
|
||||||
|
认证依赖组合
|
||||||
|
|
||||||
|
提供常用的认证依赖组合
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def authenticated():
|
||||||
|
"""返回需要认证的依赖"""
|
||||||
|
return Depends(get_current_user)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def active_user():
|
||||||
|
"""返回需要活跃用户的依赖"""
|
||||||
|
return Depends(get_current_active_user)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def optional():
|
||||||
|
"""返回可选认证的依赖"""
|
||||||
|
return Depends(optional_user)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def permissions(perms: List[str]):
|
||||||
|
"""返回需要特定权限的依赖"""
|
||||||
|
return Depends(require_permissions(perms))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def admin():
|
||||||
|
"""返回需要管理员权限的依赖"""
|
||||||
|
return Depends(require_admin)
|
||||||
346
backend/src/auth/jwt.py
Normal file
346
backend/src/auth/jwt.py
Normal file
@@ -0,0 +1,346 @@
|
|||||||
|
"""
|
||||||
|
JWT Token 处理模块
|
||||||
|
|
||||||
|
提供 access token 和 refresh token 的创建、验证功能。
|
||||||
|
支持 RSA 密钥对或 HMAC 对称加密。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import secrets
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from jose import jwt, JWTError
|
||||||
|
from passlib.context import CryptContext
|
||||||
|
|
||||||
|
from src.auth.models import TokenPayload, TokenPair
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Password hashing
|
||||||
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
|
||||||
|
# JWT 配置
|
||||||
|
JWT_ALGORITHM = "HS256"
|
||||||
|
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30"))
|
||||||
|
REFRESH_TOKEN_EXPIRE_DAYS = int(os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "7"))
|
||||||
|
|
||||||
|
# 密钥配置(生产环境必须从环境变量获取)
|
||||||
|
def _get_secret_key() -> str:
|
||||||
|
"""从环境变量获取密钥,如果未设置则生成随机密钥(仅用于开发)"""
|
||||||
|
secret = os.getenv("JWT_SECRET_KEY")
|
||||||
|
if secret:
|
||||||
|
return secret
|
||||||
|
|
||||||
|
# 开发环境:生成随机密钥并警告
|
||||||
|
if os.getenv("NODE_ENV", "development") == "development":
|
||||||
|
logger.warning(
|
||||||
|
"JWT_SECRET_KEY not set in environment. "
|
||||||
|
"Using auto-generated random key (tokens will not persist across restarts). "
|
||||||
|
"Please set JWT_SECRET_KEY in production!"
|
||||||
|
)
|
||||||
|
return secrets.token_urlsafe(32)
|
||||||
|
|
||||||
|
# 生产环境必须设置密钥
|
||||||
|
raise ValueError(
|
||||||
|
"JWT_SECRET_KEY must be set in environment variables for production. "
|
||||||
|
"Generate a secure key with: python -c 'import secrets; print(secrets.token_hex(32))'"
|
||||||
|
)
|
||||||
|
|
||||||
|
SECRET_KEY = _get_secret_key()
|
||||||
|
|
||||||
|
|
||||||
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||||
|
"""验证密码"""
|
||||||
|
return pwd_context.verify(plain_password, hashed_password)
|
||||||
|
|
||||||
|
|
||||||
|
def get_password_hash(password: str) -> str:
|
||||||
|
"""生成密码哈希"""
|
||||||
|
return pwd_context.hash(password)
|
||||||
|
|
||||||
|
|
||||||
|
def create_access_token(
|
||||||
|
user_id: str,
|
||||||
|
scopes: Optional[List[str]] = None,
|
||||||
|
expires_delta: Optional[timedelta] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
extra_claims: Optional[Dict[str, Any]] = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
创建 access token
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户 ID
|
||||||
|
scopes: 权限范围列表
|
||||||
|
expires_delta: 自定义过期时间
|
||||||
|
extra_claims: 额外的 JWT claims
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JWT access token string
|
||||||
|
"""
|
||||||
|
if expires_delta:
|
||||||
|
expire = datetime.now(timezone.utc) + expires_delta
|
||||||
|
else:
|
||||||
|
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"sub": user_id,
|
||||||
|
"exp": int(expire.timestamp()),
|
||||||
|
"iat": int(now.timestamp()),
|
||||||
|
"type": "access",
|
||||||
|
"scopes": scopes or [],
|
||||||
|
"jti": str(uuid.uuid4()),
|
||||||
|
"sid": session_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if extra_claims:
|
||||||
|
payload.update(extra_claims)
|
||||||
|
|
||||||
|
encoded_jwt = jwt.encode(payload, SECRET_KEY, algorithm=JWT_ALGORITHM)
|
||||||
|
return encoded_jwt
|
||||||
|
|
||||||
|
|
||||||
|
def create_refresh_token(
|
||||||
|
user_id: str,
|
||||||
|
expires_delta: Optional[timedelta] = None,
|
||||||
|
jti: Optional[str] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
session_family_id: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
创建 refresh token
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户 ID
|
||||||
|
expires_delta: 自定义过期时间
|
||||||
|
jti: JWT ID(用于 token 撤销)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JWT refresh token string
|
||||||
|
"""
|
||||||
|
if expires_delta:
|
||||||
|
expire = datetime.now(timezone.utc) + expires_delta
|
||||||
|
else:
|
||||||
|
expire = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
token_jti = jti or str(uuid.uuid4())
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"sub": user_id,
|
||||||
|
"exp": int(expire.timestamp()),
|
||||||
|
"iat": int(now.timestamp()),
|
||||||
|
"type": "refresh",
|
||||||
|
"jti": token_jti,
|
||||||
|
"sid": session_id,
|
||||||
|
"sfid": session_family_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
encoded_jwt = jwt.encode(payload, SECRET_KEY, algorithm=JWT_ALGORITHM)
|
||||||
|
return encoded_jwt
|
||||||
|
|
||||||
|
|
||||||
|
def create_token_pair(
|
||||||
|
user_id: str,
|
||||||
|
scopes: Optional[List[str]] = None,
|
||||||
|
access_expires: Optional[timedelta] = None,
|
||||||
|
refresh_expires: Optional[timedelta] = None,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
session_family_id: Optional[str] = None,
|
||||||
|
) -> TokenPair:
|
||||||
|
"""
|
||||||
|
创建 access token 和 refresh token 对
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户 ID
|
||||||
|
scopes: 权限范围
|
||||||
|
access_expires: access token 过期时间
|
||||||
|
refresh_expires: refresh token 过期时间
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TokenPair 包含 access_token 和 refresh_token
|
||||||
|
"""
|
||||||
|
access_token = create_access_token(user_id, scopes, access_expires, session_id=session_id)
|
||||||
|
refresh_token = create_refresh_token(
|
||||||
|
user_id,
|
||||||
|
refresh_expires,
|
||||||
|
session_id=session_id,
|
||||||
|
session_family_id=session_family_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return TokenPair(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=refresh_token,
|
||||||
|
session_id=session_id,
|
||||||
|
session_family_id=session_family_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_token(token: str, token_type: str = "access", check_blacklist: bool = True) -> Optional[TokenPayload]:
|
||||||
|
"""
|
||||||
|
验证 JWT token
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: JWT token string
|
||||||
|
token_type: 期望的 token 类型 ("access" 或 "refresh")
|
||||||
|
check_blacklist: 是否检查黑名单(默认为 True)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TokenPayload 如果验证成功,否则 None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||||
|
|
||||||
|
# 验证 token 类型
|
||||||
|
if payload.get("type") != token_type:
|
||||||
|
logger.warning(f"Token type mismatch: expected {token_type}, got {payload.get('type')}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 验证必要字段
|
||||||
|
if "sub" not in payload or "exp" not in payload:
|
||||||
|
logger.warning("Token missing required fields")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 检查是否在黑名单中
|
||||||
|
if check_blacklist:
|
||||||
|
try:
|
||||||
|
from src.services.token_blacklist_service import token_blacklist_service
|
||||||
|
is_revoked = await token_blacklist_service.is_token_revoked(token)
|
||||||
|
if is_revoked:
|
||||||
|
logger.warning(f"Token has been revoked: {payload.get('jti')}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to check token blacklist: {e}")
|
||||||
|
# 黑名单检查失败,继续验证(避免服务不可用)
|
||||||
|
|
||||||
|
return TokenPayload(**payload)
|
||||||
|
|
||||||
|
except JWTError as e:
|
||||||
|
logger.warning(f"JWT verification failed: {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error during token verification: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def verify_token_sync(token: str, token_type: str = "access") -> Optional[TokenPayload]:
|
||||||
|
"""
|
||||||
|
同步验证 JWT token(不检查黑名单,用于中间件等同步上下文)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: JWT token string
|
||||||
|
token_type: 期望的 token 类型 ("access" 或 "refresh")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TokenPayload 如果验证成功,否则 None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||||
|
|
||||||
|
# 验证 token 类型
|
||||||
|
if payload.get("type") != token_type:
|
||||||
|
logger.warning(f"Token type mismatch: expected {token_type}, got {payload.get('type')}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 验证必要字段
|
||||||
|
if "sub" not in payload or "exp" not in payload:
|
||||||
|
logger.warning("Token missing required fields")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return TokenPayload(**payload)
|
||||||
|
|
||||||
|
except JWTError as e:
|
||||||
|
logger.warning(f"JWT verification failed: {e}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error during token verification: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def verify_refresh_token(token: str) -> Optional[TokenPayload]:
|
||||||
|
"""
|
||||||
|
验证 refresh token
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: Refresh token string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TokenPayload 如果验证成功,否则 None
|
||||||
|
"""
|
||||||
|
return verify_token(token, token_type="refresh")
|
||||||
|
|
||||||
|
|
||||||
|
def decode_token_unsafe(token: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
不解密直接解码 token(用于获取 payload 信息而不验证签名)
|
||||||
|
|
||||||
|
警告: 仅用于获取非敏感信息,不要用于认证!
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: JWT token string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Payload dict 或 None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 分割 JWT
|
||||||
|
parts = token.split(".")
|
||||||
|
if len(parts) != 3:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 解码 payload (第二部分)
|
||||||
|
import base64
|
||||||
|
payload_b64 = parts[1]
|
||||||
|
# 添加 padding
|
||||||
|
padding = 4 - len(payload_b64) % 4
|
||||||
|
if padding != 4:
|
||||||
|
payload_b64 += "=" * padding
|
||||||
|
|
||||||
|
payload_json = base64.urlsafe_b64decode(payload_b64)
|
||||||
|
return json.loads(payload_json)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to decode token: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_token_expiry(token: str) -> Optional[datetime]:
|
||||||
|
"""
|
||||||
|
获取 token 的过期时间
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: JWT token string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
过期时间或 None
|
||||||
|
"""
|
||||||
|
payload = decode_token_unsafe(token)
|
||||||
|
if payload and "exp" in payload:
|
||||||
|
return datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def is_token_expired(token: str) -> bool:
|
||||||
|
"""
|
||||||
|
检查 token 是否已过期
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: JWT token string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
如果已过期返回 True
|
||||||
|
"""
|
||||||
|
expiry = get_token_expiry(token)
|
||||||
|
if expiry:
|
||||||
|
return datetime.now(timezone.utc) > expiry
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# Import json for decode_token_unsafe
|
||||||
|
import json
|
||||||
343
backend/src/auth/middleware.py
Normal file
343
backend/src/auth/middleware.py
Normal file
@@ -0,0 +1,343 @@
|
|||||||
|
"""
|
||||||
|
认证中间件
|
||||||
|
|
||||||
|
提供全局认证中间件和权限检查。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Optional, List, Set
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import Response, JSONResponse
|
||||||
|
from starlette.types import ASGIApp
|
||||||
|
|
||||||
|
from src.auth.jwt import verify_token
|
||||||
|
from src.auth.models import UserAuth, TokenPayload
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""
|
||||||
|
全局认证中间件
|
||||||
|
|
||||||
|
为所有请求添加认证信息,可选择性地保护路由。
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 不需要认证的路径
|
||||||
|
PUBLIC_PATHS: Set[str] = {
|
||||||
|
"/docs",
|
||||||
|
"/redoc",
|
||||||
|
"/openapi.json",
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
"/api/v1/auth/refresh",
|
||||||
|
"/health",
|
||||||
|
"/metrics",
|
||||||
|
"/api/v1/health",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 路径前缀(以这些前缀开头的路径不需要认证)
|
||||||
|
PUBLIC_PREFIXES: List[str] = [
|
||||||
|
"/static/",
|
||||||
|
"/api/v1/public/",
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
app: ASGIApp,
|
||||||
|
require_auth: bool = False,
|
||||||
|
public_paths: Optional[Set[str]] = None,
|
||||||
|
public_prefixes: Optional[List[str]] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化认证中间件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: ASGI 应用
|
||||||
|
require_auth: 是否默认要求认证
|
||||||
|
public_paths: 额外的公开路径
|
||||||
|
public_prefixes: 额外的公开路径前缀
|
||||||
|
"""
|
||||||
|
super().__init__(app)
|
||||||
|
self.require_auth = require_auth
|
||||||
|
|
||||||
|
if public_paths:
|
||||||
|
self.PUBLIC_PATHS.update(public_paths)
|
||||||
|
if public_prefixes:
|
||||||
|
self.PUBLIC_PREFIXES.extend(public_prefixes)
|
||||||
|
|
||||||
|
def _is_public_path(self, path: str) -> bool:
|
||||||
|
"""检查路径是否公开"""
|
||||||
|
# 精确匹配
|
||||||
|
if path in self.PUBLIC_PATHS:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 前缀匹配
|
||||||
|
for prefix in self.PUBLIC_PREFIXES:
|
||||||
|
if path.startswith(prefix):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _extract_token(self, request: Request) -> Optional[str]:
|
||||||
|
"""从请求中提取 token"""
|
||||||
|
# 从 Authorization header
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if auth_header.startswith("Bearer "):
|
||||||
|
return auth_header[7:]
|
||||||
|
|
||||||
|
# 从 query parameter (用于特殊场景,如直接链接分享)
|
||||||
|
token = request.query_params.get("token")
|
||||||
|
if token:
|
||||||
|
return token
|
||||||
|
|
||||||
|
# 从 cookie
|
||||||
|
token = request.cookies.get("access_token")
|
||||||
|
if token:
|
||||||
|
return token
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next) -> Response:
|
||||||
|
"""
|
||||||
|
处理请求
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: 请求对象
|
||||||
|
call_next: 下一个中间件/处理函数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
响应对象
|
||||||
|
"""
|
||||||
|
path = request.url.path
|
||||||
|
|
||||||
|
# 提取 token
|
||||||
|
token = self._extract_token(request)
|
||||||
|
|
||||||
|
# 如果有 token,尝试验证
|
||||||
|
user: Optional[UserAuth] = None
|
||||||
|
if token:
|
||||||
|
payload = await verify_token(token, token_type="access")
|
||||||
|
if payload:
|
||||||
|
user = UserAuth(
|
||||||
|
id=payload.sub,
|
||||||
|
username=f"user_{payload.sub[:8]}",
|
||||||
|
is_active=True,
|
||||||
|
permissions=payload.scopes or []
|
||||||
|
)
|
||||||
|
|
||||||
|
# 将用户信息附加到请求状态
|
||||||
|
request.state.user = user
|
||||||
|
request.state.is_authenticated = user is not None
|
||||||
|
|
||||||
|
# 检查是否需要认证
|
||||||
|
if self._is_public_path(path):
|
||||||
|
# 公开路径,允许访问
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
# 非公开路径
|
||||||
|
if self.require_auth and not user:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=401,
|
||||||
|
content={
|
||||||
|
"code": "4010",
|
||||||
|
"message": "Authentication required",
|
||||||
|
"data": None
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 继续处理
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""
|
||||||
|
权限检查中间件
|
||||||
|
|
||||||
|
基于路径进行权限检查。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
app: ASGIApp,
|
||||||
|
path_permissions: Optional[dict] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化权限中间件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: ASGI 应用
|
||||||
|
path_permissions: 路径权限映射 {"/path": ["permission1", "permission2"]}
|
||||||
|
"""
|
||||||
|
super().__init__(app)
|
||||||
|
self.path_permissions = path_permissions or {}
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next) -> Response:
|
||||||
|
"""处理请求"""
|
||||||
|
path = request.url.path
|
||||||
|
user: Optional[UserAuth] = getattr(request.state, "user", None)
|
||||||
|
|
||||||
|
# 检查路径权限
|
||||||
|
for path_pattern, required_perms in self.path_permissions.items():
|
||||||
|
if path.startswith(path_pattern):
|
||||||
|
if not user:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=401,
|
||||||
|
content={
|
||||||
|
"code": "4010",
|
||||||
|
"message": "Authentication required",
|
||||||
|
"data": None
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if not user.is_superuser:
|
||||||
|
user_perms = set(user.permissions)
|
||||||
|
required = set(required_perms)
|
||||||
|
|
||||||
|
if not required.issubset(user_perms):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=403,
|
||||||
|
content={
|
||||||
|
"code": "4030",
|
||||||
|
"message": "Permission denied",
|
||||||
|
"data": {"required": list(required)}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
break
|
||||||
|
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""
|
||||||
|
基于 Redis 的速率限制中间件
|
||||||
|
|
||||||
|
为认证用户和匿名用户分别设置速率限制。
|
||||||
|
使用 Redis 实现分布式速率限制。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
app: ASGIApp,
|
||||||
|
authenticated_limit: int = 1000, # per minute
|
||||||
|
anonymous_limit: int = 60, # per minute
|
||||||
|
window_size: int = 60, # 1 minute window
|
||||||
|
):
|
||||||
|
super().__init__(app)
|
||||||
|
self.authenticated_limit = authenticated_limit
|
||||||
|
self.anonymous_limit = anonymous_limit
|
||||||
|
self.window_size = window_size
|
||||||
|
self._redis_client = None
|
||||||
|
self._enabled = False
|
||||||
|
|
||||||
|
async def _get_redis(self):
|
||||||
|
"""Lazy initialize Redis connection"""
|
||||||
|
if not self._enabled:
|
||||||
|
return None
|
||||||
|
if self._redis_client is None:
|
||||||
|
try:
|
||||||
|
import redis.asyncio as aioredis
|
||||||
|
from src.config.settings import REDIS_URL
|
||||||
|
self._redis_client = await aioredis.from_url(
|
||||||
|
REDIS_URL,
|
||||||
|
encoding="utf-8",
|
||||||
|
decode_responses=True
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Redis not available for rate limiting: {e}")
|
||||||
|
return None
|
||||||
|
return self._redis_client
|
||||||
|
|
||||||
|
async def _check_rate_limit(self, key: str, limit: int) -> tuple[bool, int, int]:
|
||||||
|
"""
|
||||||
|
检查速率限制
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(allowed, remaining, reset_after)
|
||||||
|
"""
|
||||||
|
redis = await self._get_redis()
|
||||||
|
if not redis:
|
||||||
|
# Redis unavailable, allow request
|
||||||
|
return True, limit, 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
current_time = int(time.time())
|
||||||
|
window_start = current_time - self.window_size
|
||||||
|
|
||||||
|
# Remove old entries
|
||||||
|
await redis.zremrangebyscore(key, 0, window_start)
|
||||||
|
|
||||||
|
# Count current requests in window
|
||||||
|
current_count = await redis.zcard(key)
|
||||||
|
|
||||||
|
if current_count >= limit:
|
||||||
|
# Rate limit exceeded
|
||||||
|
oldest = await redis.zrange(key, 0, 0, withscores=True)
|
||||||
|
reset_after = int(oldest[0][1]) + self.window_size - current_time
|
||||||
|
return False, 0, max(reset_after, 1)
|
||||||
|
|
||||||
|
# Add current request
|
||||||
|
await redis.zadd(key, {str(current_time): current_time})
|
||||||
|
await redis.expire(key, self.window_size)
|
||||||
|
|
||||||
|
remaining = limit - current_count - 1
|
||||||
|
return True, remaining, 0
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Rate limit check failed: {e}")
|
||||||
|
# Fail open - allow request if Redis fails
|
||||||
|
return True, limit, 0
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next) -> Response:
|
||||||
|
"""处理请求"""
|
||||||
|
# Check if Redis is enabled
|
||||||
|
if not self._enabled:
|
||||||
|
from src.config.settings import REDIS_ENABLED
|
||||||
|
self._enabled = REDIS_ENABLED
|
||||||
|
|
||||||
|
# Get user identifier
|
||||||
|
user: Optional[UserAuth] = getattr(request.state, "user", None)
|
||||||
|
if user:
|
||||||
|
client_id = f"ratelimit:auth:{user.id}"
|
||||||
|
limit = self.authenticated_limit
|
||||||
|
else:
|
||||||
|
# Use IP address for anonymous users
|
||||||
|
client_ip = request.headers.get("X-Forwarded-For", request.client.host)
|
||||||
|
client_id = f"ratelimit:anon:{client_ip}"
|
||||||
|
limit = self.anonymous_limit
|
||||||
|
|
||||||
|
# Check rate limit
|
||||||
|
allowed, remaining, reset_after = await self._check_rate_limit(client_id, limit)
|
||||||
|
|
||||||
|
# Process request
|
||||||
|
response = await call_next(request)
|
||||||
|
|
||||||
|
# Add rate limit headers
|
||||||
|
if self._enabled:
|
||||||
|
response.headers["X-RateLimit-Limit"] = str(limit)
|
||||||
|
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||||
|
if reset_after > 0:
|
||||||
|
response.headers["X-RateLimit-Reset-After"] = str(reset_after)
|
||||||
|
|
||||||
|
if not allowed:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=429,
|
||||||
|
content={
|
||||||
|
"code": "4290",
|
||||||
|
"message": "Rate limit exceeded",
|
||||||
|
"details": {
|
||||||
|
"limit": limit,
|
||||||
|
"reset_after": reset_after
|
||||||
|
}
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
"Retry-After": str(reset_after),
|
||||||
|
"X-RateLimit-Limit": str(limit),
|
||||||
|
"X-RateLimit-Remaining": "0"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
91
backend/src/auth/models.py
Normal file
91
backend/src/auth/models.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
"""
|
||||||
|
认证模型定义
|
||||||
|
|
||||||
|
包含用户认证相关的 Pydantic 模型。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from datetime import datetime
|
||||||
|
from pydantic import BaseModel, Field, ConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class TokenData(BaseModel):
|
||||||
|
"""Token 数据模型"""
|
||||||
|
access_token: str
|
||||||
|
refresh_token: str
|
||||||
|
token_type: str = "bearer"
|
||||||
|
expires_in: int = Field(description="Access token 过期时间(秒)")
|
||||||
|
refresh_expires_in: int = Field(description="Refresh token 过期时间(秒)")
|
||||||
|
|
||||||
|
|
||||||
|
class TokenPayload(BaseModel):
|
||||||
|
"""JWT Token Payload"""
|
||||||
|
sub: str = Field(description="用户 ID")
|
||||||
|
exp: int = Field(description="过期时间戳")
|
||||||
|
iat: int = Field(description="签发时间戳")
|
||||||
|
type: str = Field(default="access", description="Token 类型: access 或 refresh")
|
||||||
|
scopes: List[str] = Field(default=[], description="权限范围")
|
||||||
|
jti: Optional[str] = Field(default=None, description="JWT ID,用于 token 撤销")
|
||||||
|
sid: Optional[str] = Field(default=None, description="会话 ID")
|
||||||
|
sfid: Optional[str] = Field(default=None, description="会话族 ID")
|
||||||
|
|
||||||
|
|
||||||
|
class TokenPair(BaseModel):
|
||||||
|
"""Token 对(access + refresh)"""
|
||||||
|
access_token: str
|
||||||
|
refresh_token: str
|
||||||
|
token_type: str = "bearer"
|
||||||
|
session_id: Optional[str] = None
|
||||||
|
session_family_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshTokenRequest(BaseModel):
|
||||||
|
"""刷新 Token 请求"""
|
||||||
|
refresh_token: str
|
||||||
|
|
||||||
|
|
||||||
|
class UserAuth(BaseModel):
|
||||||
|
"""用户认证信息"""
|
||||||
|
id: str
|
||||||
|
username: str
|
||||||
|
email: Optional[str] = None
|
||||||
|
avatar_url: Optional[str] = None
|
||||||
|
is_active: bool = True
|
||||||
|
is_superuser: bool = False
|
||||||
|
permissions: List[str] = []
|
||||||
|
roles: List[str] = []
|
||||||
|
created_at: Optional[float] = None
|
||||||
|
last_login: Optional[float] = None
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
|
||||||
|
class UserLoginRequest(BaseModel):
|
||||||
|
"""用户登录请求"""
|
||||||
|
username: str
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class UserRegisterRequest(BaseModel):
|
||||||
|
"""用户注册请求"""
|
||||||
|
username: str
|
||||||
|
email: str
|
||||||
|
password: str
|
||||||
|
password_confirm: str
|
||||||
|
|
||||||
|
|
||||||
|
class UserPasswordChangeRequest(BaseModel):
|
||||||
|
"""用户修改密码请求"""
|
||||||
|
current_password: str
|
||||||
|
new_password: str
|
||||||
|
new_password_confirm: str
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2UserInfo(BaseModel):
|
||||||
|
"""OAuth2 用户信息"""
|
||||||
|
provider: str
|
||||||
|
provider_user_id: str
|
||||||
|
email: Optional[str] = None
|
||||||
|
username: Optional[str] = None
|
||||||
|
avatar_url: Optional[str] = None
|
||||||
|
raw_data: Dict[str, Any] = {}
|
||||||
16
backend/src/cache/__init__.py
vendored
Normal file
16
backend/src/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""
|
||||||
|
缓存模块
|
||||||
|
|
||||||
|
提供多级缓存支持:
|
||||||
|
- Redis 缓存
|
||||||
|
- 内存缓存(开发环境)
|
||||||
|
- 缓存装饰器
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.cache.redis_cache import RedisCache, cache_result, cache_with_ttl
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"RedisCache",
|
||||||
|
"cache_result",
|
||||||
|
"cache_with_ttl",
|
||||||
|
]
|
||||||
328
backend/src/cache/redis_cache.py
vendored
Normal file
328
backend/src/cache/redis_cache.py
vendored
Normal file
@@ -0,0 +1,328 @@
|
|||||||
|
"""
|
||||||
|
Redis 缓存实现
|
||||||
|
|
||||||
|
提供缓存装饰器和缓存管理功能。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import pickle
|
||||||
|
from typing import Optional, Callable, Any, Union
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
import redis.asyncio as redis
|
||||||
|
from redis.asyncio.client import Redis
|
||||||
|
|
||||||
|
from src.config.settings import REDIS_URL, REDIS_ENABLED
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RedisCache:
|
||||||
|
"""
|
||||||
|
Redis 缓存管理器
|
||||||
|
|
||||||
|
提供统一的缓存接口。
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instance: Optional["RedisCache"] = None
|
||||||
|
|
||||||
|
def __new__(cls, *args, **kwargs):
|
||||||
|
"""单例模式"""
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self, redis_url: Optional[str] = None):
|
||||||
|
if hasattr(self, "_initialized"):
|
||||||
|
return
|
||||||
|
|
||||||
|
self._redis_url = redis_url or REDIS_URL
|
||||||
|
self._redis: Optional[Redis] = None
|
||||||
|
self._enabled = REDIS_ENABLED
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
"""建立 Redis 连接"""
|
||||||
|
if not self._enabled:
|
||||||
|
logger.info("Redis cache is disabled")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._redis = await redis.from_url(
|
||||||
|
self._redis_url,
|
||||||
|
encoding="utf-8",
|
||||||
|
decode_responses=False # 支持二进制数据
|
||||||
|
)
|
||||||
|
await self._redis.ping()
|
||||||
|
logger.info(f"Connected to Redis cache at {self._redis_url}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to connect to Redis cache: {e}")
|
||||||
|
self._enabled = False
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
"""断开 Redis 连接"""
|
||||||
|
if self._redis:
|
||||||
|
await self._redis.close()
|
||||||
|
self._redis = None
|
||||||
|
logger.info("Disconnected from Redis cache")
|
||||||
|
|
||||||
|
def _make_key(self, prefix: str, *args, **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
生成缓存键
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix: 键前缀
|
||||||
|
*args: 位置参数
|
||||||
|
**kwargs: 关键字参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
缓存键
|
||||||
|
"""
|
||||||
|
key_parts = [prefix]
|
||||||
|
|
||||||
|
for arg in args:
|
||||||
|
key_parts.append(str(arg))
|
||||||
|
|
||||||
|
for k, v in sorted(kwargs.items()):
|
||||||
|
key_parts.append(f"{k}:{v}")
|
||||||
|
|
||||||
|
raw_key = "|".join(key_parts)
|
||||||
|
# 使用哈希确保键长度合理
|
||||||
|
return f"cache:{prefix}:{hashlib.md5(raw_key.encode()).hexdigest()}"
|
||||||
|
|
||||||
|
async def get(self, key: str) -> Optional[Any]:
|
||||||
|
"""
|
||||||
|
获取缓存值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: 缓存键
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
缓存值或 None
|
||||||
|
"""
|
||||||
|
if not self._enabled or not self._redis:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = await self._redis.get(key)
|
||||||
|
if data:
|
||||||
|
return pickle.loads(data)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Cache get error: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def set(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
value: Any,
|
||||||
|
ttl: Optional[Union[int, timedelta]] = None
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
设置缓存值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: 缓存键
|
||||||
|
value: 缓存值
|
||||||
|
ttl: 过期时间(秒或 timedelta)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功
|
||||||
|
"""
|
||||||
|
if not self._enabled or not self._redis:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
expire = None
|
||||||
|
if isinstance(ttl, timedelta):
|
||||||
|
expire = int(ttl.total_seconds())
|
||||||
|
elif isinstance(ttl, int):
|
||||||
|
expire = ttl
|
||||||
|
|
||||||
|
data = pickle.dumps(value)
|
||||||
|
await self._redis.set(key, data, ex=expire)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Cache set error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def delete(self, key: str) -> bool:
|
||||||
|
"""
|
||||||
|
删除缓存值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: 缓存键
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功
|
||||||
|
"""
|
||||||
|
if not self._enabled or not self._redis:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._redis.delete(key)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Cache delete error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def delete_pattern(self, pattern: str) -> int:
|
||||||
|
"""
|
||||||
|
删除匹配模式的缓存
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: 匹配模式
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
删除的键数量
|
||||||
|
"""
|
||||||
|
if not self._enabled or not self._redis:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
keys = await self._redis.keys(pattern)
|
||||||
|
if keys:
|
||||||
|
return await self._redis.delete(*keys)
|
||||||
|
return 0
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Cache delete pattern error: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def exists(self, key: str) -> bool:
|
||||||
|
"""
|
||||||
|
检查键是否存在
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: 缓存键
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否存在
|
||||||
|
"""
|
||||||
|
if not self._enabled or not self._redis:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self._redis.exists(key) > 0
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Cache exists error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def ttl(self, key: str) -> int:
|
||||||
|
"""
|
||||||
|
获取键的剩余生存时间
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: 缓存键
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
剩余秒数,-1 表示永不过期,-2 表示不存在
|
||||||
|
"""
|
||||||
|
if not self._enabled or not self._redis:
|
||||||
|
return -2
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self._redis.ttl(key)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Cache ttl error: {e}")
|
||||||
|
return -2
|
||||||
|
|
||||||
|
async def clear(self) -> bool:
|
||||||
|
"""
|
||||||
|
清空所有缓存(谨慎使用)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否成功
|
||||||
|
"""
|
||||||
|
if not self._enabled or not self._redis:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._redis.flushdb()
|
||||||
|
logger.warning("Cache cleared")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Cache clear error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# 全局缓存实例
|
||||||
|
cache = RedisCache()
|
||||||
|
|
||||||
|
|
||||||
|
def cache_result(
|
||||||
|
prefix: str,
|
||||||
|
ttl: Optional[Union[int, timedelta]] = 300,
|
||||||
|
key_func: Optional[Callable] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
缓存装饰器
|
||||||
|
|
||||||
|
自动缓存函数结果。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix: 缓存键前缀
|
||||||
|
ttl: 过期时间(秒)
|
||||||
|
key_func: 自定义键生成函数
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@cache_result("user", ttl=300)
|
||||||
|
async def get_user(user_id: str):
|
||||||
|
return await fetch_user(user_id)
|
||||||
|
"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@functools.wraps(func)
|
||||||
|
async def async_wrapper(*args, **kwargs):
|
||||||
|
# 生成缓存键
|
||||||
|
if key_func:
|
||||||
|
cache_key = key_func(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
cache_key = cache._make_key(prefix, *args, **kwargs)
|
||||||
|
|
||||||
|
# 尝试从缓存获取
|
||||||
|
result = await cache.get(cache_key)
|
||||||
|
if result is not None:
|
||||||
|
logger.debug(f"Cache hit: {cache_key}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 执行函数
|
||||||
|
result = await func(*args, **kwargs)
|
||||||
|
|
||||||
|
# 存入缓存
|
||||||
|
await cache.set(cache_key, result, ttl)
|
||||||
|
logger.debug(f"Cache set: {cache_key}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def sync_wrapper(*args, **kwargs):
|
||||||
|
# 同步函数不支持缓存
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def cache_with_ttl(ttl: int = 300):
|
||||||
|
"""
|
||||||
|
简化版缓存装饰器
|
||||||
|
|
||||||
|
使用函数名作为前缀。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ttl: 过期时间(秒)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@cache_with_ttl(300)
|
||||||
|
async def get_user(user_id: str):
|
||||||
|
return await fetch_user(user_id)
|
||||||
|
"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
return cache_result(prefix=func.__name__, ttl=ttl)(func)
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
# 导入 asyncio 用于检查协程
|
||||||
|
import asyncio
|
||||||
198
backend/src/config/database.py
Normal file
198
backend/src/config/database.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
from sqlmodel import create_engine, SQLModel, Session
|
||||||
|
from sqlalchemy.pool import StaticPool, QueuePool
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy import event
|
||||||
|
from src.config.settings import DB_PATH, DATABASE_URL
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Slow query threshold in seconds
|
||||||
|
SLOW_QUERY_THRESHOLD = float(os.getenv("SLOW_QUERY_THRESHOLD", "1.0"))
|
||||||
|
|
||||||
|
# Connection pool configuration from environment variables
|
||||||
|
POOL_SIZE = int(os.getenv("DB_POOL_SIZE", "20"))
|
||||||
|
MAX_OVERFLOW = int(os.getenv("DB_MAX_OVERFLOW", "10"))
|
||||||
|
POOL_TIMEOUT = int(os.getenv("DB_POOL_TIMEOUT", "30"))
|
||||||
|
POOL_RECYCLE = int(os.getenv("DB_POOL_RECYCLE", "3600"))
|
||||||
|
POOL_PRE_PING = os.getenv("DB_POOL_PRE_PING", "true").lower() == "true"
|
||||||
|
|
||||||
|
# Create the database engine with optimized connection pooling
|
||||||
|
if DATABASE_URL:
|
||||||
|
# PostgreSQL Configuration with configurable connection pool
|
||||||
|
engine = create_engine(
|
||||||
|
DATABASE_URL,
|
||||||
|
echo=False,
|
||||||
|
poolclass=QueuePool,
|
||||||
|
pool_size=POOL_SIZE,
|
||||||
|
max_overflow=MAX_OVERFLOW,
|
||||||
|
pool_timeout=POOL_TIMEOUT,
|
||||||
|
pool_recycle=POOL_RECYCLE,
|
||||||
|
pool_pre_ping=POOL_PRE_PING,
|
||||||
|
# Additional optimizations
|
||||||
|
connect_args={
|
||||||
|
"connect_timeout": 10,
|
||||||
|
"options": "-c statement_timeout=120000" # 120 second statement timeout for AI operations
|
||||||
|
}
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"PostgreSQL engine created with pool_size={POOL_SIZE}, "
|
||||||
|
f"max_overflow={MAX_OVERFLOW}, pool_timeout={POOL_TIMEOUT}s, "
|
||||||
|
f"pool_recycle={POOL_RECYCLE}s, pool_pre_ping={POOL_PRE_PING}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# SQLite Configuration (Fallback)
|
||||||
|
engine = create_engine(
|
||||||
|
f"sqlite:///{DB_PATH}",
|
||||||
|
echo=False,
|
||||||
|
connect_args={
|
||||||
|
"check_same_thread": False,
|
||||||
|
"timeout": 30 # 30 second busy timeout
|
||||||
|
},
|
||||||
|
poolclass=StaticPool, # Use StaticPool for SQLite
|
||||||
|
)
|
||||||
|
logger.info(f"SQLite engine created with StaticPool at {DB_PATH}")
|
||||||
|
|
||||||
|
|
||||||
|
# Add slow query logging
|
||||||
|
@event.listens_for(engine, "before_cursor_execute")
|
||||||
|
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||||||
|
""" 记录 query start time."""
|
||||||
|
conn.info.setdefault("query_start_time", []).append(time.time())
|
||||||
|
|
||||||
|
|
||||||
|
@event.listens_for(engine, "after_cursor_execute")
|
||||||
|
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||||||
|
""" 日志 slow queries."""
|
||||||
|
total_time = time.time() - conn.info["query_start_time"].pop()
|
||||||
|
|
||||||
|
if total_time > SLOW_QUERY_THRESHOLD:
|
||||||
|
logger.warning(
|
||||||
|
f"Slow query detected (took {total_time:.2f}s): {statement[:200]}",
|
||||||
|
extra={
|
||||||
|
"query_time": total_time,
|
||||||
|
"query": statement[:500],
|
||||||
|
"parameters": str(parameters)[:200] if parameters else None
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def init_db():
|
||||||
|
""" Initialize the database tables."""
|
||||||
|
from sqlalchemy.exc import OperationalError
|
||||||
|
|
||||||
|
# Try to create all tables
|
||||||
|
try:
|
||||||
|
SQLModel.metadata.create_all(engine)
|
||||||
|
except OperationalError as e:
|
||||||
|
if "already exists" in str(e):
|
||||||
|
# Table or index already exists, skip creation
|
||||||
|
logger.warning(f"Skipping table/index creation: {e}")
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def get_session():
|
||||||
|
""" 依赖 for getting a database session.
|
||||||
|
Ensures proper cleanup and connection management.
|
||||||
|
"""
|
||||||
|
with Session(engine) as session:
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
# 会话 is automatically closed by context manager
|
||||||
|
# This ensures connections are returned to the pool
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def get_pool_status() -> dict:
|
||||||
|
""" Get current connection pool status for monitoring.
|
||||||
|
Returns pool statistics including size, checked out connections, etc.
|
||||||
|
"""
|
||||||
|
pool = engine.pool
|
||||||
|
|
||||||
|
status = {
|
||||||
|
"pool_size": getattr(pool, "size", lambda: 0)(),
|
||||||
|
"checked_out": getattr(pool, "checkedout", lambda: 0)(),
|
||||||
|
"overflow": getattr(pool, "overflow", lambda: 0)(),
|
||||||
|
"checked_in": getattr(pool, "checkedin", lambda: 0)(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add pool-specific attributes if available
|
||||||
|
if hasattr(pool, "_pool"):
|
||||||
|
status["available"] = pool._pool.qsize()
|
||||||
|
|
||||||
|
return status
|
||||||
|
|
||||||
|
|
||||||
|
def check_database_health() -> tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
Check database connectivity and health.
|
||||||
|
Returns (is_healthy, message).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with Session(engine) as session:
|
||||||
|
# Simple connectivity probe without model dependency.
|
||||||
|
session.exec(text("SELECT 1"))
|
||||||
|
return True, "Database connection healthy"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Database health check failed: {e}")
|
||||||
|
return False, f"Database connection failed: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
def check_pool_health() -> dict:
|
||||||
|
"""
|
||||||
|
Comprehensive pool health check with recommendations.
|
||||||
|
Returns detailed pool status and health indicators.
|
||||||
|
"""
|
||||||
|
status = get_pool_status()
|
||||||
|
pool = engine.pool
|
||||||
|
|
||||||
|
# Calculate health metrics
|
||||||
|
total_connections = status.get("pool_size", 0) + status.get("overflow", 0)
|
||||||
|
checked_out = status.get("checked_out", 0)
|
||||||
|
available = status.get("available", 0)
|
||||||
|
|
||||||
|
# Health checks
|
||||||
|
health = {
|
||||||
|
"status": "healthy",
|
||||||
|
"checks": {},
|
||||||
|
"recommendations": [],
|
||||||
|
"pool_status": status
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check 1: Connection exhaustion
|
||||||
|
if checked_out >= total_connections * 0.9:
|
||||||
|
health["checks"]["exhaustion"] = "critical"
|
||||||
|
health["status"] = "critical"
|
||||||
|
health["recommendations"].append(
|
||||||
|
f"Pool near exhaustion: {checked_out}/{total_connections} connections in use. "
|
||||||
|
f"Consider increasing DB_POOL_SIZE or DB_MAX_OVERFLOW."
|
||||||
|
)
|
||||||
|
elif checked_out >= total_connections * 0.75:
|
||||||
|
health["checks"]["exhaustion"] = "warning"
|
||||||
|
health["status"] = "warning"
|
||||||
|
health["recommendations"].append(
|
||||||
|
f"Pool usage high: {checked_out}/{total_connections} connections in use."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
health["checks"]["exhaustion"] = "healthy"
|
||||||
|
|
||||||
|
# Check 2: Available connections
|
||||||
|
if available == 0 and total_connections > 0:
|
||||||
|
health["checks"]["availability"] = "warning"
|
||||||
|
if health["status"] == "healthy":
|
||||||
|
health["status"] = "warning"
|
||||||
|
else:
|
||||||
|
health["checks"]["availability"] = "healthy"
|
||||||
|
|
||||||
|
# Check 3: Pool size configuration
|
||||||
|
if total_connections < 5:
|
||||||
|
health["recommendations"].append(
|
||||||
|
"Pool size may be too small for production workloads. "
|
||||||
|
"Consider DB_POOL_SIZE >= 10"
|
||||||
|
)
|
||||||
|
|
||||||
|
return health
|
||||||
217
backend/src/config/database_async.py
Normal file
217
backend/src/config/database_async.py
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
"""
|
||||||
|
异步数据库配置模块
|
||||||
|
|
||||||
|
提供 SQLAlchemy 2.0 AsyncSession 支持,用于完全异步的数据库操作。
|
||||||
|
支持 PostgreSQL (asyncpg) 和 SQLite (aiosqlite)。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.ext.asyncio import (
|
||||||
|
create_async_engine,
|
||||||
|
AsyncSession,
|
||||||
|
async_sessionmaker,
|
||||||
|
AsyncEngine,
|
||||||
|
)
|
||||||
|
from sqlalchemy.pool import NullPool
|
||||||
|
from sqlmodel import SQLModel
|
||||||
|
|
||||||
|
from src.config.settings import DB_PATH, DATABASE_URL
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Slow query threshold in seconds
|
||||||
|
SLOW_QUERY_THRESHOLD = float(os.getenv("SLOW_QUERY_THRESHOLD", "1.0"))
|
||||||
|
|
||||||
|
# Connection pool configuration from environment variables
|
||||||
|
POOL_SIZE = int(os.getenv("DB_POOL_SIZE", "20"))
|
||||||
|
MAX_OVERFLOW = int(os.getenv("DB_MAX_OVERFLOW", "10"))
|
||||||
|
POOL_TIMEOUT = int(os.getenv("DB_POOL_TIMEOUT", "30"))
|
||||||
|
POOL_RECYCLE = int(os.getenv("DB_POOL_RECYCLE", "3600"))
|
||||||
|
POOL_PRE_PING = os.getenv("DB_POOL_PRE_PING", "true").lower() == "true"
|
||||||
|
|
||||||
|
# Global engine instance
|
||||||
|
async_engine: AsyncEngine | None = None
|
||||||
|
async_session_maker: async_sessionmaker[AsyncSession] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_async_url() -> str:
|
||||||
|
"""Generate async database URL from settings."""
|
||||||
|
if DATABASE_URL:
|
||||||
|
# Convert PostgreSQL URL to async version
|
||||||
|
if DATABASE_URL.startswith("postgresql://"):
|
||||||
|
return DATABASE_URL.replace("postgresql://", "postgresql+asyncpg://", 1)
|
||||||
|
elif DATABASE_URL.startswith("postgresql+psycopg2://"):
|
||||||
|
return DATABASE_URL.replace("postgresql+psycopg2://", "postgresql+asyncpg://", 1)
|
||||||
|
return DATABASE_URL
|
||||||
|
else:
|
||||||
|
# SQLite async URL
|
||||||
|
return f"sqlite+aiosqlite:///{DB_PATH}"
|
||||||
|
|
||||||
|
|
||||||
|
def init_async_engine() -> AsyncEngine:
|
||||||
|
"""Initialize async database engine with optimized connection pooling."""
|
||||||
|
global async_engine
|
||||||
|
|
||||||
|
if async_engine is not None:
|
||||||
|
return async_engine
|
||||||
|
|
||||||
|
async_url = _get_async_url()
|
||||||
|
|
||||||
|
if "postgresql" in async_url:
|
||||||
|
# PostgreSQL async configuration
|
||||||
|
async_engine = create_async_engine(
|
||||||
|
async_url,
|
||||||
|
echo=False,
|
||||||
|
pool_size=POOL_SIZE,
|
||||||
|
max_overflow=MAX_OVERFLOW,
|
||||||
|
pool_timeout=POOL_TIMEOUT,
|
||||||
|
pool_recycle=POOL_RECYCLE,
|
||||||
|
pool_pre_ping=POOL_PRE_PING,
|
||||||
|
connect_args={
|
||||||
|
"timeout": 10,
|
||||||
|
"command_timeout": 30,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Async PostgreSQL engine created with pool_size={POOL_SIZE}, "
|
||||||
|
f"max_overflow={MAX_OVERFLOW}, pool_timeout={POOL_TIMEOUT}s"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# SQLite async configuration
|
||||||
|
async_engine = create_async_engine(
|
||||||
|
async_url,
|
||||||
|
echo=False,
|
||||||
|
poolclass=NullPool, # SQLite doesn't support connection pooling well
|
||||||
|
connect_args={
|
||||||
|
"timeout": 30,
|
||||||
|
"check_same_thread": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
logger.info(f"Async SQLite engine created at {DB_PATH}")
|
||||||
|
|
||||||
|
return async_engine
|
||||||
|
|
||||||
|
|
||||||
|
def init_async_session_maker() -> async_sessionmaker[AsyncSession]:
|
||||||
|
"""Initialize async session maker."""
|
||||||
|
global async_session_maker
|
||||||
|
|
||||||
|
if async_session_maker is not None:
|
||||||
|
return async_session_maker
|
||||||
|
|
||||||
|
engine = init_async_engine()
|
||||||
|
|
||||||
|
async_session_maker = async_sessionmaker(
|
||||||
|
engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
autocommit=False,
|
||||||
|
autoflush=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return async_session_maker
|
||||||
|
|
||||||
|
|
||||||
|
async def init_db_async():
|
||||||
|
"""Initialize database tables asynchronously."""
|
||||||
|
engine = init_async_engine()
|
||||||
|
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(SQLModel.metadata.create_all)
|
||||||
|
|
||||||
|
logger.info("Database tables initialized (async)")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""
|
||||||
|
Dependency for getting async database session.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@router.get("/items")
|
||||||
|
async def get_items(session: AsyncSession = Depends(get_async_session)):
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
session_maker = init_async_session_maker()
|
||||||
|
|
||||||
|
async with session_maker() as session:
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_async_session_context() -> AsyncSession:
|
||||||
|
"""
|
||||||
|
Get async session as context manager.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
async with get_async_session_context() as session:
|
||||||
|
result = await session.execute(...)
|
||||||
|
"""
|
||||||
|
session_maker = init_async_session_maker()
|
||||||
|
return session_maker()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_pool_status_async() -> dict:
|
||||||
|
"""
|
||||||
|
Get current connection pool status for monitoring.
|
||||||
|
Returns pool statistics including size, checked out connections, etc.
|
||||||
|
"""
|
||||||
|
engine = init_async_engine()
|
||||||
|
pool = engine.pool
|
||||||
|
|
||||||
|
status = {
|
||||||
|
"pool_size": getattr(pool, "size", lambda: 0)(),
|
||||||
|
"checked_out": getattr(pool, "checkedout", lambda: 0)(),
|
||||||
|
"overflow": getattr(pool, "overflow", lambda: 0)(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return status
|
||||||
|
|
||||||
|
|
||||||
|
async def check_database_health_async() -> tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
Check database connectivity and health asynchronously.
|
||||||
|
Returns (is_healthy, message).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
engine = init_async_engine()
|
||||||
|
async with engine.connect() as conn:
|
||||||
|
start_time = time.time()
|
||||||
|
await conn.execute(text("SELECT 1"))
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
|
return True, f"Database connection healthy ({elapsed:.3f}s)"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Database health check failed: {e}")
|
||||||
|
return False, f"Database connection failed: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
async def close_async_engine():
|
||||||
|
"""Close async engine and cleanup connections."""
|
||||||
|
global async_engine, async_session_maker
|
||||||
|
|
||||||
|
if async_engine is not None:
|
||||||
|
await async_engine.dispose()
|
||||||
|
async_engine = None
|
||||||
|
async_session_maker = None
|
||||||
|
logger.info("Async database engine closed")
|
||||||
|
|
||||||
|
|
||||||
|
# Import guards for sync operations
|
||||||
|
async def migrate_sync_to_async():
|
||||||
|
"""
|
||||||
|
Migration helper to ensure both sync and async engines are initialized.
|
||||||
|
During transition period, both can coexist.
|
||||||
|
"""
|
||||||
|
# Initialize async engine
|
||||||
|
init_async_engine()
|
||||||
|
init_async_session_maker()
|
||||||
|
|
||||||
|
# Sync engine is already initialized in database.py
|
||||||
|
logger.info("Both sync and async database engines are ready")
|
||||||
465
backend/src/config/generation_options.json
Normal file
465
backend/src/config/generation_options.json
Normal file
@@ -0,0 +1,465 @@
|
|||||||
|
{
|
||||||
|
"image": {
|
||||||
|
"aspectRatios": {
|
||||||
|
"label": "比例",
|
||||||
|
"options": [
|
||||||
|
{ "value": "1:1", "label": "1:1 (正方形)" },
|
||||||
|
{ "value": "3:4", "label": "3:4 (竖版)" },
|
||||||
|
{ "value": "4:3", "label": "4:3 (横版)" },
|
||||||
|
{ "value": "9:16", "label": "9:16 (手机竖版)" },
|
||||||
|
{ "value": "16:9", "label": "16:9 (宽屏)" }
|
||||||
|
],
|
||||||
|
"default": "16:9"
|
||||||
|
},
|
||||||
|
"resolutions": {
|
||||||
|
"label": "分辨率",
|
||||||
|
"options": [
|
||||||
|
{ "value": "1K", "label": "1K (标准)", "description": "1280x720 或 1024x1024" },
|
||||||
|
{ "value": "2K", "label": "2K (高清)", "description": "2560x1440 或 2048x2048" },
|
||||||
|
{ "value": "4K", "label": "4K (超高清)", "description": "3840x2160 或 4096x4096" }
|
||||||
|
],
|
||||||
|
"default": "1K"
|
||||||
|
},
|
||||||
|
"counts": {
|
||||||
|
"label": "生成数量",
|
||||||
|
"options": [
|
||||||
|
{ "value": 1, "label": "1张" },
|
||||||
|
{ "value": 2, "label": "2张" },
|
||||||
|
{ "value": 3, "label": "3张" },
|
||||||
|
{ "value": 4, "label": "4张" }
|
||||||
|
],
|
||||||
|
"default": 1,
|
||||||
|
"min": 1,
|
||||||
|
"max": 4
|
||||||
|
},
|
||||||
|
"templates": {
|
||||||
|
"label": "模板类型",
|
||||||
|
"options": [
|
||||||
|
{ "value": "general", "label": "通用", "description": "通用图片生成" },
|
||||||
|
{ "value": "character_white_bg", "label": "角色白底图", "description": "角色图片,白色背景" },
|
||||||
|
{ "value": "character_three_view", "label": "角色三视图", "description": "角色正面、侧面、背面视图" },
|
||||||
|
{ "value": "storyboard_integrated", "label": "分镜出图", "description": "基于分镜生成图片" }
|
||||||
|
],
|
||||||
|
"default": "general"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"video": {
|
||||||
|
"aspectRatios": {
|
||||||
|
"label": "比例",
|
||||||
|
"options": [
|
||||||
|
{ "value": "1:1", "label": "1:1 (正方形)" },
|
||||||
|
{ "value": "3:4", "label": "3:4 (竖版)" },
|
||||||
|
{ "value": "4:3", "label": "4:3 (横版)" },
|
||||||
|
{ "value": "9:16", "label": "9:16 (手机竖版)" },
|
||||||
|
{ "value": "16:9", "label": "16:9 (宽屏)" }
|
||||||
|
],
|
||||||
|
"default": "16:9"
|
||||||
|
},
|
||||||
|
"resolutions": {
|
||||||
|
"label": "分辨率",
|
||||||
|
"options": [
|
||||||
|
{ "value": "720p", "label": "720P (1280x720)", "width": 1280, "height": 720 },
|
||||||
|
{ "value": "1080p", "label": "1080P (1920x1080)", "width": 1920, "height": 1080 }
|
||||||
|
],
|
||||||
|
"default": "1080p"
|
||||||
|
},
|
||||||
|
"durations": {
|
||||||
|
"label": "时长",
|
||||||
|
"options": [
|
||||||
|
{ "value": "2s", "label": "2秒" },
|
||||||
|
{ "value": "3s", "label": "3秒" },
|
||||||
|
{ "value": "4s", "label": "4秒" },
|
||||||
|
{ "value": "5s", "label": "5秒" },
|
||||||
|
{ "value": "6s", "label": "6秒" },
|
||||||
|
{ "value": "7s", "label": "7秒" },
|
||||||
|
{ "value": "8s", "label": "8秒" },
|
||||||
|
{ "value": "10s", "label": "10秒" }
|
||||||
|
],
|
||||||
|
"default": "5s",
|
||||||
|
"description": "视频时长(秒)"
|
||||||
|
},
|
||||||
|
"counts": {
|
||||||
|
"label": "生成数量",
|
||||||
|
"options": [
|
||||||
|
{ "value": 1, "label": "1个" },
|
||||||
|
{ "value": 2, "label": "2个" },
|
||||||
|
{ "value": 3, "label": "3个" },
|
||||||
|
{ "value": 4, "label": "4个" }
|
||||||
|
],
|
||||||
|
"default": 1,
|
||||||
|
"min": 1,
|
||||||
|
"max": 4
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"audio": {
|
||||||
|
"durations": {
|
||||||
|
"label": "时长",
|
||||||
|
"options": [
|
||||||
|
{ "value": "5s", "label": "5秒" },
|
||||||
|
{ "value": "10s", "label": "10秒" },
|
||||||
|
{ "value": "30s", "label": "30秒" },
|
||||||
|
{ "value": "60s", "label": "1分钟" }
|
||||||
|
],
|
||||||
|
"default": "10s"
|
||||||
|
},
|
||||||
|
"formats": {
|
||||||
|
"label": "音频格式",
|
||||||
|
"options": [
|
||||||
|
{ "value": "mp3", "label": "MP3" },
|
||||||
|
{ "value": "wav", "label": "WAV" }
|
||||||
|
],
|
||||||
|
"default": "mp3"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"music": {
|
||||||
|
"formats": {
|
||||||
|
"label": "音频格式",
|
||||||
|
"options": [
|
||||||
|
{ "value": "mp3", "label": "MP3" },
|
||||||
|
{ "value": "wav", "label": "WAV" },
|
||||||
|
{ "value": "pcm", "label": "PCM" }
|
||||||
|
],
|
||||||
|
"default": "mp3"
|
||||||
|
},
|
||||||
|
"sampleRates": {
|
||||||
|
"label": "采样率",
|
||||||
|
"options": [
|
||||||
|
{ "value": 16000, "label": "16kHz" },
|
||||||
|
{ "value": 24000, "label": "24kHz" },
|
||||||
|
{ "value": 32000, "label": "32kHz" },
|
||||||
|
{ "value": 44100, "label": "44.1kHz" }
|
||||||
|
],
|
||||||
|
"default": 44100
|
||||||
|
},
|
||||||
|
"bitrates": {
|
||||||
|
"label": "码率",
|
||||||
|
"options": [
|
||||||
|
{ "value": 32000, "label": "32kbps" },
|
||||||
|
{ "value": 64000, "label": "64kbps" },
|
||||||
|
{ "value": 128000, "label": "128kbps" },
|
||||||
|
{ "value": 256000, "label": "256kbps" }
|
||||||
|
],
|
||||||
|
"default": 256000
|
||||||
|
},
|
||||||
|
"outputFormats": {
|
||||||
|
"label": "返回格式",
|
||||||
|
"options": [
|
||||||
|
{ "value": "url", "label": "URL (24小时有效)" },
|
||||||
|
{ "value": "hex", "label": "HEX" }
|
||||||
|
],
|
||||||
|
"default": "url"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"script": {
|
||||||
|
"movieTones": {
|
||||||
|
"label": "整体基调",
|
||||||
|
"options": [
|
||||||
|
{ "value": "悬疑/惊悚", "label": "悬疑/惊悚 (Suspense/Thriller)", "en": "Suspense/Thriller" },
|
||||||
|
{ "value": "古装/权谋", "label": "古装/权谋 (Historical/Political)", "en": "Historical/Political" },
|
||||||
|
{ "value": "现代/都市", "label": "现代/都市 (Modern/Urban)", "en": "Modern/Urban" },
|
||||||
|
{ "value": "科幻/未来", "label": "科幻/未来 (Sci-Fi/Future)", "en": "Sci-Fi/Future" },
|
||||||
|
{ "value": "喜剧/荒诞", "label": "喜剧/荒诞 (Comedy/Absurd)", "en": "Comedy/Absurd" },
|
||||||
|
{ "value": "动作/犯罪", "label": "动作/犯罪 (Action/Crime)", "en": "Action/Crime" },
|
||||||
|
{ "value": "爱情/治愈", "label": "爱情/治愈 (Romance/Healing)", "en": "Romance/Healing" },
|
||||||
|
{ "value": "奇幻/仙侠", "label": "奇幻/仙侠 (Fantasy/Xianxia)", "en": "Fantasy/Xianxia" },
|
||||||
|
{ "value": "现实/人文", "label": "现实/人文 (Realistic/Humanistic)", "en": "Realistic/Humanistic" }
|
||||||
|
],
|
||||||
|
"default": "现代/都市"
|
||||||
|
},
|
||||||
|
"targetAudiences": {
|
||||||
|
"label": "目标受众",
|
||||||
|
"options": [
|
||||||
|
{ "value": "全年龄段", "label": "全年龄段 (All Ages)", "en": "All Ages" },
|
||||||
|
{ "value": "青少年", "label": "青少年 (Teenagers)", "en": "Teenagers" },
|
||||||
|
{ "value": "年轻女性", "label": "年轻女性 (Young Women)", "en": "Young Women" },
|
||||||
|
{ "value": "年轻男性", "label": "年轻男性 (Young Men)", "en": "Young Men" },
|
||||||
|
{ "value": "成年观众", "label": "成年观众 (Adults)", "en": "Adults" },
|
||||||
|
{ "value": "家庭观众", "label": "家庭观众 (Family)", "en": "Family" },
|
||||||
|
{ "value": "资深影迷", "label": "资深影迷 (Cinephiles)", "en": "Cinephiles" }
|
||||||
|
],
|
||||||
|
"default": "全年龄段"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"director": {
|
||||||
|
"narrativeStyles": {
|
||||||
|
"label": "叙事手法",
|
||||||
|
"options": [
|
||||||
|
{ "value": "线性叙事", "label": "线性叙事 (Linear)", "en": "Linear Narrative" },
|
||||||
|
{ "value": "非线性/插叙", "label": "非线性/插叙 (Non-linear)", "en": "Non-linear Narrative" },
|
||||||
|
{ "value": "多线并行", "label": "多线并行 (Multi-line)", "en": "Multi-line Narrative" },
|
||||||
|
{ "value": "倒叙", "label": "倒叙 (Flashback)", "en": "Flashback" },
|
||||||
|
{ "value": "意识流", "label": "意识流 (Stream of Consciousness)", "en": "Stream of Consciousness" },
|
||||||
|
{ "value": "伪纪录片", "label": "伪纪录片 (Mockumentary)", "en": "Mockumentary" }
|
||||||
|
],
|
||||||
|
"default": "线性叙事"
|
||||||
|
},
|
||||||
|
"editingPaces": {
|
||||||
|
"label": "剪辑节奏",
|
||||||
|
"options": [
|
||||||
|
{ "value": "缓慢/沉浸", "label": "缓慢/沉浸 (Slow/Immersive)", "en": "Slow/Immersive" },
|
||||||
|
{ "value": "明快/流畅", "label": "明快/流畅 (Brisk/Fluid)", "en": "Brisk/Fluid" },
|
||||||
|
{ "value": "快速/凌厉", "label": "快速/凌厉 (Fast/Sharp)", "en": "Fast/Sharp" },
|
||||||
|
{ "value": "极速/碎片化", "label": "极速/碎片化 (Rapid/Fragmented)", "en": "Rapid/Fragmented" },
|
||||||
|
{ "value": "舒缓/诗意", "label": "舒缓/诗意 (Soothing/Poetic)", "en": "Soothing/Poetic" }
|
||||||
|
],
|
||||||
|
"default": "明快/流畅"
|
||||||
|
},
|
||||||
|
"styles": {
|
||||||
|
"label": "导演风格",
|
||||||
|
"options": [
|
||||||
|
{ "value": "孔笙 (Kong Sheng)", "label": "孔笙 (厚重/写实/山海情)", "en": "Kong Sheng Style" },
|
||||||
|
{ "value": "郑晓龙 (Zheng Xiaolong)", "label": "郑晓龙 (宫廷/传奇/甄嬛传)", "en": "Zheng Xiaolong Style" },
|
||||||
|
{ "value": "张黎 (Zhang Li)", "label": "张黎 (历史/权谋/大明王朝)", "en": "Zhang Li Style" },
|
||||||
|
{ "value": "辛爽 (Xin Shuang)", "label": "辛爽 (悬疑/美学/漫长的季节)", "en": "Xin Shuang Style" },
|
||||||
|
{ "value": "王家卫 (Wong Kar-wai)", "label": "王家卫 (繁花/光影/霓虹)", "en": "Wong Kar-wai Style" },
|
||||||
|
{ "value": "李路 (Li Lu)", "label": "李路 (史诗/人世间/人民的名义)", "en": "Li Lu Style" },
|
||||||
|
{ "value": "曹盾 (Cao Dun)", "label": "曹盾 (视觉/长镜头/长安十二时辰)", "en": "Cao Dun Style" },
|
||||||
|
{ "value": "王伟 (Wang Wei)", "label": "王伟 (硬汉/刑侦/白夜追凶)", "en": "Wang Wei Style" },
|
||||||
|
{ "value": "汪俊 (Wang Jun)", "label": "汪俊 (都市/细腻/小欢喜)", "en": "Wang Jun Style" },
|
||||||
|
{ "value": "徐纪周 (Xu Jizhou)", "label": "徐纪周 (群像/狂飙/快节奏)", "en": "Xu Jizhou Style" },
|
||||||
|
{ "value": "吕行 (Lu Xing)", "label": "吕行 (犯罪/人性/无证之罪)", "en": "Lu Xing Style" },
|
||||||
|
{ "value": "丁黑 (Ding Hei)", "label": "丁黑 (警察荣誉/那年花开)", "en": "Ding Hei Style" }
|
||||||
|
],
|
||||||
|
"default": "孔笙 (Kong Sheng)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"character": {
|
||||||
|
"genders": {
|
||||||
|
"label": "性别",
|
||||||
|
"options": [
|
||||||
|
{ "value": "男", "label": "男" },
|
||||||
|
{ "value": "女", "label": "女" },
|
||||||
|
{ "value": "未知", "label": "未知" }
|
||||||
|
],
|
||||||
|
"default": "未知"
|
||||||
|
},
|
||||||
|
"costumeStyles": {
|
||||||
|
"label": "服装风格",
|
||||||
|
"options": [
|
||||||
|
{ "value": "现代日常", "label": "现代日常 (Modern Daily)", "en": "Modern Daily" },
|
||||||
|
{ "value": "古装汉服", "label": "古装汉服 (Ancient Hanfu)", "en": "Ancient Hanfu" },
|
||||||
|
{ "value": "民国风情", "label": "民国风情 (Republic Era)", "en": "Republic Era" },
|
||||||
|
{ "value": "赛博科幻", "label": "赛博科幻 (Cyberpunk Sci-Fi)", "en": "Cyberpunk Sci-Fi" },
|
||||||
|
{ "value": "职业制服", "label": "职业制服 (Professional Uniform)", "en": "Professional Uniform" },
|
||||||
|
{ "value": "街头潮流", "label": "街头潮流 (Streetwear)", "en": "Streetwear" },
|
||||||
|
{ "value": "极简森系", "label": "极简森系 (Minimalist Mori)", "en": "Minimalist Mori" },
|
||||||
|
{ "value": "奢华礼服", "label": "奢华礼服 (Luxury/Formal)", "en": "Luxury/Formal" }
|
||||||
|
],
|
||||||
|
"default": "现代日常"
|
||||||
|
},
|
||||||
|
"roles": {
|
||||||
|
"label": "角色定位",
|
||||||
|
"options": [
|
||||||
|
{ "value": "主角", "label": "主角" },
|
||||||
|
{ "value": "配角", "label": "配角" },
|
||||||
|
{ "value": "反派", "label": "反派" },
|
||||||
|
{ "value": "龙套", "label": "龙套" },
|
||||||
|
{ "value": "群演", "label": "群演" }
|
||||||
|
],
|
||||||
|
"default": "配角"
|
||||||
|
},
|
||||||
|
"emotions": {
|
||||||
|
"label": "情绪基调",
|
||||||
|
"options": [
|
||||||
|
{ "value": "平静", "label": "平静 (Neutral)", "en": "Neutral/Calm" },
|
||||||
|
{ "value": "喜悦", "label": "喜悦 (Happy)", "en": "Happy/Joyful" },
|
||||||
|
{ "value": "悲伤", "label": "悲伤 (Sad)", "en": "Sad/Sorrowful" },
|
||||||
|
{ "value": "愤怒", "label": "愤怒 (Angry)", "en": "Angry/Furious" },
|
||||||
|
{ "value": "恐惧", "label": "恐惧 (Fearful)", "en": "Fearful/Scared" },
|
||||||
|
{ "value": "惊讶", "label": "惊讶 (Surprised)", "en": "Surprised/Shocked" },
|
||||||
|
{ "value": "自信", "label": "自信 (Confident)", "en": "Confident/Bold" },
|
||||||
|
{ "value": "思索", "label": "思索 (Thinking)", "en": "Thinking/Pensive" }
|
||||||
|
],
|
||||||
|
"default": "平静"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"storyboard": {
|
||||||
|
"shotTypes": {
|
||||||
|
"label": "镜头类型",
|
||||||
|
"options": [
|
||||||
|
{ "value": "大远景 (ELS)", "label": "大远景 (ELS)", "en": "Extreme Long Shot (ELS)" },
|
||||||
|
{ "value": "远景 (LS)", "label": "远景 (LS)", "en": "Long Shot (LS)" },
|
||||||
|
{ "value": "全景 (FS)", "label": "全景 (FS)", "en": "Full Shot (FS)" },
|
||||||
|
{ "value": "中远景 (MLS)", "label": "中远景 (MLS)", "en": "Medium Long Shot (MLS)" },
|
||||||
|
{ "value": "中景 (MS)", "label": "中景 (MS)", "en": "Medium Shot (MS)" },
|
||||||
|
{ "value": "中特写 (MCU)", "label": "中特写 (MCU)", "en": "Medium Close-Up (MCU)" },
|
||||||
|
{ "value": "特写 (CU)", "label": "特写 (CU)", "en": "Close-Up (CU)" },
|
||||||
|
{ "value": "大特写 (ECU)", "label": "大特写 (ECU)", "en": "Extreme Close-Up (ECU)" },
|
||||||
|
{ "value": "建立镜头", "label": "建立镜头", "en": "Establishing Shot" },
|
||||||
|
{ "value": "主观镜头 (POV)", "label": "主观镜头 (POV)", "en": "Point of View (POV)" },
|
||||||
|
{ "value": "过肩镜头 (OTS)", "label": "过肩镜头 (OTS)", "en": "Over the Shoulder (OTS)" }
|
||||||
|
],
|
||||||
|
"default": "中景 (MS)"
|
||||||
|
},
|
||||||
|
"cameraMovements": {
|
||||||
|
"label": "运镜方式",
|
||||||
|
"options": [
|
||||||
|
{ "value": "固定镜头 (Static)", "label": "固定镜头 (Static)", "en": "Static" },
|
||||||
|
{ "value": "左摇 (Pan Left)", "label": "左摇 (Pan Left)", "en": "Pan Left" },
|
||||||
|
{ "value": "右摇 (Pan Right)", "label": "右摇 (Pan Right)", "en": "Pan Right" },
|
||||||
|
{ "value": "上仰 (Tilt Up)", "label": "上仰 (Tilt Up)", "en": "Tilt Up" },
|
||||||
|
{ "value": "下俯 (Tilt Down)", "label": "下俯 (Tilt Down)", "en": "Tilt Down" },
|
||||||
|
{ "value": "推镜头 (Zoom In)", "label": "推镜头 (Zoom In)", "en": "Zoom In" },
|
||||||
|
{ "value": "拉镜头 (Zoom Out)", "label": "拉镜头 (Zoom Out)", "en": "Zoom Out" },
|
||||||
|
{ "value": "前移 (Dolly In)", "label": "前移 (Dolly In)", "en": "Dolly In" },
|
||||||
|
{ "value": "后移 (Dolly Out)", "label": "后移 (Dolly Out)", "en": "Dolly Out" },
|
||||||
|
{ "value": "跟随 (Tracking)", "label": "跟随 (Tracking)", "en": "Tracking" },
|
||||||
|
{ "value": "环绕 (Arc)", "label": "环绕 (Arc)", "en": "Arc" },
|
||||||
|
{ "value": "手持 (Handheld)", "label": "手持 (Handheld)", "en": "Handheld" }
|
||||||
|
],
|
||||||
|
"default": "固定镜头 (Static)"
|
||||||
|
},
|
||||||
|
"transitions": {
|
||||||
|
"label": "转场效果",
|
||||||
|
"options": [
|
||||||
|
{ "value": "切 (Cut)", "label": "切 (Cut)", "en": "Cut" },
|
||||||
|
{ "value": "叠化 (Dissolve)", "label": "叠化 (Dissolve)", "en": "Dissolve" },
|
||||||
|
{ "value": "淡入 (Fade In)", "label": "淡入 (Fade In)", "en": "Fade In" },
|
||||||
|
{ "value": "淡出 (Fade Out)", "label": "淡出 (Fade Out)", "en": "Fade Out" },
|
||||||
|
{ "value": "划像 (Wipe)", "label": "划像 (Wipe)", "en": "Wipe" },
|
||||||
|
{ "value": "圈入 (Iris In)", "label": "圈入 (Iris In)", "en": "Iris In" },
|
||||||
|
{ "value": "圈出 (Iris Out)", "label": "圈出 (Iris Out)", "en": "Iris Out" },
|
||||||
|
{ "value": "匹配剪辑 (Match Cut)", "label": "匹配剪辑 (Match Cut)", "en": "Match Cut" },
|
||||||
|
{ "value": "跳接 (Jump Cut)", "label": "跳接 (Jump Cut)", "en": "Jump Cut" }
|
||||||
|
],
|
||||||
|
"default": "切 (Cut)"
|
||||||
|
},
|
||||||
|
"compositions": {
|
||||||
|
"label": "构图方式",
|
||||||
|
"options": [
|
||||||
|
{ "value": "三分法 (Rule of Thirds)", "label": "三分法 (Rule of Thirds)", "en": "Rule of Thirds" },
|
||||||
|
{ "value": "中心构图 (Center Framed)", "label": "中心构图 (Center Framed)", "en": "Center Framed" },
|
||||||
|
{ "value": "对称构图 (Symmetrical)", "label": "对称构图 (Symmetrical)", "en": "Symmetrical" },
|
||||||
|
{ "value": "引导线 (Leading Lines)", "label": "引导线 (Leading Lines)", "en": "Leading Lines" },
|
||||||
|
{ "value": "对角线 (Diagonal)", "label": "对角线 (Diagonal)", "en": "Diagonal" },
|
||||||
|
{ "value": "框架构图 (Framing)", "label": "框架构图 (Framing)", "en": "Framing" },
|
||||||
|
{ "value": "极简留白 (Minimalist/Negative Space)", "label": "极简留白 (Minimalist/Negative Space)", "en": "Minimalist/Negative Space" },
|
||||||
|
{ "value": "黄金螺旋 (Golden Spiral)", "label": "黄金螺旋 (Golden Spiral)", "en": "Golden Spiral" }
|
||||||
|
],
|
||||||
|
"default": "三分法 (Rule of Thirds)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"scene": {
|
||||||
|
"timesOfDay": {
|
||||||
|
"label": "拍摄时段",
|
||||||
|
"options": [
|
||||||
|
{ "value": "清晨", "label": "清晨 (Dawn)", "en": "Dawn" },
|
||||||
|
{ "value": "早晨", "label": "早晨 (Morning)", "en": "Morning" },
|
||||||
|
{ "value": "正午", "label": "正午 (Noon)", "en": "Noon" },
|
||||||
|
{ "value": "下午", "label": "下午 (Afternoon)", "en": "Afternoon" },
|
||||||
|
{ "value": "黄金时刻", "label": "黄金时刻 (Golden Hour)", "en": "Golden Hour" },
|
||||||
|
{ "value": "傍晚", "label": "傍晚 (Dusk)", "en": "Dusk" },
|
||||||
|
{ "value": "蓝调时刻", "label": "蓝调时刻 (Blue Hour)", "en": "Blue Hour" },
|
||||||
|
{ "value": "夜晚", "label": "夜晚 (Night)", "en": "Night" },
|
||||||
|
{ "value": "深夜", "label": "深夜 (Late Night)", "en": "Late Night" }
|
||||||
|
],
|
||||||
|
"default": "早晨"
|
||||||
|
},
|
||||||
|
"environmentTypes": {
|
||||||
|
"label": "空间类型",
|
||||||
|
"options": [
|
||||||
|
{ "value": "室内", "label": "室内 (Interior)", "en": "Interior" },
|
||||||
|
{ "value": "室外", "label": "室外 (Exterior)", "en": "Exterior" },
|
||||||
|
{ "value": "半室外", "label": "半室外 (Semi-Exterior)", "en": "Semi-Exterior" }
|
||||||
|
],
|
||||||
|
"default": "室内"
|
||||||
|
},
|
||||||
|
"weather": {
|
||||||
|
"label": "天气环境",
|
||||||
|
"options": [
|
||||||
|
{ "value": "晴朗", "label": "晴朗 (Clear)", "en": "Clear/Sunny" },
|
||||||
|
{ "value": "多云", "label": "多云 (Cloudy)", "en": "Partly Cloudy" },
|
||||||
|
{ "value": "阴天", "label": "阴天 (Overcast)", "en": "Overcast" },
|
||||||
|
{ "value": "小雨", "label": "小雨 (Rainy)", "en": "Light Rain" },
|
||||||
|
{ "value": "大雨", "label": "大雨 (Heavy Rain)", "en": "Heavy Rain" },
|
||||||
|
{ "value": "暴风雨", "label": "暴风雨 (Storm)", "en": "Storm" },
|
||||||
|
{ "value": "小雪", "label": "小雪 (Snowy)", "en": "Light Snow" },
|
||||||
|
{ "value": "大雪", "label": "大雪 (Heavy Snow)", "en": "Heavy Snow" },
|
||||||
|
{ "value": "大雾", "label": "大雾 (Foggy)", "en": "Dense Fog" },
|
||||||
|
{ "value": "沙尘", "label": "沙尘 (Dusty)", "en": "Dusty/Sandstorm" }
|
||||||
|
],
|
||||||
|
"default": "晴朗"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"cinematic": {
|
||||||
|
"visualStyles": {
|
||||||
|
"label": "视觉画风",
|
||||||
|
"options": [
|
||||||
|
{ "value": "现实主义/纪录片感", "label": "现实主义/纪录片感 (Realistic/Documentary)", "en": "Realistic/Documentary" },
|
||||||
|
{ "value": "电影质感/胶片风", "label": "电影质感/胶片风 (Cinematic/Film)", "en": "Cinematic/Film" },
|
||||||
|
{ "value": "赛博朋克/霓虹", "label": "赛博朋克/霓虹 (Cyberpunk/Neon)", "en": "Cyberpunk/Neon" },
|
||||||
|
{ "value": "古风/水墨", "label": "古风/水墨 (Ancient/Ink Wash)", "en": "Ancient/Ink Wash" },
|
||||||
|
{ "value": "极简主义/冷淡", "label": "极简主义/冷淡 (Minimalist/Cold)", "en": "Minimalist/Cold" },
|
||||||
|
{ "value": "唯美/梦幻", "label": "唯美/梦幻 (Aesthetic/Dreamy)", "en": "Aesthetic/Dreamy" },
|
||||||
|
{ "value": "暗黑/哥特", "label": "暗黑/哥特 (Dark/Gothic)", "en": "Dark/Gothic" },
|
||||||
|
{ "value": "动画/二次元", "label": "动画/二次元 (Anime/2D)", "en": "Anime/2D" }
|
||||||
|
],
|
||||||
|
"default": "电影质感/胶片风"
|
||||||
|
},
|
||||||
|
"cameraAngles": {
|
||||||
|
"label": "镜头角度",
|
||||||
|
"options": [
|
||||||
|
{ "value": "平视", "label": "平视 (Eye Level)", "en": "Eye Level" },
|
||||||
|
{ "value": "俯拍", "label": "俯拍 (High Angle)", "en": "High Angle" },
|
||||||
|
{ "value": "仰拍", "label": "仰拍 (Low Angle)", "en": "Low Angle" },
|
||||||
|
{ "value": "顶拍", "label": "顶拍 (Top Down)", "en": "Top Down/Bird's Eye" },
|
||||||
|
{ "value": "底拍", "label": "底拍 (Worm's Eye)", "en": "Worm's Eye View" },
|
||||||
|
{ "value": "斜角镜头", "label": "斜角镜头 (Dutch Angle)", "en": "Dutch Angle" }
|
||||||
|
],
|
||||||
|
"default": "平视"
|
||||||
|
},
|
||||||
|
"lighting": {
|
||||||
|
"label": "灯光风格",
|
||||||
|
"options": [
|
||||||
|
{ "value": "电影感光效", "label": "电影感 (Cinematic)", "en": "Cinematic Lighting" },
|
||||||
|
{ "value": "自然光", "label": "自然光 (Natural)", "en": "Natural Lighting" },
|
||||||
|
{ "value": "柔光", "label": "柔光 (Soft)", "en": "Soft Lighting" },
|
||||||
|
{ "value": "强反差/明暗对照", "label": "强反差 (High Contrast)", "en": "Chiaroscuro/High Contrast" },
|
||||||
|
{ "value": "轮廓光", "label": "轮廓光 (Rim Lighting)", "en": "Rim Lighting" },
|
||||||
|
{ "value": "三点式亮光", "label": "三点式亮光 (Three-point Lighting)", "en": "Three-point Lighting" },
|
||||||
|
{ "value": "工作室光效", "label": "工作室 (Studio)", "en": "Studio Lighting" }
|
||||||
|
],
|
||||||
|
"default": "电影感光效"
|
||||||
|
},
|
||||||
|
"colorStyle": {
|
||||||
|
"label": "色调氛围",
|
||||||
|
"options": [
|
||||||
|
{ "value": "电影感", "label": "电影感 (Cinematic)", "en": "Cinematic" },
|
||||||
|
{ "value": "暖色调", "label": "暖色调 (Warm Tones)", "en": "Warm Tones" },
|
||||||
|
{ "value": "冷色调", "label": "冷色调 (Cold Tones)", "en": "Cold Tones" },
|
||||||
|
{ "value": "黑白", "label": "黑白 (Black and White)", "en": "Black and White" },
|
||||||
|
{ "value": "复古/胶片感", "label": "复古/胶片 (Vintage)", "en": "Vintage Film" },
|
||||||
|
{ "value": "赛博朋克", "label": "赛博朋克 (Cyberpunk)", "en": "Cyberpunk" },
|
||||||
|
{ "value": "高饱和", "label": "高饱和 (Vibrant)", "en": "Vibrant" },
|
||||||
|
{ "value": "低饱和/灰色调", "label": "低饱和 (Muted)", "en": "Muted/Desaturated" }
|
||||||
|
],
|
||||||
|
"default": "电影感"
|
||||||
|
},
|
||||||
|
"lenses": {
|
||||||
|
"label": "镜头焦距",
|
||||||
|
"options": [
|
||||||
|
{ "value": "广角镜头", "label": "超广角 (14-24mm)", "en": "Ultra Wide Lens" },
|
||||||
|
{ "value": "标准广角", "label": "广角 (35mm)", "en": "Wide Angle Lens" },
|
||||||
|
{ "value": "标准镜头", "label": "标准 (50mm)", "en": "Standard/Nifty Fifty" },
|
||||||
|
{ "value": "人像镜头", "label": "长焦 (85mm)", "en": "Portrait/Short Telephoto" },
|
||||||
|
{ "value": "远摄镜头", "label": "超长焦 (200mm+)", "en": "Telephoto Lens" },
|
||||||
|
{ "value": "鱼眼镜头", "label": "鱼眼 (Fisheye)", "en": "Fisheye Lens" },
|
||||||
|
{ "value": "变焦镜头", "label": "变焦 (Anamorphic)", "en": "Anamorphic Lens" }
|
||||||
|
],
|
||||||
|
"default": "标准镜头"
|
||||||
|
},
|
||||||
|
"focus": {
|
||||||
|
"label": "焦点控制",
|
||||||
|
"options": [
|
||||||
|
{ "value": "自动对焦", "label": "自动对焦 (Auto)", "en": "Auto Focus" },
|
||||||
|
{ "value": "浅景深/虚化", "label": "浅景深 (Shallow Bokeh)", "en": "Shallow Depth of Field" },
|
||||||
|
{ "value": "大深景", "label": "大深景 (Deep Focus)", "en": "Deep Depth of Field" },
|
||||||
|
{ "value": "焦点转移", "label": "变焦对焦 (Rack Focus)", "en": "Rack Focus" },
|
||||||
|
{ "value": "微距焦点", "label": "微距 (Macro)", "en": "Macro Focus" },
|
||||||
|
{ "value": "移轴效果", "label": "移轴 (Tilt-Shift)", "en": "Tilt-Shift" }
|
||||||
|
],
|
||||||
|
"default": "自动对焦"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
16
backend/src/config/script_agent_config.json
Normal file
16
backend/src/config/script_agent_config.json
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
{
|
||||||
|
"roles": {
|
||||||
|
"story_architect": "moonshot-v1-8k",
|
||||||
|
"character_consultant": "moonshot-v1-8k",
|
||||||
|
"scriptwriter": "moonshot-v1-8k",
|
||||||
|
"director": "moonshot-v1-8k",
|
||||||
|
"auditor": "moonshot-v1-8k",
|
||||||
|
"moderator": "moonshot-v1-8k",
|
||||||
|
"psychologist": "moonshot-v1-8k",
|
||||||
|
"visualizer": "moonshot-v1-8k",
|
||||||
|
"continuity_manager": "moonshot-v1-8k",
|
||||||
|
"showrunner": "moonshot-v1-8k",
|
||||||
|
"chief_editor": "moonshot-v1-8k",
|
||||||
|
"specialist": "moonshot-v1-8k"
|
||||||
|
}
|
||||||
|
}
|
||||||
5
backend/src/config/services/alibaba/provider.json
Normal file
5
backend/src/config/services/alibaba/provider.json
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"id": "aliyun",
|
||||||
|
"name": "阿里云",
|
||||||
|
"description": "阿里云提供的包括图像超分、视频超分等服务"
|
||||||
|
}
|
||||||
8
backend/src/config/services/alibaba/upscale.json
Normal file
8
backend/src/config/services/alibaba/upscale.json
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
{
|
||||||
|
"videoenhan": {
|
||||||
|
"name": "阿里巴巴图像超分",
|
||||||
|
"class": "backend.src.services.post_process.super_resolution.SuperResolutionService",
|
||||||
|
"args": [],
|
||||||
|
"enabled": true
|
||||||
|
}
|
||||||
|
}
|
||||||
16
backend/src/config/services/anthropic/provider.json
Normal file
16
backend/src/config/services/anthropic/provider.json
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
{
|
||||||
|
"id": "anthropic",
|
||||||
|
"name": "Anthropic",
|
||||||
|
"description": "Claude 系列模型",
|
||||||
|
"dashboard_url": "https://console.anthropic.com/settings/keys",
|
||||||
|
"helpUrl": "https://console.anthropic.com/settings/keys",
|
||||||
|
"fields": [
|
||||||
|
{
|
||||||
|
"name": "apiKey",
|
||||||
|
"label": "API Key",
|
||||||
|
"placeholder": "sk-ant-...",
|
||||||
|
"required": true,
|
||||||
|
"type": "password"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
488
backend/src/config/services/dashscope/audio.json
Normal file
488
backend/src/config/services/dashscope/audio.json
Normal file
@@ -0,0 +1,488 @@
|
|||||||
|
{
|
||||||
|
"cosyvoice-v3-plus": {
|
||||||
|
"name": "CosyVoice-V3-Plus",
|
||||||
|
"class": "backend.src.services.provider.dashscope.audio.DashScopeAudioService",
|
||||||
|
"args": ["cosyvoice-v3-plus"],
|
||||||
|
"voices": [
|
||||||
|
{
|
||||||
|
"id": "longanyang",
|
||||||
|
"name": "龙安洋",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "阳光大男孩"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "longanhuan",
|
||||||
|
"name": "龙安欢",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "欢脱元气女"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "longhuhu_v3",
|
||||||
|
"name": "龙呼呼",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "天真烂漫女童"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "longpaopao_v3",
|
||||||
|
"name": "龙泡泡",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "飞天泡泡音"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "longjielidou_v3",
|
||||||
|
"name": "龙杰力豆",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "阳光顽皮男童"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "longxian_v3",
|
||||||
|
"name": "龙仙",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "豪放可爱女童"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "longling_v3",
|
||||||
|
"name": "龙铃",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "稚气呆板女童"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "longshanshan_v3",
|
||||||
|
"name": "龙闪闪",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "戏剧化童声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "longniuniu_v3",
|
||||||
|
"name": "龙牛牛",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "阳光男童声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "longjiaxin_v3",
|
||||||
|
"name": "龙嘉欣",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "优雅粤语女"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "longjiayi_v3",
|
||||||
|
"name": "龙嘉怡",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "知性粤语女"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "longanyue_v3",
|
||||||
|
"name": "龙安粤",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "欢脱粤语男"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "longlaotie_v3",
|
||||||
|
"name": "龙老铁",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "东北直率男"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "longshange_v3",
|
||||||
|
"name": "龙陕哥",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "原味陕北男"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "longanmin_v3",
|
||||||
|
"name": "龙安闽",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "清纯闽南女"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "loongkyong_v3",
|
||||||
|
"name": "loongkyong",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "韩语女"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "loongriko_v3",
|
||||||
|
"name": "Riko",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "二次元日语女"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "loongtomoka_v3",
|
||||||
|
"name": "loongtomoka",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "日语女"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "longfei_v3",
|
||||||
|
"name": "龙飞",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "热血磁性男"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "longxiaochun_v3",
|
||||||
|
"name": "龙小淳",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "清丽温柔女"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "longxiaoxia_v3",
|
||||||
|
"name": "龙小夏",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "活泼甜美女"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "longshu_v3",
|
||||||
|
"name": "龙舒",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "知性温婉女"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "longyue_v3",
|
||||||
|
"name": "龙悦",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "阳光青年男"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "longcheng_v3",
|
||||||
|
"name": "龙城",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "成熟稳重男"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "longhua_v3",
|
||||||
|
"name": "龙华",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "标准男声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "longwan_v3",
|
||||||
|
"name": "龙婉",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "温婉知性女"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "longjing_v3",
|
||||||
|
"name": "龙静",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "标准女声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "longmiao_v3",
|
||||||
|
"name": "龙淼",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "标准女声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "longshuo_v3",
|
||||||
|
"name": "龙硕",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "标准男声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "longxiang_v3",
|
||||||
|
"name": "龙翔",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "标准男声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "longyuan_v3",
|
||||||
|
"name": "龙源",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "标准男声"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
"qwen3-tts-flash": {
|
||||||
|
"name": "Qwen3-TTS-Flash",
|
||||||
|
"class": "backend.src.services.provider.dashscope.audio.DashScopeAudioService",
|
||||||
|
"args": ["qwen3-tts-flash"],
|
||||||
|
"voices": [
|
||||||
|
{
|
||||||
|
"id": "Cherry",
|
||||||
|
"name": "芊悦",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "阳光积极、亲切自然小姐姐"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Serena",
|
||||||
|
"name": "苏瑶",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "温柔小姐姐"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Ethan",
|
||||||
|
"name": "晨煦",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "阳光、温暖、活力、朝气"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Chelsie",
|
||||||
|
"name": "千雪",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "二次元虚拟女友"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Momo",
|
||||||
|
"name": "茉兔",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "撒娇搞怪,逗你开心"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Vivian",
|
||||||
|
"name": "十三",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "拽拽的、可爱的小暴躁"
|
||||||
|
},
|
||||||
|
{ "id": "Moon", "name": "月白", "gender": "male", "desc": "率性帅气" },
|
||||||
|
{
|
||||||
|
"id": "Maia",
|
||||||
|
"name": "四月",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "知性与温柔的碰撞"
|
||||||
|
},
|
||||||
|
{ "id": "Kai", "name": "凯", "gender": "male", "desc": "耳朵的一场SPA" },
|
||||||
|
{
|
||||||
|
"id": "Nofish",
|
||||||
|
"name": "不吃鱼",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "不会翘舌音的设计师"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Bella",
|
||||||
|
"name": "萌宝",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "喝酒不打醉拳的小萝莉"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Jennifer",
|
||||||
|
"name": "詹妮弗",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "品牌级、电影质感般美语女声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Ryan",
|
||||||
|
"name": "甜茶",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "节奏拉满,戏感炸裂"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Katerina",
|
||||||
|
"name": "卡捷琳娜",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "御姐音色,韵律回味十足"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Aiden",
|
||||||
|
"name": "艾登",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "精通厨艺的美语大男孩"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Eldric Sage",
|
||||||
|
"name": "沧明子",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "沉稳睿智的老者"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Mia",
|
||||||
|
"name": "乖小妹",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "温顺如春水,乖巧如初雪"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Mochi",
|
||||||
|
"name": "沙小弥",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "聪明伶俐的小大人"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Bellona",
|
||||||
|
"name": "燕铮莺",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "金戈铁马,千面人声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Vincent",
|
||||||
|
"name": "田叔",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "沙哑烟嗓,江湖豪情"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Bunny",
|
||||||
|
"name": "萌小姬",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "萌属性爆棚的小萝莉"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Neil",
|
||||||
|
"name": "阿闻",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "字正腔圆的新闻主持人"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Elias",
|
||||||
|
"name": "墨讲师",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "严谨又通俗的知识讲解"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Arthur",
|
||||||
|
"name": "徐大爷",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "质朴嗓音,奇闻异事"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Nini",
|
||||||
|
"name": "邻家妹妹",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "又软又黏的甜蜜嗓音"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Ebona",
|
||||||
|
"name": "诡婆婆",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "幽暗低语,神秘诡异"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Seren",
|
||||||
|
"name": "小婉",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "温和舒缓,助眠音色"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Pip",
|
||||||
|
"name": "顽屁小孩",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "调皮捣蛋却充满童真"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Stella",
|
||||||
|
"name": "少女阿月",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "甜到发腻的迷糊少女音"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Bodega",
|
||||||
|
"name": "博德加",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "热情的西班牙大叔"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Sonrisa",
|
||||||
|
"name": "索尼莎",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "热情开朗的拉美大姐"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Alek",
|
||||||
|
"name": "阿列克",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "战斗民族的冷与暖"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Dolce",
|
||||||
|
"name": "多尔切",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "慵懒的意大利大叔"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Sohee",
|
||||||
|
"name": "素熙",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "温柔开朗的韩国欧尼"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Ono Anna",
|
||||||
|
"name": "小野杏",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "鬼灵精怪的青梅竹马"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Lenn",
|
||||||
|
"name": "莱恩",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "理性底色的德国青年"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Emilien",
|
||||||
|
"name": "埃米尔安",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "浪漫的法国大哥哥"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Andre",
|
||||||
|
"name": "安德雷",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "声音磁性,沉稳男生"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Radio Gol",
|
||||||
|
"name": "拉迪奥·戈尔",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "足球诗人解说员"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Jada",
|
||||||
|
"name": "上海-阿珍",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "风风火火的沪上阿姐"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Dylan",
|
||||||
|
"name": "北京-晓东",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "北京胡同里长大的少年"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Li",
|
||||||
|
"name": "南京-老李",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "耐心的瑜伽老师"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Marcus",
|
||||||
|
"name": "陕西-秦川",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "面宽话短,心实声沉"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Roy",
|
||||||
|
"name": "闽南-阿杰",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "诙谐直爽的台湾哥仔"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Peter",
|
||||||
|
"name": "天津-李彼得",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "天津相声,专业捧哏"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Sunny",
|
||||||
|
"name": "四川-晴儿",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "甜到你心里的川妹子"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Eric",
|
||||||
|
"name": "四川-程川",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "跳脱市井的成都男子"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Rocky",
|
||||||
|
"name": "粤语-阿强",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "幽默风趣,在线陪聊"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Kiki",
|
||||||
|
"name": "粤语-阿清",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "甜美的港妹闺蜜"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"enabled": true
|
||||||
|
}
|
||||||
|
}
|
||||||
143
backend/src/config/services/dashscope/image.json
Normal file
143
backend/src/config/services/dashscope/image.json
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
{
|
||||||
|
"z-image": {
|
||||||
|
"name": "Z-Image",
|
||||||
|
"class": "backend.src.services.provider.dashscope.image.ZImageService",
|
||||||
|
"args": [
|
||||||
|
"z-image"
|
||||||
|
],
|
||||||
|
"capabilities": {
|
||||||
|
"supportsRefImage": false,
|
||||||
|
"supportsLora": false
|
||||||
|
},
|
||||||
|
"variants": {
|
||||||
|
"t2i": "z-image-turbo"
|
||||||
|
},
|
||||||
|
"resolutions": {
|
||||||
|
"1K": {
|
||||||
|
"16:9": "1536*864",
|
||||||
|
"9:16": "864*1536",
|
||||||
|
"1:1": "1280*1280",
|
||||||
|
"4:3": "1280*960",
|
||||||
|
"3:4": "960*1280",
|
||||||
|
"21:9": "1680*720",
|
||||||
|
"9:21": "720*1680"
|
||||||
|
},
|
||||||
|
"2K": {
|
||||||
|
"16:9": "2048*1152",
|
||||||
|
"9:16": "1152*2048",
|
||||||
|
"21:9": "2016*864",
|
||||||
|
"9:21": "864*2016"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"counts": {
|
||||||
|
"min": 1,
|
||||||
|
"max": 4
|
||||||
|
},
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
"wan2.6-image": {
|
||||||
|
"name": "Wan 2.6",
|
||||||
|
"class": "backend.src.services.provider.dashscope.image.WanImageService",
|
||||||
|
"args": [
|
||||||
|
"wan2.6-image"
|
||||||
|
],
|
||||||
|
"capabilities": {
|
||||||
|
"supportsRefImage": true,
|
||||||
|
"supportsLora": false
|
||||||
|
},
|
||||||
|
"variants": {
|
||||||
|
"t2i": "wan2.6-t2i",
|
||||||
|
"i2i": "wan2.6-image"
|
||||||
|
},
|
||||||
|
"resolutions": {
|
||||||
|
"1K": {
|
||||||
|
"16:9": "1280*720",
|
||||||
|
"9:16": "720*1280",
|
||||||
|
"1:1": "1280*1280",
|
||||||
|
"4:3": "1280*960",
|
||||||
|
"3:4": "960*1280"
|
||||||
|
},
|
||||||
|
"2K": {
|
||||||
|
"16:9": "2560*1440",
|
||||||
|
"9:16": "1440*2560",
|
||||||
|
"1:1": "2048*2048",
|
||||||
|
"4:3": "2560*1920",
|
||||||
|
"3:4": "1920*2560"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"counts": {
|
||||||
|
"min": 1,
|
||||||
|
"max": 4
|
||||||
|
},
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
"wan2.5-image": {
|
||||||
|
"name": "Wan 2.5",
|
||||||
|
"class": "backend.src.services.provider.dashscope.image.WanImageService",
|
||||||
|
"args": [
|
||||||
|
"wan2.5-image"
|
||||||
|
],
|
||||||
|
"capabilities": {
|
||||||
|
"supportsRefImage": true,
|
||||||
|
"supportsLora": false
|
||||||
|
},
|
||||||
|
"variants": {
|
||||||
|
"t2i": "wan2.5-t2i-preview",
|
||||||
|
"i2i": "wan2.5-i2i-preview"
|
||||||
|
},
|
||||||
|
"resolutions": {
|
||||||
|
"1K": {
|
||||||
|
"16:9": "1280*720",
|
||||||
|
"9:16": "720*1280",
|
||||||
|
"1:1": "1280*1280",
|
||||||
|
"4:3": "1280*960",
|
||||||
|
"3:4": "960*1280"
|
||||||
|
},
|
||||||
|
"2K": {
|
||||||
|
"16:9": "2560*1440",
|
||||||
|
"9:16": "1440*2560",
|
||||||
|
"1:1": "2048*2048",
|
||||||
|
"4:3": "2560*1920",
|
||||||
|
"3:4": "1920*2560"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"counts": {
|
||||||
|
"min": 1,
|
||||||
|
"max": 4
|
||||||
|
},
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
"qwen-image": {
|
||||||
|
"name": "Qwen Image",
|
||||||
|
"class": "backend.src.services.provider.dashscope.image.QwenImageService",
|
||||||
|
"args": [
|
||||||
|
"qwen-image"
|
||||||
|
],
|
||||||
|
"capabilities": {
|
||||||
|
"supportsRefImage": true,
|
||||||
|
"supportsLora": false
|
||||||
|
},
|
||||||
|
"variants": {
|
||||||
|
"t2i": "qwen-image-plus",
|
||||||
|
"i2i": "qwen-image-edit-plus"
|
||||||
|
},
|
||||||
|
"resolutions": {
|
||||||
|
"1K": {
|
||||||
|
"16:9": "1664*928",
|
||||||
|
"9:16": "928*1664",
|
||||||
|
"1:1": "1328*1328",
|
||||||
|
"4:3": "1472*1140",
|
||||||
|
"3:4": "1140*1472"
|
||||||
|
},
|
||||||
|
"2K": {
|
||||||
|
"16:9": "2560*1440",
|
||||||
|
"9:16": "1440*2560",
|
||||||
|
"1:1": "2048*2048",
|
||||||
|
"4:3": "2560*1920",
|
||||||
|
"3:4": "1920*2560"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"counts": {"min": 1, "max": 4},
|
||||||
|
"enabled": true
|
||||||
|
}
|
||||||
|
}
|
||||||
39
backend/src/config/services/dashscope/llm.json
Normal file
39
backend/src/config/services/dashscope/llm.json
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
{
|
||||||
|
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||||
|
"qwen-plus": {
|
||||||
|
"name": "Qwen Plus",
|
||||||
|
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||||
|
"args": ["qwen-plus"],
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
"qwen3-max": {
|
||||||
|
"name": "Qwen Max",
|
||||||
|
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||||
|
"args": ["qwen3-max"],
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
"qwen-flash": {
|
||||||
|
"name": "Qwen Flash",
|
||||||
|
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||||
|
"args": ["qwen-flash"],
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
"deepseek": {
|
||||||
|
"name": "Deepseek",
|
||||||
|
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||||
|
"args": ["deepseek-v3.2"],
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
"kimi-k2.5": {
|
||||||
|
"name": "Kimi K2.5",
|
||||||
|
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||||
|
"args": ["kimi-k2.5"],
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
"MiniMax-M2.1": {
|
||||||
|
"name": "MiniMax M2.1",
|
||||||
|
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||||
|
"args": ["MiniMax-M2.1"],
|
||||||
|
"enabled": true
|
||||||
|
}
|
||||||
|
}
|
||||||
16
backend/src/config/services/dashscope/provider.json
Normal file
16
backend/src/config/services/dashscope/provider.json
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
{
|
||||||
|
"id": "dashscope",
|
||||||
|
"name": "阿里云百炼",
|
||||||
|
"description": "阿里云提供的大模型服务",
|
||||||
|
"dashboard_url": "https://dashscope.console.aliyun.com/",
|
||||||
|
"helpUrl": "https://dashscope.console.aliyun.com/apiKey",
|
||||||
|
"fields": [
|
||||||
|
{
|
||||||
|
"name": "apiKey",
|
||||||
|
"label": "API Key",
|
||||||
|
"placeholder": "sk-...",
|
||||||
|
"required": true,
|
||||||
|
"type": "password"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
196
backend/src/config/services/dashscope/video.json
Normal file
196
backend/src/config/services/dashscope/video.json
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
{
|
||||||
|
"wan2.6-video": {
|
||||||
|
"name": "Wan 2.6",
|
||||||
|
"class": "backend.src.services.provider.dashscope.video.WanVideoService",
|
||||||
|
"args": [
|
||||||
|
"wan2.6-video"
|
||||||
|
],
|
||||||
|
"capabilities": {
|
||||||
|
"supportsTextToVideo": true,
|
||||||
|
"supportsImageToVideo": true,
|
||||||
|
"supportsVideoToVideo": true,
|
||||||
|
"supportsFirstFrame": true,
|
||||||
|
"supportsLastFrame": false,
|
||||||
|
"supportsMultiImage": true,
|
||||||
|
"supportsMultiVideo": true,
|
||||||
|
"supportsAudio": true,
|
||||||
|
"supportsShotType": true,
|
||||||
|
"supportsNegativePrompt": true,
|
||||||
|
"supportsLora": false
|
||||||
|
},
|
||||||
|
"variants": {
|
||||||
|
"t2v": "wan2.6-t2v",
|
||||||
|
"i2v": "wan2.6-i2v",
|
||||||
|
"r2v": "wan2.6-r2v"
|
||||||
|
},
|
||||||
|
"resolutions": {
|
||||||
|
"720P": {
|
||||||
|
"16:9": "1280*720",
|
||||||
|
"9:16": "720*1280",
|
||||||
|
"1:1": "1280*1280",
|
||||||
|
"4:3": "1280*960",
|
||||||
|
"3:4": "960*1280"
|
||||||
|
},
|
||||||
|
"1080P": {
|
||||||
|
"16:9": "1920*1080",
|
||||||
|
"9:16": "1080*1920",
|
||||||
|
"1:1": "1920*1920",
|
||||||
|
"4:3": "1920*1440",
|
||||||
|
"3:4": "1440*1920"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"durations": {
|
||||||
|
"min": 2,
|
||||||
|
"max": 10
|
||||||
|
},
|
||||||
|
"counts": {
|
||||||
|
"min": 1,
|
||||||
|
"max": 4
|
||||||
|
},
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
"wan2.6-video-flash": {
|
||||||
|
"name": "Wan 2.6 Flash",
|
||||||
|
"class": "backend.src.services.provider.dashscope.video.WanVideoService",
|
||||||
|
"args": [
|
||||||
|
"wan2.6-video-flash"
|
||||||
|
],
|
||||||
|
"capabilities": {
|
||||||
|
"supportsTextToVideo": true,
|
||||||
|
"supportsImageToVideo": true,
|
||||||
|
"supportsVideoToVideo": true,
|
||||||
|
"supportsFirstFrame": true,
|
||||||
|
"supportsLastFrame": false,
|
||||||
|
"supportsMultiImage": true,
|
||||||
|
"supportsMultiVideo": true,
|
||||||
|
"supportsAudio": true,
|
||||||
|
"supportsShotType": true,
|
||||||
|
"supportsNegativePrompt": true,
|
||||||
|
"supportsLora": false
|
||||||
|
},
|
||||||
|
"variants": {
|
||||||
|
"t2v": "wan2.6-t2v",
|
||||||
|
"i2v": "wan2.6-i2v-flash",
|
||||||
|
"r2v": "wan2.6-r2v-flash"
|
||||||
|
},
|
||||||
|
"resolutions": {
|
||||||
|
"720P": {
|
||||||
|
"16:9": "1280*720",
|
||||||
|
"9:16": "720*1280",
|
||||||
|
"1:1": "1280*1280",
|
||||||
|
"4:3": "1280*960",
|
||||||
|
"3:4": "960*1280"
|
||||||
|
},
|
||||||
|
"1080P": {
|
||||||
|
"16:9": "1920*1080",
|
||||||
|
"9:16": "1080*1920",
|
||||||
|
"1:1": "1920*1920",
|
||||||
|
"4:3": "1920*1440",
|
||||||
|
"3:4": "1440*1920"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"durations": {
|
||||||
|
"min": 2,
|
||||||
|
"max": 10
|
||||||
|
},
|
||||||
|
"counts": {
|
||||||
|
"min": 1,
|
||||||
|
"max": 4
|
||||||
|
},
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
"wan2.5-video": {
|
||||||
|
"name": "Wan 2.5",
|
||||||
|
"class": "backend.src.services.provider.dashscope.video.WanVideoService",
|
||||||
|
"args": [
|
||||||
|
"wan2.5-video"
|
||||||
|
],
|
||||||
|
"capabilities": {
|
||||||
|
"supportsTextToVideo": true,
|
||||||
|
"supportsImageToVideo": true,
|
||||||
|
"supportsFirstFrame": true,
|
||||||
|
"supportsLastFrame": false,
|
||||||
|
"supportsMultiImage": false,
|
||||||
|
"supportsMultiVideo": false,
|
||||||
|
"supportsLora": false
|
||||||
|
},
|
||||||
|
"variants": {
|
||||||
|
"t2v": "wan2.5-t2v-preview",
|
||||||
|
"i2v": "wan2.5-i2v-preview"
|
||||||
|
},
|
||||||
|
"resolutions": {
|
||||||
|
"720P": {
|
||||||
|
"16:9": "1280*720",
|
||||||
|
"9:16": "720*1280",
|
||||||
|
"1:1": "1280*1280",
|
||||||
|
"4:3": "1280*960",
|
||||||
|
"3:4": "960*1280"
|
||||||
|
},
|
||||||
|
"1080P": {
|
||||||
|
"16:9": "1920*1080",
|
||||||
|
"9:16": "1080*1920",
|
||||||
|
"1:1": "1920*1920",
|
||||||
|
"4:3": "1920*1440",
|
||||||
|
"3:4": "1440*1920"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"durations": {
|
||||||
|
"values": [
|
||||||
|
5,
|
||||||
|
10
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"counts": {
|
||||||
|
"min": 1,
|
||||||
|
"max": 4
|
||||||
|
},
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
"wan2.2-video": {
|
||||||
|
"name": "Wan 2.2",
|
||||||
|
"class": "backend.src.services.provider.dashscope.video.WanVideoService",
|
||||||
|
"args": [
|
||||||
|
"wan2.2-video"
|
||||||
|
],
|
||||||
|
"capabilities": {
|
||||||
|
"supportsTextToVideo": true,
|
||||||
|
"supportsImageToVideo": true,
|
||||||
|
"supportsFirstFrame": true,
|
||||||
|
"supportsLastFrame": true,
|
||||||
|
"supportsMultiImage": false,
|
||||||
|
"supportsMultiVideo": false,
|
||||||
|
"supportsLora": false
|
||||||
|
},
|
||||||
|
"variants": {
|
||||||
|
"t2v": "wan2.2-t2v-plus",
|
||||||
|
"i2v": "wan2.2-i2v-flash",
|
||||||
|
"kf2v": "wan2.2-kf2v-flash"
|
||||||
|
},
|
||||||
|
"resolutions": {
|
||||||
|
"720P": {
|
||||||
|
"16:9": "1280*720",
|
||||||
|
"9:16": "720*1280",
|
||||||
|
"1:1": "1280*1280",
|
||||||
|
"4:3": "1280*960",
|
||||||
|
"3:4": "960*1280"
|
||||||
|
},
|
||||||
|
"1080P": {
|
||||||
|
"16:9": "1920*1080",
|
||||||
|
"9:16": "1080*1920",
|
||||||
|
"1:1": "1920*1920",
|
||||||
|
"4:3": "1920*1440",
|
||||||
|
"3:4": "1440*1920"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"durations": {
|
||||||
|
"values": [
|
||||||
|
5
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"counts": {
|
||||||
|
"min": 1,
|
||||||
|
"max": 4
|
||||||
|
},
|
||||||
|
"enabled": true
|
||||||
|
}
|
||||||
|
}
|
||||||
7
backend/src/config/services/default.json
Normal file
7
backend/src/config/services/default.json
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"llm": "qwen-plus",
|
||||||
|
"image": "z-image",
|
||||||
|
"video": "wan2.6-video",
|
||||||
|
"audio": "qwen3-tts-flash",
|
||||||
|
"upscale": "ali-videoenhan/videoenhan"
|
||||||
|
}
|
||||||
27
backend/src/config/services/google/llm.json
Normal file
27
backend/src/config/services/google/llm.json
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
{
|
||||||
|
"base_url": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||||
|
"gemini-1.5-pro": {
|
||||||
|
"name": "Gemini-1.5-Pro",
|
||||||
|
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||||
|
"args": [
|
||||||
|
"gemini-1.5-pro"
|
||||||
|
],
|
||||||
|
"enabled": false
|
||||||
|
},
|
||||||
|
"gemini-1.5-flash": {
|
||||||
|
"name": "Gemini-1.5-Flash",
|
||||||
|
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||||
|
"args": [
|
||||||
|
"gemini-1.5-flash"
|
||||||
|
],
|
||||||
|
"enabled": false
|
||||||
|
},
|
||||||
|
"gemini-2.0-flash": {
|
||||||
|
"name": "Gemini-2.0-Flash",
|
||||||
|
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||||
|
"args": [
|
||||||
|
"gemini-2.0-flash"
|
||||||
|
],
|
||||||
|
"enabled": false
|
||||||
|
}
|
||||||
|
}
|
||||||
16
backend/src/config/services/google/provider.json
Normal file
16
backend/src/config/services/google/provider.json
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
{
|
||||||
|
"id": "google",
|
||||||
|
"name": "Google",
|
||||||
|
"description": "Gemini 系列模型",
|
||||||
|
"dashboard_url": "https://aistudio.google.com/app/apikey",
|
||||||
|
"helpUrl": "https://aistudio.google.com/app/apikey",
|
||||||
|
"fields": [
|
||||||
|
{
|
||||||
|
"name": "apiKey",
|
||||||
|
"label": "API Key",
|
||||||
|
"placeholder": "...",
|
||||||
|
"required": true,
|
||||||
|
"type": "password"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
27
backend/src/config/services/kling/provider.json
Normal file
27
backend/src/config/services/kling/provider.json
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
{
|
||||||
|
"id": "kling",
|
||||||
|
"name": "可灵 AI",
|
||||||
|
"description": "快手视频生成 - 需要 Access Key 和 Secret Key",
|
||||||
|
"dashboard_url": "https://klingai.kuaishou.com/",
|
||||||
|
"param_mapping": {
|
||||||
|
"api_key": "access_key",
|
||||||
|
"api_secret": "secret_key"
|
||||||
|
},
|
||||||
|
"helpUrl": "https://klingai.kuaishou.com/",
|
||||||
|
"fields": [
|
||||||
|
{
|
||||||
|
"name": "accessKey",
|
||||||
|
"label": "Access Key",
|
||||||
|
"placeholder": "Access Key",
|
||||||
|
"required": true,
|
||||||
|
"type": "text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "secretKey",
|
||||||
|
"label": "Secret Key",
|
||||||
|
"placeholder": "Secret Key",
|
||||||
|
"required": true,
|
||||||
|
"type": "password"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
124
backend/src/config/services/kling/video.json
Normal file
124
backend/src/config/services/kling/video.json
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
{
|
||||||
|
"kling-v2-5-turbo": {
|
||||||
|
"name": "Kling V2.5 Turbo",
|
||||||
|
"class": "src.services.provider.kling.KlingVideoService",
|
||||||
|
"args": [
|
||||||
|
"kling-v2-5-turbo"
|
||||||
|
],
|
||||||
|
"capabilities": {
|
||||||
|
"supportsTextToVideo": true,
|
||||||
|
"supportsImageToVideo": true,
|
||||||
|
"supportsFirstFrame": true,
|
||||||
|
"supportsLastFrame": true,
|
||||||
|
"supportsMultiImage": false,
|
||||||
|
"supportsMultiVideo": false,
|
||||||
|
"supportsAudio": false,
|
||||||
|
"supportsNegativePrompt": true,
|
||||||
|
"supportsCameraControl": true,
|
||||||
|
"supportsLora": false
|
||||||
|
},
|
||||||
|
"variants": {
|
||||||
|
"t2v": "kling-v2-5-turbo",
|
||||||
|
"i2v": "kling-v2-5-turbo",
|
||||||
|
"kf2v": "kling-v2-5-turbo"
|
||||||
|
},
|
||||||
|
"modes": ["std", "pro"],
|
||||||
|
"resolutions": {
|
||||||
|
"720P": {
|
||||||
|
"16:9": "1280*720",
|
||||||
|
"9:16": "720*1280",
|
||||||
|
"1:1": "1024*1024"
|
||||||
|
},
|
||||||
|
"1080P": {
|
||||||
|
"16:9": "1920*1080",
|
||||||
|
"9:16": "1080*1920",
|
||||||
|
"1:1": "1024*1024"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"durations": {
|
||||||
|
"values": [5, 10]
|
||||||
|
},
|
||||||
|
"counts": {"min": 1, "max": 4},
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
"kling-v2-6": {
|
||||||
|
"name": "Kling V2.6",
|
||||||
|
"class": "src.services.provider.kling.KlingVideoService",
|
||||||
|
"args": [
|
||||||
|
"kling-v2-6"
|
||||||
|
],
|
||||||
|
"capabilities": {
|
||||||
|
"supportsTextToVideo": true,
|
||||||
|
"supportsImageToVideo": true,
|
||||||
|
"supportsFirstFrame": true,
|
||||||
|
"supportsLastFrame": true,
|
||||||
|
"supportsMultiImage": false,
|
||||||
|
"supportsMultiVideo": false,
|
||||||
|
"supportsAudio": true,
|
||||||
|
"supportsNegativePrompt": true,
|
||||||
|
"supportsCameraControl": true,
|
||||||
|
"supportsLora": false
|
||||||
|
},
|
||||||
|
"variants": {
|
||||||
|
"t2v": "kling-v2-6",
|
||||||
|
"i2v": "kling-v2-6",
|
||||||
|
"kf2v": "kling-v2-6"
|
||||||
|
},
|
||||||
|
"modes": ["pro"],
|
||||||
|
"resolutions": {
|
||||||
|
"720P": {
|
||||||
|
"16:9": "1280*720",
|
||||||
|
"9:16": "720*1280",
|
||||||
|
"1:1": "1024*1024"
|
||||||
|
},
|
||||||
|
"1080P": {
|
||||||
|
"16:9": "1920*1080",
|
||||||
|
"9:16": "1080*1920",
|
||||||
|
"1:1": "1024*1024"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"durations": {
|
||||||
|
"values": [5, 10]
|
||||||
|
},
|
||||||
|
"counts": {"min": 1, "max": 4},
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
"kling-video-o1": {
|
||||||
|
"name": "Kling Omni (O1)",
|
||||||
|
"class": "src.services.provider.kling.KlingVideoService",
|
||||||
|
"args": [
|
||||||
|
"kling-video-o1"
|
||||||
|
],
|
||||||
|
"capabilities": {
|
||||||
|
"supportsTextToVideo": true,
|
||||||
|
"supportsImageToVideo": true,
|
||||||
|
"supportsVideoToVideo": true,
|
||||||
|
"supportsFirstFrame": true,
|
||||||
|
"supportsLastFrame": true,
|
||||||
|
"supportsMultiImage": true,
|
||||||
|
"supportsMultiVideo": true,
|
||||||
|
"supportsAudio": true,
|
||||||
|
"supportsNegativePrompt": true,
|
||||||
|
"supportsLora": false
|
||||||
|
},
|
||||||
|
"variants": {
|
||||||
|
"t2v": "kling-video-o1",
|
||||||
|
"i2v": "kling-video-o1",
|
||||||
|
"v2v": "kling-video-o1",
|
||||||
|
"kf2v": "kling-video-o1"
|
||||||
|
},
|
||||||
|
"resolutions": {
|
||||||
|
"720P": {
|
||||||
|
"16:9": "1280*720",
|
||||||
|
"9:16": "720*1280",
|
||||||
|
"1:1": "1024*1024"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"durations": {
|
||||||
|
"min": 3,
|
||||||
|
"max": 10
|
||||||
|
},
|
||||||
|
"counts": {"min": 1, "max": 4},
|
||||||
|
"enabled": true
|
||||||
|
}
|
||||||
|
}
|
||||||
34
backend/src/config/services/midjourney/image.json
Normal file
34
backend/src/config/services/midjourney/image.json
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
{
|
||||||
|
"midjourney": {
|
||||||
|
"name": "Midjourney",
|
||||||
|
"class": "src.services.provider.midjourney.MidjourneyImageService",
|
||||||
|
"args": [
|
||||||
|
"midjourney"
|
||||||
|
],
|
||||||
|
"capabilities": {
|
||||||
|
"supportsRefImage": false,
|
||||||
|
"supportsLora": false
|
||||||
|
},
|
||||||
|
"variants": {
|
||||||
|
"t2i": "midjourney"
|
||||||
|
},
|
||||||
|
"resolutions": {
|
||||||
|
"1K": {
|
||||||
|
"1:1": "1024*1024",
|
||||||
|
"16:9": "1280*720",
|
||||||
|
"9:16": "720*1280",
|
||||||
|
"4:3": "1280*960",
|
||||||
|
"3:4": "960*1280"
|
||||||
|
},
|
||||||
|
"2K": {
|
||||||
|
"1:1": "1456*1456",
|
||||||
|
"16:9": "1456*816",
|
||||||
|
"9:16": "816*1456",
|
||||||
|
"4:3": "1232*928",
|
||||||
|
"3:4": "928*1232"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"counts": {"min": 1, "max": 4},
|
||||||
|
"enabled": false
|
||||||
|
}
|
||||||
|
}
|
||||||
23
backend/src/config/services/midjourney/provider.json
Normal file
23
backend/src/config/services/midjourney/provider.json
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
{
|
||||||
|
"id": "midjourney",
|
||||||
|
"name": "Midjourney(悠船)",
|
||||||
|
"description": "悠船 API - 需要 App ID 和 Secret Key",
|
||||||
|
"dashboard_url": "https://ali.youchuan.cn/",
|
||||||
|
"helpUrl": "https://ali.youchuan.cn/",
|
||||||
|
"fields": [
|
||||||
|
{
|
||||||
|
"name": "apiKey",
|
||||||
|
"label": "App ID",
|
||||||
|
"placeholder": "应用 ID",
|
||||||
|
"required": true,
|
||||||
|
"type": "text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "apiSecret",
|
||||||
|
"label": "Secret Key",
|
||||||
|
"placeholder": "密钥",
|
||||||
|
"required": true,
|
||||||
|
"type": "password"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
378
backend/src/config/services/minimax/audio.json
Normal file
378
backend/src/config/services/minimax/audio.json
Normal file
@@ -0,0 +1,378 @@
|
|||||||
|
{
|
||||||
|
"speech-2.8-hd": {
|
||||||
|
"name": "speech-2.8-hd",
|
||||||
|
"class": "src.services.provider.minimax.MiniMaxAudioService",
|
||||||
|
"args": [
|
||||||
|
"speech-2.8-hd"
|
||||||
|
],
|
||||||
|
"voices": [
|
||||||
|
{
|
||||||
|
"id": "male-qn-qingse",
|
||||||
|
"name": "青涩青年",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "青涩青年风格中文普通话男声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "male-qn-jingying",
|
||||||
|
"name": "精英青年",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "精英青年风格中文普通话男声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "male-qn-badao",
|
||||||
|
"name": "霸道青年",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "沉稳有力的中文普通话青年男声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "male-qn-daxuesheng",
|
||||||
|
"name": "青年大学生",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "青年大学生风格中文普通话男声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "female-shaonv",
|
||||||
|
"name": "少女",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "清亮少女风格中文普通话女声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "female-yujie",
|
||||||
|
"name": "御姐",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "成熟干练风格中文普通话女声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "female-chengshu",
|
||||||
|
"name": "成熟女性",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "成熟女性风格中文普通话女声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "female-tianmei",
|
||||||
|
"name": "甜美女性",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "甜美风格中文普通话女声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Chinese (Mandarin)_News_Anchor",
|
||||||
|
"name": "新闻女声",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "专业播音腔的中年女性新闻主播,标准普通话"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Chinese (Mandarin)_Male_Announcer",
|
||||||
|
"name": "播报男声",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "富有磁性的中年男性播报员声音,标准普通话,清晰而权威"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Chinese (Mandarin)_Gentleman",
|
||||||
|
"name": "温润男声",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "温润磁性的青年男性声音,标准普通话"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Chinese (Mandarin)_Sweet_Lady",
|
||||||
|
"name": "甜美女声",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "温柔甜美的青年女性声音,标准普通话"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Cantonese_ProfessionalHost(F)",
|
||||||
|
"name": "粤语专业女主持",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "中性、专业的青年女性粤语主持人声音"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Cantonese_ProfessionalHost(M)",
|
||||||
|
"name": "粤语专业男主持",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "中性、专业的青年男性粤语主持人声音"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
"speech-2.6-hd": {
|
||||||
|
"name": "speech-2.6-hd",
|
||||||
|
"class": "src.services.provider.minimax.MiniMaxAudioService",
|
||||||
|
"args": [
|
||||||
|
"speech-2.6-hd"
|
||||||
|
],
|
||||||
|
"voices": [
|
||||||
|
{
|
||||||
|
"id": "male-qn-qingse",
|
||||||
|
"name": "青涩青年",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "青涩青年风格中文普通话男声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "male-qn-jingying",
|
||||||
|
"name": "精英青年",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "精英青年风格中文普通话男声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "male-qn-badao",
|
||||||
|
"name": "霸道青年",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "沉稳有力的中文普通话青年男声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "male-qn-daxuesheng",
|
||||||
|
"name": "青年大学生",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "青年大学生风格中文普通话男声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "female-shaonv",
|
||||||
|
"name": "少女",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "清亮少女风格中文普通话女声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "female-yujie",
|
||||||
|
"name": "御姐",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "成熟干练风格中文普通话女声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "female-chengshu",
|
||||||
|
"name": "成熟女性",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "成熟女性风格中文普通话女声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "female-tianmei",
|
||||||
|
"name": "甜美女性",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "甜美风格中文普通话女声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Chinese (Mandarin)_News_Anchor",
|
||||||
|
"name": "新闻女声",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "专业、播音腔的中年女性新闻主播,标准普通话"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Chinese (Mandarin)_Male_Announcer",
|
||||||
|
"name": "播报男声",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "富有磁性的中年男性播报员声音,标准普通话,清晰而权威"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Chinese (Mandarin)_Gentleman",
|
||||||
|
"name": "温润男声",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "温润磁性的青年男性声音,标准普通话"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Chinese (Mandarin)_Sweet_Lady",
|
||||||
|
"name": "甜美女声",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "温柔甜美的青年女性声音,标准普通话"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Cantonese_ProfessionalHost(F)",
|
||||||
|
"name": "粤语专业女主持",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "中性、专业的青年女性粤语主持人声音"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Cantonese_ProfessionalHost(M)",
|
||||||
|
"name": "粤语专业男主持",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "中性、专业的青年男性粤语主持人声音"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
"speech-2.8-turbo": {
|
||||||
|
"name": "speech-2.8-turbo",
|
||||||
|
"class": "src.services.provider.minimax.MiniMaxAudioService",
|
||||||
|
"args": [
|
||||||
|
"speech-2.8-turbo"
|
||||||
|
],
|
||||||
|
"voices": [
|
||||||
|
{
|
||||||
|
"id": "male-qn-qingse",
|
||||||
|
"name": "青涩青年",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "青涩青年风格中文普通话男声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "male-qn-jingying",
|
||||||
|
"name": "精英青年",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "精英青年风格中文普通话男声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "male-qn-badao",
|
||||||
|
"name": "霸道青年",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "沉稳有力的中文普通话青年男声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "male-qn-daxuesheng",
|
||||||
|
"name": "青年大学生",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "青年大学生风格中文普通话男声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "female-shaonv",
|
||||||
|
"name": "少女",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "清亮少女风格中文普通话女声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "female-yujie",
|
||||||
|
"name": "御姐",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "成熟干练风格中文普通话女声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "female-chengshu",
|
||||||
|
"name": "成熟女性",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "成熟女性风格中文普通话女声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "female-tianmei",
|
||||||
|
"name": "甜美女性",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "甜美风格中文普通话女声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Chinese (Mandarin)_News_Anchor",
|
||||||
|
"name": "新闻女声",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "专业、播音腔的中年女性新闻主播,标准普通话"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Chinese (Mandarin)_Male_Announcer",
|
||||||
|
"name": "播报男声",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "富有磁性的中年男性播报员声音,标准普通话,清晰而权威"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Chinese (Mandarin)_Gentleman",
|
||||||
|
"name": "温润男声",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "温润磁性的青年男性声音,标准普通话"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Chinese (Mandarin)_Sweet_Lady",
|
||||||
|
"name": "甜美女声",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "温柔甜美的青年女性声音,标准普通话"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Cantonese_ProfessionalHost(F)",
|
||||||
|
"name": "粤语专业女主持",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "中性、专业的青年女性粤语主持人声音"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Cantonese_ProfessionalHost(M)",
|
||||||
|
"name": "粤语专业男主持",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "中性、专业的青年男性粤语主持人声音"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
"speech-2.6-turbo": {
|
||||||
|
"name": "speech-2.6-turbo",
|
||||||
|
"class": "src.services.provider.minimax.MiniMaxAudioService",
|
||||||
|
"args": [
|
||||||
|
"speech-2.6-turbo"
|
||||||
|
],
|
||||||
|
"voices": [
|
||||||
|
{
|
||||||
|
"id": "male-qn-qingse",
|
||||||
|
"name": "青涩青年",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "青涩青年风格中文普通话男声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "male-qn-jingying",
|
||||||
|
"name": "精英青年",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "精英青年风格中文普通话男声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "male-qn-badao",
|
||||||
|
"name": "霸道青年",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "沉稳有力的中文普通话青年男声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "male-qn-daxuesheng",
|
||||||
|
"name": "青年大学生",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "青年大学生风格中文普通话男声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "female-shaonv",
|
||||||
|
"name": "少女",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "清亮少女风格中文普通话女声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "female-yujie",
|
||||||
|
"name": "御姐",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "成熟干练风格中文普通话女声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "female-chengshu",
|
||||||
|
"name": "成熟女性",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "成熟女性风格中文普通话女声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "female-tianmei",
|
||||||
|
"name": "甜美女性",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "甜美风格中文普通话女声"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Chinese (Mandarin)_News_Anchor",
|
||||||
|
"name": "新闻女声",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "专业、播音腔的中年女性新闻主播,标准普通话"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Chinese (Mandarin)_Male_Announcer",
|
||||||
|
"name": "播报男声",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "富有磁性的中年男性播报员声音,标准普通话,清晰而权威"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Chinese (Mandarin)_Gentleman",
|
||||||
|
"name": "温润男声",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "温润磁性的青年男性声音,标准普通话"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Chinese (Mandarin)_Sweet_Lady",
|
||||||
|
"name": "甜美女声",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "温柔甜美的青年女性声音,标准普通话"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Cantonese_ProfessionalHost(F)",
|
||||||
|
"name": "粤语专业女主持",
|
||||||
|
"gender": "female",
|
||||||
|
"desc": "中性、专业的青年女性粤语主持人声音"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "Cantonese_ProfessionalHost(M)",
|
||||||
|
"name": "粤语专业男主持",
|
||||||
|
"gender": "male",
|
||||||
|
"desc": "中性、专业的青年男性粤语主持人声音"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"enabled": true
|
||||||
|
}
|
||||||
|
}
|
||||||
64
backend/src/config/services/minimax/image.json
Normal file
64
backend/src/config/services/minimax/image.json
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
{
|
||||||
|
"image-01": {
|
||||||
|
"name": "image-01",
|
||||||
|
"class": "src.services.provider.minimax.MiniMaxImageService",
|
||||||
|
"capabilities": {
|
||||||
|
"supportsRefImage": false,
|
||||||
|
"supportsLora": false
|
||||||
|
},
|
||||||
|
"resolutions": {
|
||||||
|
"1K": {
|
||||||
|
"1:1": "1024x1024",
|
||||||
|
"16:9": "1280x720",
|
||||||
|
"9:16": "720x1280",
|
||||||
|
"4:3": "1152x864",
|
||||||
|
"3:4": "864x1152",
|
||||||
|
"3:2": "1248x832",
|
||||||
|
"2:3": "832x1248",
|
||||||
|
"21:9": "1344x576"
|
||||||
|
},
|
||||||
|
"2K": {
|
||||||
|
"1:1": "2048x2048",
|
||||||
|
"16:9": "2560x1440",
|
||||||
|
"9:16": "1440x2560",
|
||||||
|
"4:3": "2304x1728",
|
||||||
|
"3:4": "1728x2304",
|
||||||
|
"3:2": "2496x1664",
|
||||||
|
"2:3": "1664x2496",
|
||||||
|
"21:9": "2688x1152"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"counts": {"min": 1, "max": 4}
|
||||||
|
},
|
||||||
|
"image-01-live": {
|
||||||
|
"name": "image-01-live",
|
||||||
|
"class": "src.services.provider.minimax.MiniMaxImageService",
|
||||||
|
"capabilities": {
|
||||||
|
"supportsRefImage": true,
|
||||||
|
"supportsLora": false
|
||||||
|
},
|
||||||
|
"resolutions": {
|
||||||
|
"1K": {
|
||||||
|
"1:1": "1024x1024",
|
||||||
|
"16:9": "1280x720",
|
||||||
|
"9:16": "720x1280",
|
||||||
|
"4:3": "1152x864",
|
||||||
|
"3:4": "864x1152",
|
||||||
|
"3:2": "1248x832",
|
||||||
|
"2:3": "832x1248",
|
||||||
|
"21:9": "1344x576"
|
||||||
|
},
|
||||||
|
"2K": {
|
||||||
|
"1:1": "2048x2048",
|
||||||
|
"16:9": "2560x1440",
|
||||||
|
"9:16": "1440x2560",
|
||||||
|
"4:3": "2304x1728",
|
||||||
|
"3:4": "1728x2304",
|
||||||
|
"3:2": "2496x1664",
|
||||||
|
"2:3": "1664x2496",
|
||||||
|
"21:9": "2688x1152"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"counts": {"min": 1, "max": 4}
|
||||||
|
}
|
||||||
|
}
|
||||||
11
backend/src/config/services/minimax/llm.json
Normal file
11
backend/src/config/services/minimax/llm.json
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
{
|
||||||
|
"base_url": "https://api.minimaxi.com/v1",
|
||||||
|
"MiniMax-M2.11": {
|
||||||
|
"name": "MiniMax M2.1",
|
||||||
|
"class": "src.services.provider.openai_service.OpenAIService",
|
||||||
|
"args": [
|
||||||
|
"MiniMax-M2.1"
|
||||||
|
],
|
||||||
|
"enabled": true
|
||||||
|
}
|
||||||
|
}
|
||||||
22
backend/src/config/services/minimax/music.json
Normal file
22
backend/src/config/services/minimax/music.json
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"lyrics-2.5": {
|
||||||
|
"name": "lyrics-2.5",
|
||||||
|
"class": "src.services.provider.minimax.MiniMaxMusicService",
|
||||||
|
"args": [
|
||||||
|
"music-2.5"
|
||||||
|
],
|
||||||
|
"type": "lyrics",
|
||||||
|
"enabled": true,
|
||||||
|
"is_default": true
|
||||||
|
},
|
||||||
|
"music-2.5": {
|
||||||
|
"name": "music-2.5",
|
||||||
|
"class": "src.services.provider.minimax.MiniMaxMusicService",
|
||||||
|
"args": [
|
||||||
|
"music-2.5"
|
||||||
|
],
|
||||||
|
"type": "music",
|
||||||
|
"enabled": true,
|
||||||
|
"is_default": true
|
||||||
|
}
|
||||||
|
}
|
||||||
23
backend/src/config/services/minimax/provider.json
Normal file
23
backend/src/config/services/minimax/provider.json
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
{
|
||||||
|
"id": "minimax",
|
||||||
|
"name": "MiniMax",
|
||||||
|
"description": "海螺AI",
|
||||||
|
"dashboard_url": "https://platform.minimaxi.com/",
|
||||||
|
"helpUrl": "https://platform.minimaxi.com/user-center/basic-information/interface-key",
|
||||||
|
"fields": [
|
||||||
|
{
|
||||||
|
"name": "apiKey",
|
||||||
|
"label": "API Key",
|
||||||
|
"placeholder": "sk-api-...",
|
||||||
|
"required": true,
|
||||||
|
"type": "password"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "groupId",
|
||||||
|
"label": "Group ID (可选)",
|
||||||
|
"placeholder": "分组 ID",
|
||||||
|
"required": false,
|
||||||
|
"type": "text"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
68
backend/src/config/services/minimax/video.json
Normal file
68
backend/src/config/services/minimax/video.json
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
{
|
||||||
|
"MiniMax-Hailuo-2.3": {
|
||||||
|
"name": "Hailuo 2.3",
|
||||||
|
"class": "src.services.provider.minimax.MiniMaxVideoService",
|
||||||
|
"capabilities": {
|
||||||
|
"supportsTextToVideo": true,
|
||||||
|
"supportsImageToVideo": true,
|
||||||
|
"supportsFirstFrame": true,
|
||||||
|
"supportsLastFrame": true,
|
||||||
|
"supportsMultiImage": false,
|
||||||
|
"supportsMultiVideo": false,
|
||||||
|
"supportsAudio": false,
|
||||||
|
"supportsNegativePrompt": false,
|
||||||
|
"supportsLora": false
|
||||||
|
},
|
||||||
|
"durations": {
|
||||||
|
"values": [
|
||||||
|
6,
|
||||||
|
10
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"counts": {"min": 1, "max": 4}
|
||||||
|
},
|
||||||
|
"MiniMax-Hailuo-02": {
|
||||||
|
"name": "Hailuo 02",
|
||||||
|
"class": "src.services.provider.minimax.MiniMaxVideoService",
|
||||||
|
"capabilities": {
|
||||||
|
"supportsTextToVideo": true,
|
||||||
|
"supportsImageToVideo": true,
|
||||||
|
"supportsFirstFrame": true,
|
||||||
|
"supportsLastFrame": true,
|
||||||
|
"supportsMultiImage": false,
|
||||||
|
"supportsMultiVideo": false,
|
||||||
|
"supportsAudio": false,
|
||||||
|
"supportsNegativePrompt": false,
|
||||||
|
"supportsLora": false
|
||||||
|
},
|
||||||
|
"durations": {
|
||||||
|
"values": [
|
||||||
|
6,
|
||||||
|
10
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"counts": {"min": 1, "max": 4}
|
||||||
|
},
|
||||||
|
"T2V-01-Director": {
|
||||||
|
"name": "T2V-01-Director",
|
||||||
|
"class": "src.services.provider.minimax.MiniMaxVideoService",
|
||||||
|
"capabilities": {
|
||||||
|
"supportsTextToVideo": true,
|
||||||
|
"supportsImageToVideo": true,
|
||||||
|
"supportsFirstFrame": true,
|
||||||
|
"supportsLastFrame": true,
|
||||||
|
"supportsMultiImage": false,
|
||||||
|
"supportsMultiVideo": false,
|
||||||
|
"supportsAudio": false,
|
||||||
|
"supportsNegativePrompt": false,
|
||||||
|
"supportsLora": false
|
||||||
|
},
|
||||||
|
"durations": {
|
||||||
|
"values": [
|
||||||
|
6,
|
||||||
|
10
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"counts": {"min": 1, "max": 4}
|
||||||
|
}
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user