# -*- coding: utf-8 -*- """ This file defines Dataclass and tool implementations. Modified from https://github.com/OpenPipe/ART/blob/art-e/ """ import datetime import os import sqlite3 from dataclasses import dataclass from typing import Any, List, Optional from pydantic import BaseModel, Field, field_validator from agentscope import logger DEFAULT_DB_PATH = os.environ.get("DEFAULT_EMAIL_DB_PATH") conn = None def get_conn() -> sqlite3.Connection: """Get or create a database connection.""" global conn if conn is None: conn = sqlite3.connect( f"file:{DEFAULT_DB_PATH}?mode=ro", uri=True, check_same_thread=False, ) return conn class QueryModel(BaseModel): """Model for email search query.""" id: int question: str answer: str message_ids: List[str] # message_ids (strings) of referenced emails how_realistic: float inbox_address: str query_date: str @field_validator("query_date", mode="before") @classmethod def format_date(cls, v: Any) -> str: """Format date to string if it's a datetime object.""" if isinstance(v, datetime.datetime): return v.strftime("%Y-%m-%d") return v class AnswerModel(BaseModel): """Model for agent's answer with sources.""" answer: str = Field( description=( "It should be called with the answer and the sources. " "If you cannot find the answer, you should return " "'I don't know' with an empty list of sources." ), ) sources: List[str] = Field( description=( "a list of message ids that are relevant to the query. " "Usually there will be only one. If you cannot find the " "answer, you should return an empty list." ), ) class Email(BaseModel): """Model representing an email.""" message_id: str date: str # ISO 8601 string 'YYYY-MM-DD HH:MM:SS' subject: Optional[str] = None from_address: Optional[str] = None to_addresses: List[str] = Field(default_factory=list) cc_addresses: List[str] = Field(default_factory=list) bcc_addresses: List[str] = Field(default_factory=list) body: Optional[str] = None file_name: Optional[str] = None @dataclass class SearchResult: """Result from email search.""" message_id: str snippet: str class FinalRubric(BaseModel): """Rubric for evaluating agent performance.""" answer_correct: bool = False sources_correct: bool = False num_turns: int = 0 attempted_answer: bool = False ever_found_right_email: bool = False ever_read_right_email: bool = False cant_parse_tool_call: bool = False bad_tool_call_name: bool = False bad_tool_call_args: bool = False ran_out_of_turns: bool = False returned_i_dont_know: bool = False num_sources: int = 0 ever_tried_to_read_invalid_email: bool = False prompt_tokens: int = 0 completion_tokens: int = 0 # Define tools for agent def search_emails_tool( inbox: str, keywords: List[str], from_addr: Optional[str] = None, to_addr: Optional[str] = None, sent_after: Optional[str] = None, sent_before: Optional[str] = None, max_results: int = 10, ) -> List[SearchResult]: """ Searches the email database based on keywords, inbox, sender, recipient, and date range. Args: inbox: The email address of the user performing the search. Results include emails sent from or to (inc. cc/bcc) this address. keywords: A list of keywords that must all appear in the subject or body. from_addr: Optional email address to filter emails sent *from*. to_addr: Optional email address to filter emails sent *to* (inc. cc/bcc). sent_after: Optional date string 'YYYY-MM-DD'. Filters for emails sent on or after this date. sent_before: Optional date string 'YYYY-MM-DD'. Filters for emails sent before this date. max_results: The maximum number of results to return. Cannot exceed 10. Returns: A list of SearchResult objects, each containing 'message_id' and 'snippet'. Returns an empty list if no results are found or an error occurs. """ # Initialize sql and params sql: Optional[str] = None params: List[str | int] = [] cursor = get_conn().cursor() # --- Build Query --- where_clauses: List[str] = [] # 1. Keywords (FTS) if not keywords: raise ValueError("No keywords provided for search.") if max_results > 10: raise ValueError("max_results must be less than or equal to 10.") # FTS5 default is AND, so just join keywords. Escape quotes for safety. fts_query = " ".join(f""" "{k.replace('"', '""')}" """ for k in keywords) where_clauses.append("emails_fts MATCH ?") params.append(fts_query) # 2. Inbox filter (must be from OR to/cc/bcc the inbox user) # Use the composite index idx_recipients_address_email here where_clauses.append( """ (e.from_address = ? OR EXISTS ( SELECT 1 FROM recipients r_inbox WHERE r_inbox.recipient_address = ? AND r_inbox.email_id = e.id )) """, ) params.extend([inbox, inbox]) # 3. Optional From filter if from_addr: where_clauses.append("e.from_address = ?") params.append(from_addr) # 4. Optional To filter (includes to, cc, bcc) # Use composite index idx_recipients_address_email if to_addr: where_clauses.append( """ EXISTS ( SELECT 1 FROM recipients r_to WHERE r_to.recipient_address = ? AND r_to.email_id = e.id ) """, ) params.append(to_addr) # 5. Optional Sent After filter if sent_after: # Assumes date format 'YYYY-MM-DD' # Compare against the start of the day where_clauses.append("e.date >= ?") params.append(f"{sent_after} 00:00:00") # 6. Optional Sent Before filter if sent_before: # Assumes date format 'YYYY-MM-DD' # Compare against the start of the day (exclusive) where_clauses.append("e.date < ?") params.append(f"{sent_before} 00:00:00") # --- Construct Final Query --- # snippet(