# -*- coding: utf-8 -*-
import json
import os
import asyncio
from typing import Any, Dict, List, Tuple

from agentscope.tool import ToolResponse
from agentscope.message import TextBlock

import re
import subprocess
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import partial

# Third-party imports
import dotenv
import tqdm
from agentscope.embedding import DashScopeTextEmbedding
from agentscope.rag import SimpleKnowledge, QdrantStore

import csv
import hashlib
from typing import Literal
from agentscope.rag import ReaderBase, TextReader, Document

dotenv.load_dotenv()

class CSVReader(ReaderBase):
    """CSV reader that splits table data into chunks by fixed chunk size."""

    def __init__(
        self,
        chunk_size: int = 512,
        split_by: Literal["char", "sentence", "paragraph"] = "paragraph",
        delimiter: str = ",",
        encoding: str = "utf-8",
    ) -> None:
        """Initialize the CSV reader.

        Args:
            chunk_size (`int`, default to 512):
                The size of each chunk, in number of characters.
            split_by (`Literal["char", "sentence", "paragraph"]`, default to \
            "sentence"):
                The unit to split the text, can be "char", "sentence", or
                "paragraph". The "sentence" option is implemented using the
                "nltk" library, which only supports English text.
            delimiter (`str`, default to ","):
                The delimiter used in the CSV file.
            encoding (`str`, default to "utf-8"):
                The encoding of the CSV file.
        """
        if chunk_size <= 0:
            raise ValueError(
                f"The chunk_size must be positive, got {chunk_size}",
            )

        if split_by not in ["char", "sentence", "paragraph"]:
            raise ValueError(
                "The split_by must be one of 'char', 'sentence' or "
                f"'paragraph', got {split_by}",
            )

        self.chunk_size = chunk_size
        self.split_by = split_by
        self.delimiter = delimiter
        self.encoding = encoding

        # To avoid code duplication, we use TextReader to do the chunking.
        self._text_reader = TextReader(
            self.chunk_size,
            self.split_by,
        )

    def _read_csv(
        self,
        csv_path: str,
        sample_rows: int = 5,
        delimiter: str = ",",
        encoding: str = "utf-8",
        output_delimiter: str = " | ",
    ) -> str:
        """
        Read CSV header and first N rows with formatted output.

        Args:
            csv_path (`str`):
                The path to the CSV file.
            sample_rows (`int`, default to 5):
                Number of data rows to sample.
            delimiter (`str`, default to ","):
                The input CSV delimiter.
            encoding (`str`, default to "utf-8"):
                The encoding of the CSV file.
            output_delimiter (`str`, default to " | "):
                The delimiter for output formatting.

        Returns:
            str: Formatted plain text of the CSV header and sample rows.
        """

        try:
            lines = []

            with open(
                csv_path,
                "r",
                encoding=encoding,
                newline="",
                errors="ignore",
            ) as csvfile:
                csv_reader = csv.reader(csvfile, delimiter=delimiter)

                # Read header
                header = next(csv_reader, None)
                if header:
                    lines.append(output_delimiter.join(header))

                # Read first N rows
                for i, row in enumerate(csv_reader):
                    if i >= sample_rows:
                        break
                    if row:  # Skip empty rows
                        lines.append(output_delimiter.join(row))

            return " ".join(lines)

        except FileNotFoundError as exc:
            raise FileNotFoundError(
                f"CSV file not found: {csv_path}",
            ) from exc
        except Exception as e:
            raise RuntimeError(
                f"Error reading CSV file: {csv_path}. Error: {str(e)}",
            ) from e

    async def __call__(
        self,
        csv_path: str,
    ) -> list[Document]:
        """
        Read a CSV file, split it into chunks, and return Document objects.

        Args:
            csv_path (`str`):
                The input CSV file path.
        """
        sample_content = self._read_csv(csv_path)

        doc_id = hashlib.sha256(csv_path.encode("utf-8")).hexdigest()

        docs = await self._text_reader(sample_content)
        for doc in docs:
            doc.id = doc_id

        return docs

    def get_doc_id(self, csv_path: str) -> str:
        """Get the document ID. This function can be used to check if the
        doc_id already exists in the knowledge base."""
        return hashlib.sha256(csv_path.encode("utf-8")).hexdigest()


file_extensions_white_list_grep = {
    ".arff",
    ".csv",
    ".dat",
    ".data",
    ".db",
    ".docx",
    ".geojson",
    ".gz",
    ".html",
    ".json",
    ".jsonl",
    ".md",
    ".names",
    ".noext",
    ".pbix",
    ".pdf",
    ".png",
    ".py",
    ".r",
    ".sq",
    ".sql",
    ".sqlite",
    ".tex",
    ".tsv",
    ".txt",
    ".xls",
    ".xlsx",
    ".yaml",
    ".yml",
    ".zip",
}

