"""MCP tools for compliance and knowledge research."""

import base64
import logging
from datetime import datetime
from typing import Any

import httpx

from wistx_mcp.tools.lib.api_client import WISTXAPIClient, get_api_client
from wistx_mcp.tools.lib.web_search_client import WebSearchClient
from wistx_mcp.tools.lib.mongodb_client import MongoDBClient
from wistx_mcp.tools import visualize_infra_flow
from wistx_mcp.tools import generate_documentation
from wistx_mcp.config import settings
from wistx_mcp.tools import pricing

logger = logging.getLogger(__name__)

api_client = WISTXAPIClient()


async def get_compliance_requirements(
    resource_types: list[str],
    standards: list[str] | None = None,
    severity: str | None = None,
    include_remediation: bool = True,
    include_verification: bool = True,
    api_key: str = "",
    generate_report: bool = True,
) -> dict[str, Any]:
    """Get compliance requirements for infrastructure resources.

    Args:
        resource_types: List of resource types (RDS, S3, EC2, etc.)
        standards: List of compliance standards (PCI-DSS, HIPAA, etc.)
        severity: Filter by severity level
        include_remediation: Include remediation guidance
        include_verification: Include verification procedures
        api_key: WISTX API key (required for authentication)
        generate_report: Whether to automatically generate and store a compliance report

    Returns:
        Dictionary with compliance controls and summary.
        If generate_report=True and api_key provided, also includes:
        - report_id: Generated report ID
        - report_download_url: URL to download report
        - report_view_url: URL to view report

    Raises:
        ValueError: If input validation fails
        RuntimeError: If API call fails
        ConnectionError: If network connection fails
        TimeoutError: If request times out
    """
    if not resource_types:
        raise ValueError("At least one resource type is required")

    if not isinstance(resource_types, list):
        raise ValueError("resource_types must be a list")

    if len(resource_types) > 50:
        raise ValueError("Maximum 50 resource types allowed")

    sanitized_resource_types = []
    for rt in resource_types:
        if not isinstance(rt, str):
            raise ValueError(f"Resource type must be string, got {type(rt)}")
        rt_clean = rt.strip().upper()[:50]
        if not rt_clean or len(rt_clean) < 2:
            raise ValueError(f"Invalid resource type: {rt}")
        sanitized_resource_types.append(rt_clean)

    resource_types = sanitized_resource_types

    if standards is not None:
        if not isinstance(standards, list):
            raise ValueError("standards must be a list")
        if len(standards) > 20:
            raise ValueError("Maximum 20 standards allowed")

        sanitized_standards = []
        for std in standards:
            if not isinstance(std, str):
                raise ValueError(f"Standard must be string, got {type(std)}")
            std_clean = std.strip().upper()[:50]
            if not std_clean or len(std_clean) < 2:
                raise ValueError(f"Invalid standard: {std}")
            sanitized_standards.append(std_clean)

        standards = sanitized_standards

    if severity is not None:
        valid_severities = ["CRITICAL", "HIGH", "MEDIUM", "LOW"]
        if severity not in valid_severities:
            raise ValueError(f"severity must be one of {valid_severities}")
        severity = severity.upper()

    from wistx_mcp.tools.lib.auth_context import validate_api_key_and_get_user_id

    try:
        user_id = await validate_api_key_and_get_user_id(api_key)
    except (ValueError, RuntimeError) as e:
        raise

    try:
        result = await api_client.get_compliance_requirements(
            resource_types=resource_types,
            standards=standards or [],
            severity=severity,
            include_remediation=include_remediation,
            include_verification=include_verification,
        )

        if not isinstance(result, dict):
            logger.error("Invalid response type from API: %s", type(result))
            raise RuntimeError("Invalid response format from API")

        if "data" not in result and "controls" not in result:
            logger.error("Response missing required fields: %s", list(result.keys()))
            raise RuntimeError("Invalid response structure: missing 'data' or 'controls'")

        if "controls" in result:
            if not isinstance(result["controls"], list):
                logger.error("Controls field is not a list: %s", type(result["controls"]))
                raise RuntimeError("Invalid controls structure: expected list")

            for i, control in enumerate(result["controls"]):
                if not isinstance(control, dict):
                    logger.error("Control %d is not a dict: %s", i, type(control))
                    raise RuntimeError(f"Invalid control structure at index {i}")

                required_fields = ["control_id", "standard"]
                missing = [f for f in required_fields if f not in control]
                if missing:
                    logger.warning("Control %d missing fields: %s", i, missing)

        if "data" in result and isinstance(result["data"], dict):
            if "controls" in result["data"]:
                if not isinstance(result["data"]["controls"], list):
                    logger.error("Data.controls field is not a list: %s", type(result["data"]["controls"]))
                    raise RuntimeError("Invalid data.controls structure: expected list")

        if generate_report and api_key:
            try:
                if user_id:
                    controls = result.get("controls") or result.get("data", {}).get("controls", [])
                    if controls:
                        subject = f"Compliance Report: {', '.join(resource_types)}"
                        if standards:
                            subject += f" ({', '.join(standards)})"

                        logger.info("Generating compliance report for user %s", user_id)

                        report_result = await generate_documentation.generate_documentation(
                            document_type="compliance_report",
                            subject=subject,
                            resource_types=resource_types,
                            compliance_standards=standards or [],
                            format="markdown",
                            include_compliance=True,
                            include_security=True,
                            include_cost=False,
                            include_best_practices=True,
                        )

                        report_id = f"report-{datetime.now().strftime('%Y%m%d%H%M%S')}-{user_id[:8]}"

                        content = report_result.get("content", "")
                        output_format = report_result.get("format", "markdown")

                        content_type_map = {
                            "markdown": "text/markdown",
                            "html": "text/html",
                            "pdf": "application/pdf",
                            "docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
                            "json": "application/json",
                        }

                        content_type = content_type_map.get(output_format, "text/plain")

                        if isinstance(content, bytes):
                            content_b64 = base64.b64encode(content).decode("utf-8")
                        else:
                            content_b64 = base64.b64encode(content.encode("utf-8")).decode("utf-8")

                        try:
                            from wistx_mcp.tools.lib.mongodb_client import execute_mongodb_operation
                            from wistx_mcp.tools.lib.constants import API_TIMEOUT_SECONDS

                            async with MongoDBClient() as mongodb_client:
                                db = await mongodb_client.get_database()
                                reports_collection = db.reports

                                async def _insert_report() -> None:
                                    await reports_collection.insert_one({
                                        "report_id": report_id,
                                        "user_id": user_id,
                                        "document_type": "compliance_report",
                                        "subject": subject,
                                        "format": output_format,
                                        "content": content_b64,
                                        "content_type": content_type,
                                        "sections": report_result.get("sections", []),
                                        "metadata": report_result.get("metadata", {}),
                                        "created_at": datetime.utcnow(),
                                    })

                                await execute_mongodb_operation(
                                    _insert_report,
                                    timeout=API_TIMEOUT_SECONDS,
                                    max_retries=3,
                                )

                                base_url = settings.api_url.rstrip("/") if hasattr(settings, "api_url") else ""
                                download_url = f"{base_url}/v1/reports/{report_id}/download?format={output_format}" if base_url else ""
                                view_url = f"{base_url}/v1/reports/{report_id}/view?format={output_format}" if base_url else ""

                                result["report_id"] = report_id
                                result["report_download_url"] = download_url
                                result["report_view_url"] = view_url

                                logger.info("Compliance report generated and stored: %s", report_id)
                        except Exception as e:
                            logger.warning("Error storing compliance report: %s", e)
                    else:
                        logger.warning("No controls found, skipping report generation")
                else:
                    logger.warning("User ID not found from API key, skipping report generation")
            except Exception as e:
                logger.warning("Failed to generate compliance report: %s", e, exc_info=True)

        return result
    except ValueError as e:
        logger.error("Validation error in get_compliance_requirements: %s", e, exc_info=True)
        raise
    except httpx.HTTPStatusError as e:
        status_code = e.response.status_code if e.response else None
        logger.error("HTTP status error: %s (status: %s)", e, status_code)
        if status_code == 401:
            raise ValueError("Invalid API key") from e
        elif status_code == 429:
            raise RuntimeError("Rate limit exceeded") from e
        elif status_code >= 500:
            raise RuntimeError(f"Server error: {status_code}") from e
        raise RuntimeError(f"HTTP error: {status_code}") from e
    except httpx.TimeoutException as e:
        logger.error("Request timeout: %s", e)
        raise TimeoutError("Request timeout") from e
    except httpx.NetworkError as e:
        logger.error("Network error: %s", e)
        raise ConnectionError("Network connection failed") from e
    except httpx.HTTPError as e:
        logger.error("HTTP error in get_compliance_requirements: %s", e, exc_info=True)
        raise RuntimeError(f"HTTP error: {e}") from e
    except (RuntimeError, ConnectionError, TimeoutError) as e:
        logger.error("Error in get_compliance_requirements: %s", e, exc_info=True)
        raise
    except Exception as e:
        logger.error("Unexpected error in get_compliance_requirements: %s", e, exc_info=True)
        raise


