"""
Jupyter Notebook Executor for Pipeline Replay.

This module provides validation and execution capabilities for Jupyter notebooks
generated by NotebookExporter, using Papermill for parameter injection and
execution control.
"""

import logging
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

import anndata
import nbformat

try:
    import papermill

    PAPERMILL_AVAILABLE = True
except ImportError:
    papermill = None
    PAPERMILL_AVAILABLE = False

from lobster.core.data_manager_v2 import DataManagerV2

logger = logging.getLogger(__name__)


@dataclass
class ValidationResult:
    """
    Result of notebook input validation.

    Attributes:
        errors: List of validation errors (blocking)
        warnings: List of validation warnings (non-blocking)
    """

    errors: List[str] = field(default_factory=list)
    warnings: List[str] = field(default_factory=list)

    @property
    def has_errors(self) -> bool:
        """Check if validation has blocking errors."""
        return len(self.errors) > 0

    @property
    def has_warnings(self) -> bool:
        """Check if validation has warnings."""
        return len(self.warnings) > 0

    @property
    def is_valid(self) -> bool:
        """Check if validation passed (no errors)."""
        return not self.has_errors

    def add_error(self, error: str) -> None:
        """Add validation error."""
        self.errors.append(error)

    def add_warning(self, warning: str) -> None:
        """Add validation warning."""
        self.warnings.append(warning)

    def __str__(self) -> str:
        """String representation of validation result."""
        parts = []
        if self.errors:
            parts.append(f"Errors ({len(self.errors)}): {', '.join(self.errors)}")
        if self.warnings:
            parts.append(f"Warnings ({len(self.warnings)}): {', '.join(self.warnings)}")
        return "; ".join(parts) if parts else "Validation passed"