file_extensions_white_list_rag = {
    ".csv",
}


class GrepFilterTool:
    """File filtering tool based on grep command (hybrid parallel version)"""

    def __init__(
        self,
        max_workers: int = None,
    ):
        self.white_list = file_extensions_white_list_grep
        self.max_workers = max_workers or min(
            32,
            (os.cpu_count() or 1) * 2,
        )
        self.batch_size = 50

    def _extract_keywords(
        self,
        query: str,
        language: str = "auto",
    ) -> List[str]:
        # Extract keywords from user query
        stopwords_en = {
            "i",
            "me",
            "my",
            "myself",
            "we",
            "our",
            "ours",
            "ourselves",
            "you",
            "you're",
            "you've",
            "you'll",
            "you'd",
            "your",
            "yours",
            "yourself",
            "yourselves",
            "he",
            "him",
            "his",
            "himself",
            "she",
            "she's",
            "her",
            "hers",
            "herself",
            "it",
            "it's",
            "its",
            "itself",
            "they",
            "them",
            "their",
            "theirs",
            "themselves",
            "what",
            "which",
            "who",
            "whom",
            "this",
            "that",
            "that'll",
            "these",
            "those",
            "am",
            "is",
            "are",
            "was",
            "were",
            "be",
            "been",
            "being",
            "have",
            "has",
            "had",
            "having",
            "do",
            "does",
            "did",
            "doing",
            "a",
            "an",
            "the",
            "and",
            "but",
            "if",
            "or",
            "because",
            "as",
            "until",
            "while",
            "of",
            "at",
            "by",
            "for",
            "with",
            "about",
            "against",
            "between",
            "into",
            "through",
            "during",
            "before",
            "after",
            "above",
            "below",
            "to",
            "from",
            "up",
            "down",
            "in",
            "out",
            "on",
            "off",
            "over",
            "under",
            "again",
            "further",
            "then",
            "once",
            "here",
            "there",
            "when",
            "where",
            "why",
            "how",
            "all",
            "both",
            "each",
            "few",
            "more",
            "most",
            "other",
            "some",
            "such",
            "no",
            "nor",
            "not",
            "only",
            "own",
            "same",
            "so",
            "than",
            "too",
            "very",
            "s",
            "t",
            "can",
            "will",
            "just",
            "don",
            "don't",
            "should",
            "should've",
            "now",
            "d",
            "ll",
            "m",
            "o",
            "re",
            "ve",
            "y",
            "ain",
            "aren",
            "aren't",
            "couldn",
            "couldn't",
            "didn",
            "didn't",
            "doesn",
            "doesn't",
            "hadn",
            "hadn't",
            "hasn",
            "hasn't",
            "haven",
            "haven't",
            "isn",
            "isn't",
            "ma",
            "mightn",
            "mightn't",
            "mustn",
            "mustn't",
            "needn",
            "needn't",
            "shan",
            "shan't",
            "shouldn",
            "shouldn't",
            "wasn",
            "won",
            "wasn't",
            "weren",
            "weren't",
            "won't",
            "wouldn",
            "wouldn't",
        }

        stop_words_en = set(stopwords_en)

        keywords = []
        words = re.findall(r"\b\w+\b", query.lower())
        for word in words:
            if len(word) > 2 and word not in stop_words_en:
                keywords.append(word)

        return keywords

    def _check_file_match_all(
        self,
        file_path: str,
        keywords: List[str],
        case_sensitive: bool,
    ) -> Tuple[bool, str]:
        """
        Check if a single file content matches all keywords (AND mode)
        """
        try:
            if not os.path.isfile(file_path) or not os.access(file_path, os.R_OK):
                return False, file_path

            grep_opts = ["-q"]
            if not case_sensitive:
                grep_opts.append("-i")

            grep_opts.extend(["-I"])

            for keyword in keywords:
                cmd = ["grep"] + grep_opts + ["--", keyword, file_path]
                result = subprocess.run(
                    cmd,
                    capture_output=True,
                    text=False,
                    timeout=10,
                    check=False,
                )
                if result.returncode != 0:
                    return False, file_path
            return True, file_path
        except (subprocess.TimeoutExpired, Exception) as e:
            print(f"Error searching file {file_path}: {e}")
            return False, file_path

    def _check_filename_match(
        self,
        file_path: str,
        keywords: List[str],
        case_sensitive: bool,
    ) -> bool:
        """
        Check if filename matches any keyword (OR mode)
        """
        filename = os.path.basename(file_path)
        search_name = filename if case_sensitive else filename.lower()

        for keyword in keywords:
            search_keyword = keyword if case_sensitive else keyword.lower()
            if search_keyword in search_name:
                return True
        return False

    def _filter_by_whitelist(self, file_list: List[str]) -> List[str]:
        """
        Filter files by extension whitelist
        """
        if not self.white_list:
            return file_list

        normalized_whitelist = []
        for ext in self.white_list:
            if not ext.startswith('.'):
                normalized_whitelist.append('.' + ext)
            else:
                normalized_whitelist.append(ext)

        whitelist_tuple = tuple(normalized_whitelist)

        filtered = []
        for f in file_list:
            _, ext = os.path.splitext(f)
            if ext and ext in whitelist_tuple:
                filtered.append(f)

        return filtered

    def _grep_files_parallel(
        self,
        keywords: List[str],
        file_list: List[str],
        case_sensitive: bool = False,
        match_all: bool = False,
    ) -> List[str]:
        """
        Parallel search files
        """
        file_list = self._filter_by_whitelist(file_list)

        if not keywords or not file_list:
            return []

        matched_files = []
        lock = threading.Lock()

        if match_all:
            # AND mode: search file content, all keywords must match
            with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
                check_func = partial(
                    self._check_file_match_all,
                    keywords=keywords,
                    case_sensitive=case_sensitive,
                )

                futures = [executor.submit(check_func, fp) for fp in file_list]

                desc = "Searching files (matching all keywords in content)"
                for future in tqdm.tqdm(
                    as_completed(futures),
                    total=len(futures),
                    desc=desc,
                    unit="files",
                ):
                    is_match, file_path = future.result()
                    if is_match:
                        with lock:
                            matched_files.append(file_path)

        else:
            # OR mode: search file name, any keyword match
            def check_file(file_path):
                if self._check_filename_match(file_path, keywords, case_sensitive):
                    return file_path
                return None

            with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
                futures = [executor.submit(check_file, fp) for fp in file_list]

                desc = "Searching files (matching any keyword in filename)"
                for future in tqdm.tqdm(
                    as_completed(futures),
                    total=len(futures),
                    desc=desc,
                    unit="files",
                ):
                    result = future.result()
                    if result:
                        with lock:
                            matched_files.append(result)

        return sorted(matched_files)

    def search_files_by_grep(
        self,
        query: str,
        file_list: List[str],
        case_sensitive: bool = False,
        match_all: bool = False,
    ) -> List[str]:
        """
        Extract keywords from user query and search relevant files

        Args:
            query: user input query
            file_list: list of file paths (MUST be full paths)
            case_sensitive: whether to match case sensitive
            match_all: True=AND mode (search content), False=OR mode (search filename)

        Returns:
            list of matched file paths
        """
        keywords = self._extract_keywords(query)

        if not keywords:
            print("No keywords extracted from user query")
            return []

        print(f"Extracted keywords: {keywords}")
        print(f"Search mode: {'AND (content)' if match_all else 'OR (filename)'}")

        return self._grep_files_parallel(
            keywords,
            file_list,
            case_sensitive,
            match_all,
        )

