import matplotlib
matplotlib.use('Agg')  # Must be set before importing pyplot
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
import io
from typing import Set, List, Dict
import hashlib
import functools
from datetime import datetime
import numpy as np
import pandas as pd


class ChartSummarizer:
    """
    Automatic analysis utility for Matplotlib charts,
    supporting detailed information extraction for multiple chart types.
    """

    def __init__(self, max_points=20):
        """
        Initialize the chart summarizer.

        Parameters:
            max_points (int): Maximum number of points to display per data series
            (uniformly sampled, including first and last).
        """
        self.max_points = max_points

    def summarize(self, ax=None, max_points=None):
        """
        Automatically analyze detailed information of the current Matplotlib chart.

        Parameters:
            ax: Matplotlib axes object. If None, uses plt.gca().
            max_points (int): Override the max_points setting from initialization.

        Returns:
            str: Formatted string containing detailed chart information.
        """
        if ax is None:
            ax = plt.gca()

        if max_points is None:
            max_points = self.max_points

        lines = []  # Store all output lines

        lines.append("\nChart Information Summary")
        lines.append("-" * 70)

        # Basic metadata
        title = ax.get_title() or "No Title"
        xlabel = ax.get_xlabel() or "X-axis Label Not Set"
        ylabel = ax.get_ylabel() or "Y-axis Label Not Set"
        lines.append(f"Title: {title}")
        lines.append(f"X-axis Label: {xlabel}")
        lines.append(f"Y-axis Label: {ylabel}")
        lines.append(f"X Range: [{ax.get_xlim()[0]:.4g}, {ax.get_xlim()[1]:.4g}]")
        lines.append(f"Y Range: [{ax.get_ylim()[0]:.4g}, {ax.get_ylim()[1]:.4g}]")

        has_content = False

        # 1. Handle line plots and some scatter plots (ax.lines)
        line_objs = ax.get_lines()
        if line_objs:
            has_content = True
            lines.append(f"\nLine Plots/Marker Lines ({len(line_objs)} lines):")
            for i, line in enumerate(line_objs):
                label = line.get_label() if line.get_label() not in ('_nolegend_', '') else f"Line {i + 1}"
                x = np.array(line.get_xdata())
                y = np.array(line.get_ydata())
                lines.extend(self._get_sampled_data_lines(label, x, y, max_points))

        # 2. Handle scatter plots (PathCollection in ax.collections)
        from matplotlib.collections import PathCollection
        scatter_collections = [c for c in ax.collections if isinstance(c, PathCollection)]
        if scatter_collections:
            has_content = True
            lines.append(f"\nScatter Plots ({len(scatter_collections)} groups):")
            for i, coll in enumerate(scatter_collections):
                offsets = np.array(coll.get_offsets())
                if offsets.size == 0:
                    continue
                x = offsets[:, 0]
                y = offsets[:, 1]
                label = f"Scatter {i + 1}"
                lines.extend(self._get_sampled_data_lines(label, x, y, max_points))

        # 3. Handle bar charts and histograms (Rectangle in ax.patches)
        from matplotlib.patches import Rectangle
        rectangles = [p for p in ax.patches if isinstance(p, Rectangle)]

        if rectangles:
            is_histogram = self._is_histogram_like(rectangles, ax)

            if is_histogram:
                has_content = True
                lines.append(f"\nHistogram ({len(rectangles)} bins):")
                lines.extend(self._get_histogram_info_lines(rectangles, max_points))
            else:
                has_content = True
                lines.append(f"\nBar Chart ({len(rectangles)} bars):")
                bars_info = []
                for rect in rectangles:
                    x = rect.get_x() + rect.get_width() / 2
                    y = rect.get_height()
                    bars_info.append((x, y))
                bars_info.sort(key=lambda t: t[0])
                x_vals = np.array([t[0] for t in bars_info])
                y_vals = np.array([t[1] for t in bars_info])
                lines.extend(self._get_sampled_data_lines("Bar Chart", x_vals, y_vals, max_points))

        # 4. Fallback message
        if not has_content:
            lines.append("\nNo standard chart type recognized (may be empty or using special plotting methods)")
            lines.append(f"Note: Detected {len(ax.patches)} patches, {len(ax.collections)} collections")

        lines.append("-" * 70 + "\n")

        return "\n".join(lines)

    def _get_sampled_data_lines(self, label, x, y, max_points):
        """
        Return a list of strings for uniformly sampled (x, y) data points.

        Returns:
            list: List of formatted data point strings.
        """
        result = []
        n = len(x)
        if n == 0:
            result.append(f"  {label}: No data")
            return result

        # Uniform sampling (including first and last)
        if n <= max_points:
            indices = np.arange(n)
            sampling_note = ""
        else:
            indices = np.linspace(0, n - 1, max_points, dtype=int)
            sampling_note = f" [Uniformly sampled {max_points}/{n} points, including first and last]"

        # Format X (handle datetime types)
        try:
            if isinstance(x[0], pd.Timestamp):
                x_strs = [val.strftime('%Y-%m-%d %H:%M') for val in x[indices]]
            elif isinstance(x[0], np.datetime64):
                x_strs = [pd.Timestamp(val).strftime('%Y-%m-%d %H:%M') for val in x[indices]]
            elif hasattr(x[0], 'strftime') and not isinstance(x[0], (int, float)):
                x_strs = [val.strftime('%Y-%m-%d %H:%M') for val in x[indices]]
            else:
                x_strs = [f"{val:.4g}" for val in x[indices]]
        except Exception:
            x_strs = [str(val) for val in x[indices]]

        y_strs = [f"{val:.4g}" for val in y[indices]]

        result.append(f"  {label}{sampling_note}:")
        for xs, ys in zip(x_strs, y_strs):
            result.append(f"    ({xs}, {ys})")

        return result

    def _get_histogram_info_lines(self, rectangles, max_points):
        """
        Return a list of strings describing histogram information.

        Returns:
            list: List of histogram info strings.
        """
        result = []
        heights = [rect.get_height() for rect in rectangles]
        widths = [rect.get_width() for rect in rectangles]
        lefts = [rect.get_x() for rect in rectangles]

        # Sort by left edge
        sorted_indices = np.argsort(lefts)
        lefts = np.array(lefts)[sorted_indices]
        heights = np.array(heights)[sorted_indices]
        widths = np.array(widths)[sorted_indices]

        # Construct bin edges
        bin_edges = np.concatenate([lefts, [lefts[-1] + widths[-1]]])
        counts = heights

        total = counts.sum()
        result.append(f"  Total: {total:.0f} data points")
        result.append("  Bins (left-closed, right-open) -> Count:")

        n_bins = len(counts)

        # Uniformly sample bins (including first and last)
        if n_bins <= max_points:
            show_indices = np.arange(n_bins)
            sampling_note = ""
        else:
            show_indices = np.linspace(0, n_bins - 1, max_points, dtype=int)
            sampling_note = f" [Uniformly sampled {max_points}/{n_bins} bins, including first and last]"

        if sampling_note:
            result.append(f"  {sampling_note}")

        for i in show_indices:
            left, right = bin_edges[i], bin_edges[i + 1]
            count = counts[i]
            pct = count / total * 100 if total > 0 else 0
            result.append(f"    [{left:.4g}, {right:.4g}) -> {count:.0f} ({pct:.1f}%)")

        return result

    def _is_histogram_like(self, rectangles, ax):
        """
        Determine if the rectangles represent a histogram:
        1. Bar widths are nearly identical.
        2. Bars are tightly packed (allowing small gaps to distinguish from plt.bar's default spacing).
        """
        if len(rectangles) < 2:
            return False

        widths = [r.get_width() for r in rectangles]
        lefts = [r.get_x() for r in rectangles]

        # Check width consistency
        mean_width = np.mean(widths)
        if mean_width == 0:
            return False
        if np.std(widths) / mean_width > 0.01:
            return False

        # Check continuity after sorting by left edge
        sorted_lefts = sorted(lefts)
        width = widths[0]

        # Calculate gaps between adjacent bars
        gaps = []
        for i in range(len(sorted_lefts) - 1):
            gap = sorted_lefts[i + 1] - (sorted_lefts[i] + width)
            gaps.append(gap)

        # Histogram characteristic: gaps are very small (< 1% of bar width)
        # Bar chart characteristic: noticeable gaps (plt.bar defaults to width=0.8, gap ~0.2)
        max_gap = max(abs(g) for g in gaps) if gaps else 0

        # Consider it a histogram if max gap is less than 5% of bar width
        return max_gap < width * 0.05