class NotebookExecutor:
    """
    Execute Jupyter notebooks with validation and error handling.

    This class provides validation, dry-run simulation, and execution
    capabilities for Jupyter notebooks using Papermill.

    Attributes:
        data_manager: DataManagerV2 instance for data access
    """

    def __init__(self, data_manager: DataManagerV2) -> None:
        """
        Initialize notebook executor.

        Args:
            data_manager: DataManagerV2 for data access
        """
        self.data_manager = data_manager
        logger.debug("Initialized NotebookExecutor")

    def validate_input(
        self, notebook_path: Path, input_data: Union[str, Path]
    ) -> ValidationResult:
        """
        Validate input data matches notebook expectations.

        Reads notebook metadata to check:
        - Data shape compatibility
        - Required columns present
        - Data type matches (single-cell vs bulk)

        Args:
            notebook_path: Path to .ipynb file
            input_data: Path to input H5AD file

        Returns:
            ValidationResult with errors and warnings
        """
        result = ValidationResult()

        # Load notebook metadata
        try:
            with open(notebook_path) as f:
                nb = nbformat.read(f, as_version=4)
        except Exception as e:
            result.add_error(f"Cannot read notebook: {e}")
            return result

        metadata = nb.metadata.get("lobster", {})

        # Load input data
        try:
            adata = anndata.read_h5ad(input_data)
        except Exception as e:
            result.add_error(f"Cannot read input data: {e}")
            return result

        # Check shape compatibility
        expected_min_cells = metadata.get("min_cells", 100)
        if adata.n_obs < expected_min_cells:
            result.add_warning(
                f"Input has {adata.n_obs} cells, notebook expects ≥{expected_min_cells}"
            )

        # Check required columns (if specified in metadata)
        required_obs = metadata.get("required_obs_columns", [])
        missing = set(required_obs) - set(adata.obs.columns)
        if missing:
            result.add_error(f"Missing required columns: {missing}")

        # Data type check (sparse vs dense)
        if hasattr(adata.X, "toarray"):
            result.add_warning(
                "Input data is sparse - notebook may expect dense arrays"
            )

        logger.debug(f"Validation result: {result}")
        return result

    def dry_run(
        self, notebook_path: Path, input_data: Union[str, Path]
    ) -> Dict[str, Any]:
        """
        Simulate execution without running.

        Args:
            notebook_path: Path to .ipynb file
            input_data: Path to input H5AD file

        Returns:
            Dictionary with dry-run results:
                - validation_result: ValidationResult
                - estimated_duration: Estimated execution time (minutes)
                - steps_to_execute: Number of code cells
                - output_files: Expected output files
        """
        validation = self.validate_input(notebook_path, input_data)

        # Read notebook to count steps
        with open(notebook_path) as f:
            nb = nbformat.read(f, as_version=4)

        code_cells = [c for c in nb.cells if c.cell_type == "code"]

        # Estimate duration (rough: 2 minutes per cell)
        estimated_duration = len(code_cells) * 2

        return {
            "status": "dry_run_complete",
            "validation": validation,
            "steps_to_execute": len(code_cells),
            "estimated_duration_minutes": estimated_duration,
            "output_notebook": str(
                notebook_path.parent / f"{notebook_path.stem}_output.ipynb"
            ),
            "validation_passed": validation.is_valid,
        }

    def execute(
        self,
        notebook_path: Path,
        input_data: Union[str, Path],
        parameters: Optional[Dict[str, Any]] = None,
        output_path: Optional[Path] = None,
    ) -> Dict[str, Any]:
        """
        Execute notebook with Papermill.

        Args:
            notebook_path: Path to .ipynb file
            input_data: Path to input H5AD file
            parameters: Optional parameter overrides
            output_path: Where to save output notebook

        Returns:
            Dictionary with execution results:
                - status: "success" | "failed" | "validation_failed"
                - output_notebook: Path to executed notebook
                - execution_time: Duration in seconds
                - errors: List of errors if failed
        """
        # Validate first
        validation = self.validate_input(notebook_path, input_data)
        if validation.has_errors:
            return {
                "status": "validation_failed",
                "errors": validation.errors,
                "warnings": validation.warnings,
            }

        # Prepare parameters
        params = parameters or {}
        params["input_data"] = str(input_data)

        # Output path
        if output_path is None:
            output_path = notebook_path.parent / f"{notebook_path.stem}_output.ipynb"

        # Execute with Papermill
        try:
            if not PAPERMILL_AVAILABLE:
                return {
                    "status": "failed",
                    "error": "Papermill not installed. Install with: pip install papermill",
                }

            start_time = time.time()

            logger.info(f"Executing notebook: {notebook_path}")
            logger.debug(f"Parameters: {params}")

            papermill.execute_notebook(
                input_path=str(notebook_path),
                output_path=str(output_path),
                parameters=params,
                kernel_name="python3",
            )

            execution_time = time.time() - start_time

            logger.info(f"Notebook executed successfully in {execution_time:.1f}s")

            return {
                "status": "success",
                "output_notebook": str(output_path),
                "execution_time": execution_time,
                "parameters_used": params,
            }

        except Exception as e:
            # Check if it's a Papermill execution error
            error_type = type(e).__name__
            logger.error(f"Notebook execution failed ({error_type}): {e}")

            # Try to get cell information if available
            failed_cell = getattr(e, "exec_count", "unknown")

            return {
                "status": "failed",
                "error": str(e),
                "error_type": error_type,
                "failed_cell": failed_cell,
                "output_notebook": str(output_path),  # Partial results may be saved
            }

    def get_execution_summary(self, output_notebook_path: Path) -> Dict[str, Any]:
        """
        Extract execution summary from output notebook.

        Args:
            output_notebook_path: Path to executed notebook

        Returns:
            Dictionary with execution summary:
                - cells_executed: Number of cells executed
                - execution_time: Total execution time
                - has_errors: Whether execution had errors
                - output_data: Outputs from cells
        """
        try:
            with open(output_notebook_path) as f:
                nb = nbformat.read(f, as_version=4)

            executed_cells = 0
            total_time = 0.0
            has_errors = False
            outputs = []

            for cell in nb.cells:
                if cell.cell_type == "code":
                    executed_cells += 1

                    # Check for execution time
                    if "execution" in cell.metadata:
                        exec_time = cell.metadata["execution"].get(
                            "iopub.execute_input"
                        )
                        if exec_time:
                            total_time += 1  # Placeholder, need proper timing

                    # Check for errors
                    for output in cell.get("outputs", []):
                        if output.get("output_type") == "error":
                            has_errors = True

                        # Collect outputs (limit size)
                        if output.get("output_type") == "stream":
                            outputs.append(output.get("text", ""))

            return {
                "cells_executed": executed_cells,
                "execution_time": total_time,
                "has_errors": has_errors,
                "outputs": outputs[:10],  # Limit to first 10 outputs
            }

        except Exception as e:
            logger.error(f"Failed to extract execution summary: {e}")
            return {
                "error": str(e),
                "cells_executed": 0,
                "has_errors": True,
            }

    def list_parameters(self, notebook_path: Path) -> Dict[str, Any]:
        """
        Extract available parameters from notebook.

        Args:
            notebook_path: Path to .ipynb file

        Returns:
            Dictionary of parameter names and default values
        """
        try:
            with open(notebook_path) as f:
                nb = nbformat.read(f, as_version=4)

            # Find parameters cell (tagged with 'parameters')
            for cell in nb.cells:
                if cell.cell_type == "code" and "tags" in cell.metadata:
                    if "parameters" in cell.metadata["tags"]:
                        # Parse parameters from cell source
                        parameters = {}
                        for line in cell.source.split("\n"):
                            line = line.strip()
                            if line and not line.startswith("#") and "=" in line:
                                try:
                                    parts = line.split("=", 1)
                                    param_name = parts[0].strip()
                                    param_value = parts[1].strip()
                                    parameters[param_name] = param_value
                                except Exception:
                                    continue

                        return parameters

            return {}

        except Exception as e:
            logger.error(f"Failed to extract parameters: {e}")
            return {}

    def validate_papermill_availability(self) -> bool:
        """
        Check if Papermill is available for execution.

        Returns:
            True if Papermill can be imported, False otherwise
        """
        if not PAPERMILL_AVAILABLE:
            logger.error("Papermill not available. Install with: pip install papermill")
        return PAPERMILL_AVAILABLE
