- FastAPI backend with SQLModel, Alembic migrations, AgentScope agents - Next.js 15 frontend with React 19, Tailwind, Zustand, React Flow - Multi-provider AI system (DashScope, Kling, MiniMax, Volcengine, OpenAI, etc.) - All HTTP clients migrated from sync requests to async httpx - Admin-managed API keys via environment variables - SSRF vulnerability fixed in ensure_url()
180 lines
5.7 KiB
Python
180 lines
5.7 KiB
Python
"""
|
|
Integration tests for error handler middleware
|
|
|
|
Tests the error handler middleware integration with FastAPI.
|
|
"""
|
|
import pytest
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.testclient import TestClient
|
|
from src.middlewares.error_handler import setup_error_handler
|
|
from src.utils.errors import (
|
|
ProjectNotFoundException,
|
|
TaskTimeoutException,
|
|
InvalidParameterException,
|
|
ModelNotFoundException,
|
|
RateLimitExceededException
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def app():
|
|
"""Create a test FastAPI app with error handler"""
|
|
app = FastAPI()
|
|
|
|
# Setup error handler
|
|
setup_error_handler(app)
|
|
|
|
# Test routes that raise different exceptions
|
|
@app.get("/test/business-error")
|
|
async def business_error():
|
|
raise ProjectNotFoundException(project_id="test_123")
|
|
|
|
@app.get("/test/system-error")
|
|
async def system_error():
|
|
raise TaskTimeoutException(task_id="task_123", timeout=300)
|
|
|
|
@app.get("/test/invalid-param")
|
|
async def invalid_param():
|
|
raise InvalidParameterException(field="email", reason="Invalid format")
|
|
|
|
@app.get("/test/model-not-found")
|
|
async def model_not_found():
|
|
raise ModelNotFoundException(model_id="flux-pro")
|
|
|
|
@app.get("/test/rate-limit")
|
|
async def rate_limit():
|
|
raise RateLimitExceededException(limit=100, window=60)
|
|
|
|
@app.get("/test/unexpected-error")
|
|
async def unexpected_error():
|
|
raise ValueError("Unexpected error")
|
|
|
|
@app.get("/test/success")
|
|
async def success():
|
|
return {"message": "Success"}
|
|
|
|
return app
|
|
|
|
|
|
@pytest.fixture
|
|
def client(app):
|
|
"""Create a test client"""
|
|
return TestClient(app)
|
|
|
|
|
|
class TestErrorHandlerMiddleware:
|
|
"""Test error handler middleware integration"""
|
|
|
|
def test_business_exception_response(self, client):
|
|
"""Test business exception returns 400 with correct format"""
|
|
response = client.get("/test/business-error")
|
|
|
|
assert response.status_code == 400
|
|
data = response.json()
|
|
|
|
# Check response format
|
|
assert "code" in data
|
|
assert "message" in data
|
|
assert "details" in data
|
|
assert "request_id" in data
|
|
assert "timestamp" in data
|
|
|
|
# Check error details
|
|
assert data["code"] == "2001"
|
|
assert "Project not found" in data["message"]
|
|
assert data["details"]["project_id"] == "test_123"
|
|
|
|
# Check headers
|
|
assert "X-Request-ID" in response.headers
|
|
assert "X-Timestamp" in response.headers
|
|
|
|
def test_system_exception_response(self, client):
|
|
"""Test system exception returns 500 with correct format"""
|
|
response = client.get("/test/system-error")
|
|
|
|
assert response.status_code == 500
|
|
data = response.json()
|
|
|
|
assert data["code"] == "3003"
|
|
assert "timeout" in data["message"].lower()
|
|
assert data["details"]["task_id"] == "task_123"
|
|
assert data["details"]["timeout_seconds"] == 300
|
|
|
|
def test_invalid_parameter_exception(self, client):
|
|
"""Test invalid parameter exception"""
|
|
response = client.get("/test/invalid-param")
|
|
|
|
assert response.status_code == 400
|
|
data = response.json()
|
|
|
|
assert data["code"] == "1001"
|
|
assert data["details"]["field"] == "email"
|
|
assert data["details"]["reason"] == "Invalid format"
|
|
|
|
def test_model_not_found_exception(self, client):
|
|
"""Test model not found exception"""
|
|
response = client.get("/test/model-not-found")
|
|
|
|
assert response.status_code == 400
|
|
data = response.json()
|
|
|
|
assert data["code"] == "4001"
|
|
assert data["details"]["model_id"] == "flux-pro"
|
|
|
|
def test_rate_limit_exception(self, client):
|
|
"""Test rate limit exception returns 429"""
|
|
response = client.get("/test/rate-limit")
|
|
|
|
assert response.status_code == 429
|
|
data = response.json()
|
|
|
|
assert data["code"] == "1007"
|
|
assert data["details"]["limit"] == 100
|
|
assert data["details"]["window_seconds"] == 60
|
|
|
|
def test_unexpected_exception_response(self, client):
|
|
"""Test unexpected exception returns 500"""
|
|
response = client.get("/test/unexpected-error")
|
|
|
|
assert response.status_code == 500
|
|
data = response.json()
|
|
|
|
assert data["code"] == "1000"
|
|
assert "internal error" in data["message"].lower()
|
|
assert "request_id" in data
|
|
assert "timestamp" in data
|
|
|
|
def test_success_response_has_headers(self, client):
|
|
"""Test successful response includes request ID and timestamp headers"""
|
|
response = client.get("/test/success")
|
|
|
|
assert response.status_code == 200
|
|
assert "X-Request-ID" in response.headers
|
|
assert "X-Timestamp" in response.headers
|
|
|
|
def test_request_id_consistency(self, client):
|
|
"""Test request ID is consistent in response and headers"""
|
|
response = client.get("/test/business-error")
|
|
|
|
data = response.json()
|
|
assert data["request_id"] == response.headers["X-Request-ID"]
|
|
|
|
def test_timestamp_format(self, client):
|
|
"""Test timestamp is in ISO format"""
|
|
response = client.get("/test/business-error")
|
|
|
|
data = response.json()
|
|
timestamp = data["timestamp"]
|
|
|
|
# Check ISO format (ends with Z for UTC)
|
|
assert timestamp.endswith("Z")
|
|
assert "T" in timestamp
|
|
|
|
# Verify it's a valid ISO timestamp
|
|
from datetime import datetime
|
|
datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|