Add examples for frozenlake and emailsearch (#94)
This commit is contained in:
328
tuner/email_search/_utils.py
Normal file
328
tuner/email_search/_utils.py
Normal file
@@ -0,0 +1,328 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
This file defines Dataclass and tool implementations.
|
||||
Modified from https://github.com/OpenPipe/ART/blob/art-e/
|
||||
"""
|
||||
import datetime
|
||||
import os
|
||||
import sqlite3
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from agentscope import logger
|
||||
|
||||
DEFAULT_DB_PATH = os.environ.get("DEFAULT_EMAIL_DB_PATH")
|
||||
conn = None
|
||||
|
||||
|
||||
def get_conn() -> sqlite3.Connection:
|
||||
"""Get or create a database connection."""
|
||||
global conn
|
||||
if conn is None:
|
||||
conn = sqlite3.connect(
|
||||
f"file:{DEFAULT_DB_PATH}?mode=ro",
|
||||
uri=True,
|
||||
check_same_thread=False,
|
||||
)
|
||||
return conn
|
||||
|
||||
|
||||
class QueryModel(BaseModel):
|
||||
"""Model for email search query."""
|
||||
|
||||
id: int
|
||||
question: str
|
||||
answer: str
|
||||
message_ids: List[str] # message_ids (strings) of referenced emails
|
||||
how_realistic: float
|
||||
inbox_address: str
|
||||
query_date: str
|
||||
|
||||
@field_validator("query_date", mode="before")
|
||||
@classmethod
|
||||
def format_date(cls, v: Any) -> str:
|
||||
"""Format date to string if it's a datetime object."""
|
||||
if isinstance(v, datetime.datetime):
|
||||
return v.strftime("%Y-%m-%d")
|
||||
return v
|
||||
|
||||
|
||||
class AnswerModel(BaseModel):
|
||||
"""Model for agent's answer with sources."""
|
||||
|
||||
answer: str = Field(
|
||||
description=(
|
||||
"It should be called with the answer and the sources. "
|
||||
"If you cannot find the answer, you should return "
|
||||
"'I don't know' with an empty list of sources."
|
||||
),
|
||||
)
|
||||
sources: List[str] = Field(
|
||||
description=(
|
||||
"a list of message ids that are relevant to the query. "
|
||||
"Usually there will be only one. If you cannot find the "
|
||||
"answer, you should return an empty list."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Email(BaseModel):
|
||||
"""Model representing an email."""
|
||||
|
||||
message_id: str
|
||||
date: str # ISO 8601 string 'YYYY-MM-DD HH:MM:SS'
|
||||
subject: Optional[str] = None
|
||||
from_address: Optional[str] = None
|
||||
to_addresses: List[str] = Field(default_factory=list)
|
||||
cc_addresses: List[str] = Field(default_factory=list)
|
||||
bcc_addresses: List[str] = Field(default_factory=list)
|
||||
body: Optional[str] = None
|
||||
file_name: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
"""Result from email search."""
|
||||
|
||||
message_id: str
|
||||
snippet: str
|
||||
|
||||
|
||||
class FinalRubric(BaseModel):
|
||||
"""Rubric for evaluating agent performance."""
|
||||
|
||||
answer_correct: bool = False
|
||||
sources_correct: bool = False
|
||||
num_turns: int = 0
|
||||
attempted_answer: bool = False
|
||||
ever_found_right_email: bool = False
|
||||
ever_read_right_email: bool = False
|
||||
cant_parse_tool_call: bool = False
|
||||
bad_tool_call_name: bool = False
|
||||
bad_tool_call_args: bool = False
|
||||
ran_out_of_turns: bool = False
|
||||
returned_i_dont_know: bool = False
|
||||
num_sources: int = 0
|
||||
ever_tried_to_read_invalid_email: bool = False
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
|
||||
|
||||
# Define tools for agent
|
||||
|
||||
|
||||
def search_emails_tool(
|
||||
inbox: str,
|
||||
keywords: List[str],
|
||||
from_addr: Optional[str] = None,
|
||||
to_addr: Optional[str] = None,
|
||||
sent_after: Optional[str] = None,
|
||||
sent_before: Optional[str] = None,
|
||||
max_results: int = 10,
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Searches the email database based on keywords, inbox,
|
||||
sender, recipient, and date range.
|
||||
|
||||
Args:
|
||||
inbox: The email address of the user performing the search.
|
||||
Results include emails sent from or to (inc. cc/bcc)
|
||||
this address.
|
||||
keywords: A list of keywords that must all appear in the
|
||||
subject or body.
|
||||
from_addr: Optional email address to filter emails sent *from*.
|
||||
to_addr: Optional email address to filter emails sent *to*
|
||||
(inc. cc/bcc).
|
||||
sent_after: Optional date string 'YYYY-MM-DD'. Filters for
|
||||
emails sent on or after this date.
|
||||
sent_before: Optional date string 'YYYY-MM-DD'. Filters for
|
||||
emails sent before this date.
|
||||
max_results: The maximum number of results to return.
|
||||
Cannot exceed 10.
|
||||
|
||||
Returns:
|
||||
A list of SearchResult objects, each containing 'message_id'
|
||||
and 'snippet'. Returns an empty list if no results are found
|
||||
or an error occurs.
|
||||
"""
|
||||
# Initialize sql and params
|
||||
sql: Optional[str] = None
|
||||
params: List[str | int] = []
|
||||
|
||||
cursor = get_conn().cursor()
|
||||
|
||||
# --- Build Query ---
|
||||
where_clauses: List[str] = []
|
||||
|
||||
# 1. Keywords (FTS)
|
||||
if not keywords:
|
||||
raise ValueError("No keywords provided for search.")
|
||||
|
||||
if max_results > 10:
|
||||
raise ValueError("max_results must be less than or equal to 10.")
|
||||
|
||||
# FTS5 default is AND, so just join keywords. Escape quotes for safety.
|
||||
fts_query = " ".join(f""" "{k.replace('"', '""')}" """ for k in keywords)
|
||||
where_clauses.append("emails_fts MATCH ?")
|
||||
params.append(fts_query)
|
||||
|
||||
# 2. Inbox filter (must be from OR to/cc/bcc the inbox user)
|
||||
# Use the composite index idx_recipients_address_email here
|
||||
where_clauses.append(
|
||||
"""
|
||||
(e.from_address = ? OR EXISTS (
|
||||
SELECT 1 FROM recipients r_inbox
|
||||
WHERE r_inbox.recipient_address = ? AND r_inbox.email_id = e.id
|
||||
))
|
||||
""",
|
||||
)
|
||||
params.extend([inbox, inbox])
|
||||
|
||||
# 3. Optional From filter
|
||||
if from_addr:
|
||||
where_clauses.append("e.from_address = ?")
|
||||
params.append(from_addr)
|
||||
|
||||
# 4. Optional To filter (includes to, cc, bcc)
|
||||
# Use composite index idx_recipients_address_email
|
||||
if to_addr:
|
||||
where_clauses.append(
|
||||
"""
|
||||
EXISTS (
|
||||
SELECT 1 FROM recipients r_to
|
||||
WHERE r_to.recipient_address = ? AND r_to.email_id = e.id
|
||||
)
|
||||
""",
|
||||
)
|
||||
params.append(to_addr)
|
||||
|
||||
# 5. Optional Sent After filter
|
||||
if sent_after:
|
||||
# Assumes date format 'YYYY-MM-DD'
|
||||
# Compare against the start of the day
|
||||
where_clauses.append("e.date >= ?")
|
||||
params.append(f"{sent_after} 00:00:00")
|
||||
|
||||
# 6. Optional Sent Before filter
|
||||
if sent_before:
|
||||
# Assumes date format 'YYYY-MM-DD'
|
||||
# Compare against the start of the day (exclusive)
|
||||
where_clauses.append("e.date < ?")
|
||||
params.append(f"{sent_before} 00:00:00")
|
||||
|
||||
# --- Construct Final Query ---
|
||||
# snippet(<table>, <column_index>, <highlight_start>,
|
||||
# <highlight_end>, <ellipsis>, <tokens>)
|
||||
# -1 means highlight across all columns (subject, body)
|
||||
sql = f"""
|
||||
SELECT
|
||||
e.message_id,
|
||||
snippet(emails_fts, -1, '<b>', '</b>', ' ... ', 15) as snippet
|
||||
FROM
|
||||
emails e JOIN emails_fts fts ON e.id = fts.rowid
|
||||
WHERE
|
||||
{" AND ".join(where_clauses)}
|
||||
ORDER BY
|
||||
e.date DESC -- Order by date for relevance
|
||||
LIMIT ?;
|
||||
"""
|
||||
params.append(max_results)
|
||||
|
||||
# --- Execute and Fetch ---
|
||||
logger.debug("Executing SQL: %s", sql)
|
||||
logger.debug("With params: %s", params)
|
||||
cursor.execute(sql, params)
|
||||
results = cursor.fetchall()
|
||||
|
||||
# Format results
|
||||
formatted_results = [
|
||||
SearchResult(message_id=row[0], snippet=row[1]) for row in results
|
||||
]
|
||||
logger.info("Search found %d results.", len(formatted_results))
|
||||
return formatted_results
|
||||
|
||||
|
||||
def read_email_tool(message_id: str) -> Optional[Email]:
|
||||
"""
|
||||
Retrieves a single email by its message_id from the database.
|
||||
|
||||
Args:
|
||||
message_id: The unique identifier of the email to retrieve.
|
||||
|
||||
Returns:
|
||||
An Email object containing the details of the found email,
|
||||
or None if the email is not found or an error occurs.
|
||||
"""
|
||||
cursor = get_conn().cursor()
|
||||
|
||||
# --- Query for Email Core Details ---
|
||||
email_sql = """
|
||||
SELECT id, message_id, date, subject, from_address, body, file_name
|
||||
FROM emails
|
||||
WHERE message_id = ?;
|
||||
"""
|
||||
cursor.execute(email_sql, (message_id,))
|
||||
email_row = cursor.fetchone()
|
||||
|
||||
if not email_row:
|
||||
logger.warning("Email with message_id '%s' not found.", message_id)
|
||||
return None
|
||||
|
||||
email_pk_id, msg_id, date, subject, from_addr, body, file_name = email_row
|
||||
|
||||
# DEBUG
|
||||
logger.info("[read_email_tool] input_message_id=%s", message_id)
|
||||
logger.info(
|
||||
"[read_email_tool] db: id=%s, message_id=%s",
|
||||
email_pk_id,
|
||||
msg_id,
|
||||
)
|
||||
|
||||
# search for recipients by emails.id (rather than message_id)
|
||||
recipients_sql = """
|
||||
SELECT recipient_address, recipient_type
|
||||
FROM recipients
|
||||
WHERE email_id = ?;
|
||||
"""
|
||||
cursor.execute(recipients_sql, (email_pk_id,))
|
||||
recipient_rows = cursor.fetchall()
|
||||
|
||||
to_addresses: List[str] = []
|
||||
cc_addresses: List[str] = []
|
||||
bcc_addresses: List[str] = []
|
||||
|
||||
for addr, rtype in recipient_rows:
|
||||
type_lower = rtype.lower()
|
||||
if type_lower == "to":
|
||||
to_addresses.append(addr)
|
||||
elif type_lower == "cc":
|
||||
cc_addresses.append(addr)
|
||||
elif type_lower == "bcc":
|
||||
bcc_addresses.append(addr)
|
||||
|
||||
# --- Construct Email Object ---
|
||||
email_obj = Email(
|
||||
message_id=msg_id, # Convert to string to match Pydantic model
|
||||
date=date,
|
||||
subject=subject,
|
||||
from_address=from_addr,
|
||||
to_addresses=to_addresses,
|
||||
cc_addresses=cc_addresses,
|
||||
bcc_addresses=bcc_addresses,
|
||||
body=body,
|
||||
file_name=file_name,
|
||||
)
|
||||
|
||||
return email_obj
|
||||
|
||||
|
||||
__all__ = [
|
||||
"QueryModel",
|
||||
"AnswerModel",
|
||||
"FinalRubric",
|
||||
"Email",
|
||||
"SearchResult",
|
||||
"search_emails_tool",
|
||||
"read_email_tool",
|
||||
"get_conn",
|
||||
]
|
||||
Reference in New Issue
Block a user