async def research_knowledge_base(
    query: str,
    domains: list[str] | None = None,
    content_types: list[str] | None = None,
    include_cross_domain: bool = True,
    include_web_search: bool = True,
    format: str = "structured",
    max_results: int = 1000,
    api_key: str = "",
) -> dict[str, Any]:
    """Research knowledge base across all domains with optional web search.

    Deep research tool that searches internal knowledge base and optionally
    includes real-time web search results for comprehensive coverage.

    Args:
        query: Research query in natural language
        domains: Filter by domains (compliance, finops, devops, infrastructure, security, etc.)
        content_types: Filter by content types (guide, pattern, strategy, etc.)
        include_cross_domain: Include cross-domain relationships
        include_web_search: Include web search results (Tavily) for real-time information
        format: Response format (structured, markdown, executive_summary)
        max_results: Maximum number of results
        api_key: WISTX API key (required for authentication)

    Returns:
        Dictionary with research results and summary:
        - results: Knowledge articles from internal database
        - web_results: Web search results (if include_web_search=True)
        - research_summary: Summary of findings

    Raises:
        ValueError: If query validation fails
        RuntimeError: If API call fails
        ConnectionError: If network connection fails
        TimeoutError: If request times out
    """
    if not query or not isinstance(query, str):
        raise ValueError("Query must be a non-empty string")
    
    query = query.strip()
    if len(query) < 10:
        raise ValueError("Query must be at least 10 characters")
    
    if len(query) > 10000:
        raise ValueError("Query must be less than 10000 characters")
    
    if max_results < 1 or max_results > 50000:
        raise ValueError("max_results must be between 1 and 50000")
    
    if format not in ["structured", "markdown", "executive_summary"]:
        raise ValueError(f"Invalid format: {format}. Must be one of: structured, markdown, executive_summary")

    from wistx_mcp.tools.lib.input_sanitizer import validate_query_input

    validate_query_input(query)

    from wistx_mcp.tools.lib.auth_context import validate_api_key_and_get_user_id

    try:
        user_id = await validate_api_key_and_get_user_id(api_key)
    except (ValueError, RuntimeError) as e:
        raise

    web_search_client = None
    web_results = None

    try:
        result = await api_client.research_knowledge_base(
            query=query,
            domains=domains or [],
            content_types=content_types or [],
            include_cross_domain=include_cross_domain,
            response_format=format,
            max_results=max_results,
        )

        if include_web_search and settings.tavily_api_key:
            try:
                web_search_client = WebSearchClient(api_key=settings.tavily_api_key)
                
                if domains:
                    web_search_data = await web_search_client.search_by_domain(
                        query=query,
                        domains=domains,
                        max_results=max_results,
                        max_age_days=None,
                    )
                else:
                    web_search_data = await web_search_client.search_devops(
                        query=query,
                        max_results=max_results,
                        max_age_days=90,
                    )

                web_results = {
                    "answer": web_search_data.get("answer"),
                    "results": web_search_data.get("results", []),
                    "domains_searched": domains if domains else ["devops", "infrastructure"],
                    "freshness_info": web_search_data.get("freshness_info", {}),
                }

                logger.info(
                    "Added web search results to knowledge research: %d web results for domains %s",
                    len(web_search_data.get("results", [])),
                    domains if domains else ["devops", "infrastructure"],
                )

                try:
                    from api.services.web_search_storage_service import web_search_storage_service

                    storage_stats = await web_search_storage_service.store_web_search_results(
                        web_results=web_results,
                        query=query,
                        domains_searched=domains if domains else ["devops", "infrastructure"],
                        store_in_background=True,
                    )
                    logger.info(
                        "Web search results storage initiated: %d results, %d will be stored",
                        storage_stats["total_results"],
                        storage_stats.get("stored", 0) + storage_stats.get("converted", 0),
                    )
                except Exception as e:
                    logger.warning("Failed to store web search results: %s", e, exc_info=True)
            except (ValueError, RuntimeError, ConnectionError, TimeoutError) as e:
                logger.warning("Failed to include web search in research: %s", e)
            except Exception as e:
                logger.error("Unexpected error in web search: %s", e, exc_info=True)

        if web_results:
            result["web_results"] = web_results

        return result
    except (ValueError, RuntimeError, ConnectionError, TimeoutError) as e:
        logger.error("Error in research_knowledge_base: %s", e, exc_info=True)
        raise
    except Exception as e:
        logger.error("Unexpected error in research_knowledge_base: %s", e, exc_info=True)
        raise
    finally:
        if web_search_client:
            await web_search_client.close()


