from datetime import datetime, timedelta import pytest from sqlalchemy.pool import StaticPool from sqlmodel import Session, SQLModel, create_engine from src.auth.jwt import create_token_pair, decode_token_unsafe from src.models.entities import UserDB import src.services.session_service as session_service_module from src.services.session_service import SessionService @pytest.fixture def engine(): engine = create_engine( "sqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) SQLModel.metadata.create_all(engine) return engine @pytest.fixture def session_service(engine, monkeypatch): monkeypatch.setattr(session_service_module, "engine", engine) return SessionService() @pytest.fixture def user_id(engine): now = datetime.now().timestamp() user_id = "user-test" user = UserDB( id=user_id, username="session_user", email="session@example.com", password_hash="hashed-password", is_active=True, is_superuser=False, permissions=[], roles=["user"], created_at=now, updated_at=now, ) with Session(engine) as session: session.add(user) session.commit() return user_id def test_create_token_pair_includes_session_claims(): tokens = create_token_pair( user_id="user-1", scopes=["user"], session_id="session-1", session_family_id="family-1", ) access_payload = decode_token_unsafe(tokens.access_token) refresh_payload = decode_token_unsafe(tokens.refresh_token) assert tokens.session_id == "session-1" assert tokens.session_family_id == "family-1" assert access_payload["sid"] == "session-1" assert refresh_payload["sid"] == "session-1" assert refresh_payload["sfid"] == "family-1" def test_session_service_create_and_validate_refresh_token(session_service, user_id): created = session_service.create_session( user_id=user_id, refresh_token="refresh-token-1", session_id="session-1", session_family_id="family-1", ) validated = session_service.validate_refresh_token(created.id, "refresh-token-1") assert created.id == "session-1" assert created.session_family_id == "family-1" assert validated is not None assert validated.id == created.id assert session_service.is_session_active(created.id) is True def test_rotate_refresh_token_revokes_previous_session(session_service, user_id): created = session_service.create_session( user_id=user_id, refresh_token="refresh-token-1", session_id="session-1", session_family_id="family-1", ) rotated = session_service.rotate_refresh_token( created.id, "refresh-token-1", "refresh-token-2", new_session_id="session-2", ) previous = session_service.get_session(created.id) assert rotated is not None assert rotated.id == "session-2" assert rotated.session_family_id == "family-1" assert previous is not None assert previous.status == "rotated" assert previous.replaced_by_session_id == "session-2" assert session_service.validate_refresh_token("session-2", "refresh-token-2") is not None def test_refresh_token_reuse_revokes_session_family(session_service, user_id): created = session_service.create_session( user_id=user_id, refresh_token="refresh-token-1", session_id="session-1", session_family_id="family-1", ) rotated = session_service.rotate_refresh_token( created.id, "refresh-token-1", "refresh-token-2", new_session_id="session-2", ) invalidated = session_service.validate_refresh_token("session-2", "wrong-refresh-token") active_sessions = session_service.list_user_sessions(user_id, include_inactive=True) assert rotated is not None assert invalidated is None assert all(session.status in {"rotated", "revoked"} for session in active_sessions) assert any(session.id == "session-2" and session.status == "revoked" for session in active_sessions) def test_expired_session_is_not_active(session_service, user_id): expired_session = session_service.create_session( user_id=user_id, refresh_token="refresh-token-expired", session_id="session-expired", expires_at=(datetime.now() - timedelta(minutes=1)).timestamp(), ) validated = session_service.validate_refresh_token(expired_session.id, "refresh-token-expired") expired = session_service.get_session(expired_session.id) assert validated is None assert expired is not None assert expired.status == "revoked" assert expired.revoked_reason == "expired"