# -*- coding: utf-8 -*- """ This file defines Dataclass and tool implementations. Modified from https://github.com/OpenPipe/ART/blob/art-e/ """ import datetime import os import sqlite3 from dataclasses import dataclass from typing import Any, List, Optional from pydantic import BaseModel, Field, field_validator from agentscope import logger DEFAULT_DB_PATH = os.environ.get("DEFAULT_EMAIL_DB_PATH") conn = None def get_conn() -> sqlite3.Connection: """Get or create a database connection.""" global conn if conn is None: conn = sqlite3.connect( f"file:{DEFAULT_DB_PATH}?mode=ro", uri=True, check_same_thread=False, ) return conn class QueryModel(BaseModel): """Model for email search query.""" id: int question: str answer: str message_ids: List[str] # message_ids (strings) of referenced emails how_realistic: float inbox_address: str query_date: str @field_validator("query_date", mode="before") @classmethod def format_date(cls, v: Any) -> str: """Format date to string if it's a datetime object.""" if isinstance(v, datetime.datetime): return v.strftime("%Y-%m-%d") return v class AnswerModel(BaseModel): """Model for agent's answer with sources.""" answer: str = Field( description=( "It should be called with the answer and the sources. " "If you cannot find the answer, you should return " "'I don't know' with an empty list of sources." ), ) sources: List[str] = Field( description=( "a list of message ids that are relevant to the query. " "Usually there will be only one. If you cannot find the " "answer, you should return an empty list." ), ) class Email(BaseModel): """Model representing an email.""" message_id: str date: str # ISO 8601 string 'YYYY-MM-DD HH:MM:SS' subject: Optional[str] = None from_address: Optional[str] = None to_addresses: List[str] = Field(default_factory=list) cc_addresses: List[str] = Field(default_factory=list) bcc_addresses: List[str] = Field(default_factory=list) body: Optional[str] = None file_name: Optional[str] = None @dataclass class SearchResult: """Result from email search.""" message_id: str snippet: str class FinalRubric(BaseModel): """Rubric for evaluating agent performance.""" answer_correct: bool = False sources_correct: bool = False num_turns: int = 0 attempted_answer: bool = False ever_found_right_email: bool = False ever_read_right_email: bool = False cant_parse_tool_call: bool = False bad_tool_call_name: bool = False bad_tool_call_args: bool = False ran_out_of_turns: bool = False returned_i_dont_know: bool = False num_sources: int = 0 ever_tried_to_read_invalid_email: bool = False prompt_tokens: int = 0 completion_tokens: int = 0 # Define tools for agent def search_emails_tool( inbox: str, keywords: List[str], from_addr: Optional[str] = None, to_addr: Optional[str] = None, sent_after: Optional[str] = None, sent_before: Optional[str] = None, max_results: int = 10, ) -> List[SearchResult]: """ Searches the email database based on keywords, inbox, sender, recipient, and date range. Args: inbox: The email address of the user performing the search. Results include emails sent from or to (inc. cc/bcc) this address. keywords: A list of keywords that must all appear in the subject or body. from_addr: Optional email address to filter emails sent *from*. to_addr: Optional email address to filter emails sent *to* (inc. cc/bcc). sent_after: Optional date string 'YYYY-MM-DD'. Filters for emails sent on or after this date. sent_before: Optional date string 'YYYY-MM-DD'. Filters for emails sent before this date. max_results: The maximum number of results to return. Cannot exceed 10. Returns: A list of SearchResult objects, each containing 'message_id' and 'snippet'. Returns an empty list if no results are found or an error occurs. """ # Initialize sql and params sql: Optional[str] = None params: List[str | int] = [] cursor = get_conn().cursor() # --- Build Query --- where_clauses: List[str] = [] # 1. Keywords (FTS) if not keywords: raise ValueError("No keywords provided for search.") if max_results > 10: raise ValueError("max_results must be less than or equal to 10.") # FTS5 default is AND, so just join keywords. Escape quotes for safety. fts_query = " ".join(f""" "{k.replace('"', '""')}" """ for k in keywords) where_clauses.append("emails_fts MATCH ?") params.append(fts_query) # 2. Inbox filter (must be from OR to/cc/bcc the inbox user) # Use the composite index idx_recipients_address_email here where_clauses.append( """ (e.from_address = ? OR EXISTS ( SELECT 1 FROM recipients r_inbox WHERE r_inbox.recipient_address = ? AND r_inbox.email_id = e.id )) """, ) params.extend([inbox, inbox]) # 3. Optional From filter if from_addr: where_clauses.append("e.from_address = ?") params.append(from_addr) # 4. Optional To filter (includes to, cc, bcc) # Use composite index idx_recipients_address_email if to_addr: where_clauses.append( """ EXISTS ( SELECT 1 FROM recipients r_to WHERE r_to.recipient_address = ? AND r_to.email_id = e.id ) """, ) params.append(to_addr) # 5. Optional Sent After filter if sent_after: # Assumes date format 'YYYY-MM-DD' # Compare against the start of the day where_clauses.append("e.date >= ?") params.append(f"{sent_after} 00:00:00") # 6. Optional Sent Before filter if sent_before: # Assumes date format 'YYYY-MM-DD' # Compare against the start of the day (exclusive) where_clauses.append("e.date < ?") params.append(f"{sent_before} 00:00:00") # --- Construct Final Query --- # snippet(, , , # , , ) # -1 means highlight across all columns (subject, body) sql = f""" SELECT e.message_id, snippet(emails_fts, -1, '', '', ' ... ', 15) as snippet FROM emails e JOIN emails_fts fts ON e.id = fts.rowid WHERE {" AND ".join(where_clauses)} ORDER BY e.date DESC -- Order by date for relevance LIMIT ?; """ params.append(max_results) # --- Execute and Fetch --- logger.debug("Executing SQL: %s", sql) logger.debug("With params: %s", params) cursor.execute(sql, params) results = cursor.fetchall() # Format results formatted_results = [ SearchResult(message_id=row[0], snippet=row[1]) for row in results ] logger.info("Search found %d results.", len(formatted_results)) return formatted_results def read_email_tool(message_id: str) -> Optional[Email]: """ Retrieves a single email by its message_id from the database. Args: message_id: The unique identifier of the email to retrieve. Returns: An Email object containing the details of the found email, or None if the email is not found or an error occurs. """ cursor = get_conn().cursor() # --- Query for Email Core Details --- email_sql = """ SELECT id, message_id, date, subject, from_address, body, file_name FROM emails WHERE message_id = ?; """ cursor.execute(email_sql, (message_id,)) email_row = cursor.fetchone() if not email_row: logger.warning("Email with message_id '%s' not found.", message_id) return None email_pk_id, msg_id, date, subject, from_addr, body, file_name = email_row # DEBUG logger.info("[read_email_tool] input_message_id=%s", message_id) logger.info( "[read_email_tool] db: id=%s, message_id=%s", email_pk_id, msg_id, ) # search for recipients by emails.id (rather than message_id) recipients_sql = """ SELECT recipient_address, recipient_type FROM recipients WHERE email_id = ?; """ cursor.execute(recipients_sql, (email_pk_id,)) recipient_rows = cursor.fetchall() to_addresses: List[str] = [] cc_addresses: List[str] = [] bcc_addresses: List[str] = [] for addr, rtype in recipient_rows: type_lower = rtype.lower() if type_lower == "to": to_addresses.append(addr) elif type_lower == "cc": cc_addresses.append(addr) elif type_lower == "bcc": bcc_addresses.append(addr) # --- Construct Email Object --- email_obj = Email( message_id=msg_id, # Convert to string to match Pydantic model date=date, subject=subject, from_address=from_addr, to_addresses=to_addresses, cc_addresses=cc_addresses, bcc_addresses=bcc_addresses, body=body, file_name=file_name, ) return email_obj __all__ = [ "QueryModel", "AnswerModel", "FinalRubric", "Email", "SearchResult", "search_emails_tool", "read_email_tool", "get_conn", ]