class RAGFilterTool:
    """File filtering tool based on RAG"""

    def __init__(
        self,
        file_list: List[str],
        api_key,
    ):
        self.white_list = file_extensions_white_list_rag
        white_list_tuple = tuple[str, ...](self.white_list)
        self.file_list = [
            file_path
            for file_path in file_list
            if file_path.endswith(white_list_tuple)
        ]
        self.api_key = api_key
        if not self.api_key:
            raise ValueError("DASHSCOPE_API_KEY is not set")

        # build mapping table between doc_id and file_name,
        # allowing final output to associate file names
        self.file_name_to_doc_id_map = {}
        self.knowledge = None

    async def build_knowledge_base(
        self,
        force_rebuild: bool = False,
        collection_name: str = "file_collection",
    ):
        """Build vector index"""

        if self.knowledge is not None and not force_rebuild:
            print("Knowledge base already exists")
            return

        print("=" * 60)
        print("Starting to build knowledge base...")
        print("=" * 60)

        documents = []

        # loop through all files
        for csv_file in self.file_list:
            reader = CSVReader()
            temp_docs = await reader(csv_path=csv_file)

            for doc in temp_docs:
                self.file_name_to_doc_id_map[doc.metadata.doc_id] = csv_file
            documents.extend(temp_docs)

        if not documents:
            print("No documents to process!")
            return

        print(f"Documents processed, {len(documents)} documents")
        print(documents[0].metadata)

        # create knowledge base
        print("\nCreating vector storage...")
        self.knowledge = SimpleKnowledge(
            embedding_model=DashScopeTextEmbedding(
                api_key=self.api_key,
                model_name="text-embedding-v4",
                dimensions=1024,
            ),
            embedding_store=QdrantStore(
                location=":memory:",
                collection_name=collection_name,
                dimensions=1024,
            ),
        )

        # add documents to knowledge base
        num_docs = len(documents)
        print(f"\nVectorizing and storing {num_docs} documents...")
        await self.knowledge.add_documents(documents)
        print("=" * 60)
        print("✓ Knowledge base built")
        print(f"  number of documents: {num_docs}")
        print(f"  collection name: {collection_name}")
        print("=" * 60)

    async def search(
        self,
        query: str,
        top_k: int = 10,
        score_threshold: float = 0.55,
    ) -> List[str]:
        """Use knowledge base for retrieval"""
        if self.knowledge is None:
            msg = (
                "Knowledge base not initialized, "
                "please call build_knowledge_base()"
            )
            raise ValueError(msg)

        # use knowledge base for retrieval
        docs = await self.knowledge.retrieve(
            query=query,
            limit=top_k,
            score_threshold=score_threshold,
        )

        # format results
        formatted_results = []
        for doc in docs:
            metadata = doc.metadata
            doc_id = metadata.doc_id
            file_path = self.file_name_to_doc_id_map[doc_id]
            formatted_results.append(file_path)

        return list(dict.fromkeys(formatted_results))

