168 lines
5.6 KiB
Python
168 lines
5.6 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""The SQLite session class."""
|
|
import json
|
|
import os
|
|
import sqlite3
|
|
|
|
from agentscope import logger
|
|
from agentscope.module import StateModule
|
|
from agentscope.session import SessionBase
|
|
|
|
|
|
class SqliteSession(SessionBase):
|
|
"""A session that uses SQLite for storage."""
|
|
|
|
def __init__(
|
|
self,
|
|
sqlite_path: str,
|
|
) -> None:
|
|
"""Initialize the session.
|
|
|
|
Args:
|
|
sqlite_path (`str`):
|
|
The path to the SQLite database file.
|
|
"""
|
|
self.sqlite_path = sqlite_path
|
|
|
|
async def save_session_state(
|
|
self,
|
|
session_id: str,
|
|
**state_modules_mapping: StateModule,
|
|
) -> None:
|
|
"""Save the session state to the SQLite database."""
|
|
with sqlite3.connect(self.sqlite_path) as conn:
|
|
cursor = conn.cursor()
|
|
# Prepare the session data as a dictionary
|
|
session_data = {
|
|
name: module.state_dict()
|
|
for name, module in state_modules_mapping.items()
|
|
}
|
|
|
|
json_data = json.dumps(session_data)
|
|
|
|
cursor.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS as_session (
|
|
session_id TEXT,
|
|
session_data JSON,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
PRIMARY KEY (session_id)
|
|
)
|
|
""",
|
|
)
|
|
|
|
# Insert or replace the session data
|
|
cursor.execute(
|
|
"""
|
|
INSERT INTO as_session (session_id, session_data, updated_at)
|
|
VALUES (?, json(?), CURRENT_TIMESTAMP)
|
|
ON CONFLICT(session_id) DO UPDATE SET
|
|
session_data = excluded.session_data,
|
|
updated_at = excluded.updated_at
|
|
""",
|
|
(session_id, json_data),
|
|
)
|
|
conn.commit()
|
|
cursor.close()
|
|
|
|
async def load_session_state(
|
|
self,
|
|
session_id: str,
|
|
allow_not_exist: bool = True,
|
|
**state_modules_mapping: StateModule,
|
|
) -> None:
|
|
"""Get the state dictionary from the SQLite database.
|
|
|
|
Args:
|
|
session_id (`str`):
|
|
The session id.
|
|
allow_not_exist (`bool`, defaults to `True`):
|
|
Whether to allow the session to not exist. If `False`, raises
|
|
an error if the session does not exist.
|
|
**state_modules_mapping (`list[StateModule]`):
|
|
The list of state modules to be loaded.
|
|
"""
|
|
if not os.path.exists(self.sqlite_path):
|
|
if allow_not_exist:
|
|
logger.info(
|
|
"SQLite database %s does not exist. "
|
|
"Skipping load for session_id %s.",
|
|
self.sqlite_path,
|
|
session_id,
|
|
)
|
|
return
|
|
raise ValueError(
|
|
"Failed to load session state because the SQLite database "
|
|
f"file '{self.sqlite_path}' does not exist.",
|
|
)
|
|
|
|
with sqlite3.connect(self.sqlite_path) as conn:
|
|
cursor = conn.cursor()
|
|
|
|
try:
|
|
# If the table does not exist, return
|
|
cursor.execute(
|
|
"""
|
|
SELECT name FROM sqlite_master WHERE type='table' AND
|
|
name='as_session';
|
|
""",
|
|
)
|
|
if cursor.fetchone() is None:
|
|
if allow_not_exist:
|
|
logger.info(
|
|
"Session table does not exist in database %s. "
|
|
"Skipping load for session_id %s.",
|
|
self.sqlite_path,
|
|
session_id,
|
|
)
|
|
return
|
|
|
|
raise ValueError(
|
|
"Failed to load session state because the session "
|
|
"table 'as_session' does not exist in database "
|
|
f"{self.sqlite_path}.",
|
|
)
|
|
|
|
# Query the session data
|
|
cursor.execute(
|
|
"SELECT session_data FROM as_session WHERE session_id = ?",
|
|
(session_id,),
|
|
)
|
|
row = cursor.fetchone()
|
|
|
|
if row is None:
|
|
if allow_not_exist:
|
|
logger.info(
|
|
"Session_id %s does not exist in database %s. "
|
|
"Skip loading.",
|
|
session_id,
|
|
self.sqlite_path,
|
|
)
|
|
return
|
|
|
|
raise ValueError(
|
|
f"Failed to load session state for session_id "
|
|
f"{session_id} does not exist.",
|
|
)
|
|
|
|
session_data = json.loads(row[0])
|
|
|
|
for name, module in state_modules_mapping.items():
|
|
if name in session_data:
|
|
module.load_state_dict(session_data[name])
|
|
else:
|
|
raise ValueError(
|
|
f"State module '{name}' not found in session "
|
|
"data.",
|
|
)
|
|
logger.info(
|
|
"Load session state for session_id %s from "
|
|
"database %s successfully.",
|
|
session_id,
|
|
self.sqlite_path,
|
|
)
|
|
|
|
finally:
|
|
cursor.close()
|