stock/tests/test_config_api.py
ZhangPeng 9aecdd036c Initial commit: OpenClaw Trading - AI多智能体量化交易系统
- 添加项目核心代码和配置
- 添加前端界面 (Next.js)
- 添加单元测试
- 更新 .gitignore 排除缓存和依赖
2026-02-27 03:47:40 +08:00

437 lines
14 KiB
Python

"""Tests for configuration management API endpoints.
Tests the REST API endpoints for configuration management:
- GET /api/config - Returns current configuration
- POST /api/config - Saves and validates configuration
- GET /api/config/schema - Returns JSON schema for form generation
"""
import json
import tempfile
from pathlib import Path
from typing import Any
from unittest.mock import patch
import pytest
import yaml
from fastapi.testclient import TestClient
from openclaw.dashboard.app import create_app
from openclaw.core.config import (
ConfigLoader,
CostStructure,
LLMConfig,
OpenClawConfig,
SurvivalThresholds,
get_config,
reload_config,
set_config,
)
@pytest.fixture
def app():
"""Create a fresh FastAPI application for testing."""
return create_app()
@pytest.fixture
def client(app):
"""Create a TestClient for the FastAPI application."""
return TestClient(app)
@pytest.fixture(autouse=True)
def reset_config():
"""Reset global config before and after each test."""
# Clear any existing config
set_config(None)
yield
# Clean up after test
set_config(None)
@pytest.fixture
def temp_config_file(tmp_path: Path) -> Path:
"""Create a temporary config file for testing."""
config_file = tmp_path / "test_config.yaml"
config_data = {
"initial_capital": {"trader": 15000.0, "analyst": 7500.0},
"simulation_days": 60,
"log_level": "DEBUG",
"cost_structure": {
"llm_input_per_1m": 3.0,
"llm_output_per_1m": 12.0,
"market_data_per_call": 0.02,
"trade_fee_rate": 0.002,
},
"survival_thresholds": {
"thriving_multiplier": 2.5,
"stable_multiplier": 1.3,
"struggling_multiplier": 0.7,
"bankrupt_multiplier": 0.1,
},
}
config_file.write_text(yaml.dump(config_data))
return config_file
@pytest.fixture
def valid_config_payload() -> dict[str, Any]:
"""Return a valid configuration payload for API testing."""
return {
"initial_capital": {"trader": 20000.0, "analyst": 10000.0, "risk_manager": 8000.0},
"simulation_days": 90,
"log_level": "INFO",
"data_dir": "./data",
"cost_structure": {
"llm_input_per_1m": 2.5,
"llm_output_per_1m": 10.0,
"market_data_per_call": 0.01,
"trade_fee_rate": 0.001,
},
"survival_thresholds": {
"thriving_multiplier": 3.0,
"stable_multiplier": 1.5,
"struggling_multiplier": 0.8,
"bankrupt_multiplier": 0.1,
},
"llm_providers": {
"openai": {
"model": "gpt-4o",
"temperature": 0.7,
"timeout": 30,
"api_key": None,
"base_url": None,
"max_tokens": None,
}
},
}
class TestGetConfig:
"""Tests for GET /api/config endpoint."""
def test_get_config_returns_valid_config(self, client: TestClient) -> None:
"""Test that GET /api/config returns a valid configuration object."""
response = client.get("/api/config")
assert response.status_code == 200
data = response.json()
# Check required fields are present
assert "initial_capital" in data
assert "simulation_days" in data
assert "log_level" in data
assert "cost_structure" in data
assert "survival_thresholds" in data
def test_get_config_default_values(self, client: TestClient) -> None:
"""Test that GET /api/config returns default values when no config file exists."""
response = client.get("/api/config")
assert response.status_code == 200
data = response.json()
# Verify default values
assert data["initial_capital"]["trader"] == 10000.0
assert data["initial_capital"]["analyst"] == 5000.0
assert data["simulation_days"] == 30
assert data["log_level"] == "INFO"
assert data["cost_structure"]["llm_input_per_1m"] == 2.5
assert data["cost_structure"]["trade_fee_rate"] == 0.001
def test_get_config_loaded_from_file(
self, client: TestClient, temp_config_file: Path
) -> None:
"""Test that GET /api/config returns config loaded from file."""
# Load the temp config file
with patch.object(ConfigLoader, "_resolve_config_path", return_value=temp_config_file):
set_config(None) # Clear cache
response = client.get("/api/config")
assert response.status_code == 200
data = response.json()
# Verify values from file
assert data["initial_capital"]["trader"] == 15000.0
assert data["simulation_days"] == 60
assert data["log_level"] == "DEBUG"
class TestPostConfig:
"""Tests for POST /api/config endpoint."""
def test_post_valid_config_saves_successfully(
self, client: TestClient, valid_config_payload: dict[str, Any]
) -> None:
"""Test that POST /api/config saves valid configuration."""
response = client.post("/api/config", json=valid_config_payload)
assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
assert "message" in data
def test_post_config_updates_values(
self, client: TestClient, valid_config_payload: dict[str, Any]
) -> None:
"""Test that POST /api/config updates configuration values."""
# Post new config
response = client.post("/api/config", json=valid_config_payload)
assert response.status_code == 200
# Verify values were updated by getting config
response = client.get("/api/config")
data = response.json()
assert data["simulation_days"] == 90
assert data["initial_capital"]["trader"] == 20000.0
assert data["log_level"] == "INFO"
def test_post_config_partial_update(
self, client: TestClient
) -> None:
"""Test that POST /api/config handles partial updates."""
partial_config = {"simulation_days": 45, "log_level": "WARNING"}
response = client.post("/api/config", json=partial_config)
assert response.status_code == 200
# Verify only specified fields were updated
response = client.get("/api/config")
data = response.json()
assert data["simulation_days"] == 45
assert data["log_level"] == "WARNING"
# Other fields should retain default values
assert data["initial_capital"]["trader"] == 10000.0
def test_post_invalid_config_rejected(
self, client: TestClient
) -> None:
"""Test that POST /api/config rejects invalid configuration values."""
invalid_config = {
"simulation_days": -5, # Invalid: must be positive
}
response = client.post("/api/config", json=invalid_config)
# API returns 200 with error status instead of 422
assert response.status_code == 200
data = response.json()
assert data.get("status") == "error"
assert "simulation_days" in data.get("message", "").lower() or "validation" in data.get("message", "").lower()
def test_post_invalid_log_level_rejected(self, client: TestClient) -> None:
"""Test that invalid log level is rejected."""
invalid_config = {"log_level": "INVALID"}
response = client.post("/api/config", json=invalid_config)
assert response.status_code == 200
data = response.json()
assert data.get("status") == "error"
def test_post_invalid_cost_structure_rejected(self, client: TestClient) -> None:
"""Test that invalid cost structure values are rejected."""
invalid_config = {
"cost_structure": {
"llm_input_per_1m": -1.0, # Invalid: must be positive
}
}
response = client.post("/api/config", json=invalid_config)
assert response.status_code == 200
data = response.json()
assert data.get("status") == "error"
def test_post_invalid_survival_thresholds_rejected(self, client: TestClient) -> None:
"""Test that invalid survival threshold values are rejected."""
invalid_config = {
"survival_thresholds": {
"thriving_multiplier": 0.5, # Invalid: must be > 1
}
}
response = client.post("/api/config", json=invalid_config)
assert response.status_code == 200
data = response.json()
assert data.get("status") == "error"
def test_post_invalid_trade_fee_rate_rejected(self, client: TestClient) -> None:
"""Test that trade fee rate > 1 is rejected."""
invalid_config = {
"cost_structure": {
"trade_fee_rate": 1.5, # Invalid: must be <= 1
}
}
response = client.post("/api/config", json=invalid_config)
assert response.status_code == 200
data = response.json()
assert data.get("status") == "error"
class TestGetConfigSchema:
"""Tests for GET /api/config/schema endpoint."""
def test_get_config_schema_returns_schema(self, client: TestClient) -> None:
"""Test that GET /api/config/schema returns JSON schema."""
response = client.get("/api/config/schema")
assert response.status_code == 200
data = response.json()
# Check schema structure
assert "title" in data or "$defs" in data or "properties" in data
def test_get_config_schema_contains_expected_fields(self, client: TestClient) -> None:
"""Test that schema contains expected configuration fields."""
response = client.get("/api/config/schema")
assert response.status_code == 200
data = response.json()
# Schema should contain properties or definitions
schema_str = json.dumps(data).lower()
assert "initial_capital" in schema_str or "initialcapital" in schema_str
assert "simulation_days" in schema_str or "simulationdays" in schema_str
assert "log_level" in schema_str or "loglevel" in schema_str
class TestConfigRoundtrip:
"""Tests for config save/load roundtrip."""
def test_config_roundtrip(
self, client: TestClient, valid_config_payload: dict[str, Any]
) -> None:
"""Test that config can be saved and then loaded with same values."""
# Save config
response = client.post("/api/config", json=valid_config_payload)
assert response.status_code == 200
# Load config and verify values match
response = client.get("/api/config")
data = response.json()
assert data["simulation_days"] == valid_config_payload["simulation_days"]
assert data["log_level"] == valid_config_payload["log_level"]
assert data["initial_capital"]["trader"] == valid_config_payload["initial_capital"]["trader"]
def test_multiple_updates_preserve_values(
self, client: TestClient
) -> None:
"""Test that multiple config updates work correctly."""
# First update
config1 = {"simulation_days": 30, "log_level": "INFO"}
response = client.post("/api/config", json=config1)
assert response.status_code == 200
# Second update
config2 = {"simulation_days": 60}
response = client.post("/api/config", json=config2)
assert response.status_code == 200
# Verify final state
response = client.get("/api/config")
data = response.json()
assert data["simulation_days"] == 60
assert data["log_level"] == "INFO" # Should retain from first update
class TestConfigValidation:
"""Tests for configuration validation in API."""
def test_negative_initial_capital_rejected(self, client: TestClient) -> None:
"""Test that negative initial capital values are rejected."""
invalid_config = {
"initial_capital": {"trader": -1000.0}
}
response = client.post("/api/config", json=invalid_config)
assert response.status_code == 200
data = response.json()
assert data.get("status") == "error"
def test_zero_simulation_days_rejected(self, client: TestClient) -> None:
"""Test that zero simulation days is rejected."""
invalid_config = {"simulation_days": 0}
response = client.post("/api/config", json=invalid_config)
assert response.status_code == 200
data = response.json()
assert data.get("status") == "error"
def test_invalid_temperature_rejected(self, client: TestClient) -> None:
"""Test that invalid temperature values are rejected."""
invalid_config = {
"llm_providers": {
"openai": {
"model": "gpt-4o",
"temperature": 3.0, # Invalid: must be <= 2
}
}
}
response = client.post("/api/config", json=invalid_config)
assert response.status_code == 200
data = response.json()
assert data.get("status") == "error"
def test_nested_validation_error_detail(self, client: TestClient) -> None:
"""Test that validation errors include detail about nested fields."""
invalid_config = {
"cost_structure": {
"llm_input_per_1m": -5.0,
}
}
response = client.post("/api/config", json=invalid_config)
assert response.status_code == 200
data = response.json()
# Should have error status with message
assert data.get("status") == "error"
assert "message" in data
class TestErrorHandling:
"""Tests for API error handling."""
def test_malformed_json_rejected(self, client: TestClient) -> None:
"""Test that malformed JSON is handled gracefully."""
response = client.post(
"/api/config",
data="{invalid json",
headers={"Content-Type": "application/json"},
)
assert response.status_code == 422
def test_empty_body_rejected(self, client: TestClient) -> None:
"""Test that empty request body is handled."""
response = client.post("/api/config", json={})
# Empty config should be valid (uses defaults)
assert response.status_code == 200
def test_invalid_content_type_handled(self, client: TestClient) -> None:
"""Test that invalid content type is handled."""
response = client.post(
"/api/config",
data="not json",
headers={"Content-Type": "text/plain"},
)
# FastAPI should return 415 Unsupported Media Type or 422
assert response.status_code in [415, 422]