async def _eval_rag_filter(file_list: List[str], query: str, api_key: str = None) -> List[str]:
    rag_filter = RAGFilterTool(
        file_list,
        api_key,
    )

    await rag_filter.build_knowledge_base(
        force_rebuild=True,
        collection_name="file_index",
    )

    relevant_files = await rag_filter.search(
        query,
        top_k=5,
        score_threshold=0.3,
    )
    relevant_files = set(relevant_files)

    print(relevant_files)
    return relevant_files


async def files_filter_backup(query: str, files_list: List, api_key) -> ToolResponse:
    """
    Filter the uploaded files based on the user's query.
    If the number of uploaded files is too small, return all files.
    Otherwise, use RAG and Grep filtering to select relevant files.

    Args:
        query (str): The user's query.
        files_list (List): List of uploaded file paths.

    Example:
        query = "Analyze the sales data for Q1 2023."
        files_list = [
            "/workspace/data/sales_january.csv",
            "/workspace/data/sales_february.csv",
            "/workspace/data/sales_march.csv",
            "/workspace/data/marketing_report.pdf",
            "/workspace/data/employee_list.xlsx",
        ]
    """

    try:
        if len(files_list) < 1:
            selected_files = files_list
            files_json = json.dumps(selected_files, ensure_ascii=False)
            return ToolResponse(
                content=[
                    TextBlock(
                        type="text",
                        text=(
                            "The tool has determined that the number of "
                            "uploaded files is small enough to process all of them. "
                            f"Total files: {len(selected_files)}. "
                            f"Selected files:\n```\n{files_json}\n```"
                        ),
                    ),
                ]
            )
        else:
            print("Starting RAG and Grep filtering...")
            rag_filter_result = await _eval_rag_filter(files_list, query, api_key)
            print("RAG filter result:", rag_filter_result)
            grep_filter_result = GrepFilterTool().search_files_by_grep(
                query,
                files_list,
                match_all=False,
            )

            print("Grep filter result:", grep_filter_result)
            rag_set = set(rag_filter_result)
            grep_set = set(grep_filter_result)
            combined_set = rag_set | grep_set  # union
            if not combined_set: # if both methods return empty, fallback to all files
                combined_set = set(files_list)
            selected_files = list(combined_set)
            files_json = json.dumps(selected_files, ensure_ascii=False)

            return ToolResponse(
                content=[
                    TextBlock(
                        type="text",
                        text=(
                            "The tool has filtered the user's uploaded files and "
                            f"selected {len(selected_files)} relevant file(s) for processing. "
                            f"Selected files:\n```\n{files_json}\n```"
        ),
                    ),
                ]
            )

    except Exception as e:
        return ToolResponse(
            content=[
                TextBlock(
                    type="text",
                    text=(
                        "File filtering failed."
                        "Please identify the relevant files yourself. "
                    ),
                ),
            ]
        )

async def files_filter(query: str, files_list: List, api_key) -> List:

    print("Starting RAG and Grep filtering...")
    rag_filter_result = await _eval_rag_filter(files_list, query, api_key)
    print("RAG filter result:", rag_filter_result)
    grep_filter_result = GrepFilterTool().search_files_by_grep(
        query,
        files_list,
        match_all=False,
    )
    print("Grep filter result:", grep_filter_result)
    rag_set = set(rag_filter_result)
    grep_set = set(grep_filter_result)
    combined_set = rag_set | grep_set  # union
    if not combined_set: # if both methods return empty, fallback to all files
        combined_set = set(files_list)
    selected_files = list(combined_set)
    files_json = json.dumps(selected_files, ensure_ascii=False)
    print(f"Found files relevant to the query:```\n{files_json}\n```")