class MatplotlibMonitor:
    """Monitors Matplotlib chart creation and automatically generates summaries."""

    def __init__(self):
        """
        Initialize the monitor.

        Args:
            summarize_func: Chart summary function that takes no arguments and analyzes the current chart.
        """
        self.chart_summarizer = ChartSummarizer()

        self.summarize_func = self.chart_summarizer.summarize
        self.original_show = None
        self.processed_figures: Set[str] = set()  # Store hashes of processed figures
        self.is_hooked = False
        self.summaries: List[Dict] = []  # Store all summary information

    def start(self):
        """Start monitoring by hooking plt.show()."""
        if self.is_hooked:
            return

        self.original_show = plt.show

        # Create a wrapped function and preserve the original function's signature and attributes
        @functools.wraps(self.original_show)
        def wrapped_show(*args, **kwargs):
            return self._custom_show(*args, **kwargs)

        # Copy __signature__ attribute if it exists
        if hasattr(self.original_show, '__signature__'):
            wrapped_show.__signature__ = self.original_show.__signature__

        plt.show = wrapped_show
        self.is_hooked = True

    def stop(self):
        """Stop monitoring and restore the original plt.show()."""
        if not self.is_hooked:
            return

        plt.show = self.original_show
        self.is_hooked = False
        self.processed_figures.clear()

    def _get_figure_hash(self, fig: Figure) -> str:
        """
        Compute a hash for a figure to enable deduplication.

        Args:
            fig: Matplotlib Figure object.

        Returns:
            MD5 hash of the figure.
        """
        buf = io.BytesIO()
        fig.savefig(buf, format='png', bbox_inches='tight')
        buf.seek(0)
        image_bytes = buf.read()
        buf.close()
        return hashlib.md5(image_bytes).hexdigest()

    def _process_figure(self, fig: Figure, fig_num: int):
        """
        Process a single figure.

        Args:
            fig: Matplotlib Figure object.
            fig_num: Figure number.
        """
        try:
            # Compute hash to check if already processed
            fig_hash = self._get_figure_hash(fig)

            if fig_hash in self.processed_figures:
                return

            # Set this figure as active so summarize_func can access it
            plt.figure(fig.number)

            # Handle multiple subplots
            axes = fig.get_axes()
            if not axes:
                return

            # Process each subplot
            for idx, ax in enumerate(axes):
                plt.sca(ax)  # Set current axes

                # Call summary function (no arguments)
                summary = self.summarize_func()

                # Store summary info
                summary_entry = {
                    'timestamp': datetime.now().isoformat(),
                    'figure_number': fig_num,
                    'figure_hash': fig_hash,
                    'axes_index': idx,
                    'total_axes': len(axes),
                    'summary': summary
                }
                self.summaries.append(summary_entry)

            # Mark as processed
            self.processed_figures.add(fig_hash)

        except Exception as e:
            error_entry = {
                'timestamp': datetime.now().isoformat(),
                'figure_number': fig_num,
                'axes_index': None,
                'error': str(e),
                'summary': None
            }
            self.summaries.append(error_entry)

    def _get_all_figures(self):
        """
        Get all currently open figure objects.

        Returns:
            List of (fig_num, Figure) tuples.
        """
        figures = []

        # Get all figure managers via _pylab_helpers
        try:
            from matplotlib._pylab_helpers import Gcf
            for manager in Gcf.get_all_fig_managers():
                figures.append((manager.num, manager.canvas.figure))
        except Exception:
            pass

        return figures

    def _custom_show(self, *args, **kwargs):
        """
        Custom show function that processes all figures before calling the original show.
        """
        # Get all currently open figures
        figures = self._get_all_figures()

        if figures:
            # Save current figure and axes
            current_fig = plt.gcf()
            current_ax = plt.gca() if current_fig.get_axes() else None

            # Process each figure
            for fig_num, fig in figures:
                self._process_figure(fig, fig_num)

            # Restore previous figure and axes (if possible)
            try:
                if current_fig in [f[1] for f in figures]:
                    plt.figure(current_fig.number)
                    if current_ax and current_ax in current_fig.get_axes():
                        plt.sca(current_ax)
            except:
                pass

        # Call the original show function
        return self.original_show(*args, **kwargs)

    def get_all_summaries(self) -> List[Dict]:
        """
        Retrieve all stored summary information.

        Returns:
            List of summary dictionaries, each containing:
            - timestamp: ISO format timestamp
            - figure_number: Figure number
            - figure_hash: Figure hash
            - axes_index: Subplot index
            - total_axes: Total number of subplots
            - summary: Summary content
            - error: Error message (if any)
        """
        return self.summaries.copy()

    def clear_all_summaries(self):
        """Clear all stored summary information."""
        self.summaries.clear()

    def clear_cache(self):
        """Clear the cache of processed figures."""
        self.processed_figures.clear()

monitor = MatplotlibMonitor()

# Start monitoring
monitor.start()