async def calculate_infrastructure_cost(
    resources: list[dict[str, Any]],
    api_key: str = "",
) -> dict[str, Any]:
    """Calculate infrastructure costs.

    Args:
        resources: List of resource specifications
            Example: [{"cloud": "aws", "service": "rds", "instance_type": "db.t3.medium", "quantity": 1}]
        api_key: WISTX API key (required for authentication, can be provided via context)

    Returns:
        Dictionary with cost breakdown and optimizations

    Raises:
        ValueError: If api_key is missing or invalid
    """
    from wistx_mcp.tools.lib.auth_context import validate_api_key_and_get_user_id

    try:
        user_id = await validate_api_key_and_get_user_id(api_key)
    except (ValueError, RuntimeError) as e:
        raise

    try:
        environment_name = None
        for resource in resources:
            env_name = resource.get("environment") or resource.get("environment_name")
            if env_name:
                environment_name = env_name
                break

        result = await pricing.calculate_infrastructure_cost(
            resources,
            user_id=str(user_id),
            check_budgets=True,
            environment_name=environment_name,
        )
        return result
    except ValueError as e:
        if "Budget exceeded" in str(e):
            raise
        logger.error("Error in calculate_infrastructure_cost: %s", e, exc_info=True)
        raise
    except (RuntimeError, ConnectionError, TimeoutError) as e:
        logger.error("Error in calculate_infrastructure_cost: %s", e, exc_info=True)
        raise
    except Exception as e:
        logger.error("Unexpected error in calculate_infrastructure_cost: %s", e, exc_info=True)
        raise

