"""
Exploit Engine Orchestrator for LogicPwn
- Executes exploit chains step-by-step
- Manages session state, retries, validation, and logging
- Extensible, maintainable, and interoperable
"""
import time
from typing import Any, Dict, Optional
from logicpwn.core.exploit_engine.models import ExploitChain, ExploitStep, ExploitResult, ExploitStatus
from logicpwn.core.exploit_engine.payload_generator import inject_payload
from logicpwn.core.exploit_engine.validation_engine import validate_step_success
from logicpwn.core.exploit_engine.exploit_logger import ExploitLogger
from logicpwn.models.request_config import RequestConfig
import requests
from logicpwn.core.performance import monitor_performance, PerformanceMonitor
from logicpwn.core.cache.session_cache import SessionCache
from logicpwn.core.cache.response_cache import ResponseCache
from logicpwn.core.reliability import CircuitBreaker, CircuitBreakerConfig, CircuitBreakerOpenException, circuit_breaker_registry
from logicpwn.core.reliability import record_security_event, SecurityEventType, SecuritySeverity
import os

try:
    import yaml
except ImportError:
    yaml = None
import json
import asyncio
import functools
import copy
import re

session_cache = SessionCache()
response_cache = ResponseCache()

logger = ExploitLogger()

def _cleanup_session_state(session_state: Dict[str, Any], max_size: int, max_keys: int) -> None:
    """
    Clean up session state to prevent memory leaks.
    
    Args:
        session_state: The session state dictionary to clean
        max_size: Maximum size in characters for string values
        max_keys: Maximum number of keys to keep
    """
    # Remove oldest entries if we have too many keys
    if len(session_state) > max_keys:
        # Keep only the most recent entries (basic LRU-like behavior)
        keys_to_remove = list(session_state.keys())[:-max_keys]
        for key in keys_to_remove:
            del session_state[key]
    
    # Truncate large string values
    for key, value in session_state.items():
        if isinstance(value, str) and len(value) > max_size:
            session_state[key] = value[:max_size] + "...[TRUNCATED]"
        elif isinstance(value, (list, dict)) and len(str(value)) > max_size:
            # For complex types, convert to string and truncate if too large
            str_value = str(value)
            if len(str_value) > max_size:
                session_state[key] = str_value[:max_size] + "...[TRUNCATED]"

def _replace_template_variables(request_config: RequestConfig, session_state: Dict[str, Any]) -> RequestConfig:
    """
    Replace template variables like {{variable_name}} in request configuration.
    
    Args:
        request_config: The request configuration to process
        session_state: Dictionary containing variable values
    
    Returns:
        RequestConfig with template variables replaced
    """
    config = copy.deepcopy(request_config)
    
    def replace_in_string(text: str, state: Dict[str, Any]) -> str:
        """Replace {{variable}} patterns in a string."""
        if not isinstance(text, str):
            return text
            
        # Find all {{variable}} patternsm u./llk
        pattern = r'\{\{([^}]+)\}\}'
        
        def replacer(match):
            var_name = match.group(1).strip()
            return str(state.get(var_name, match.group(0)))  # Keep original if not found
            
        return re.sub(pattern, replacer, text)
    
    def replace_in_dict(obj: Dict[str, Any], state: Dict[str, Any]) -> Dict[str, Any]:
        """Recursively replace template variables in dictionary values."""
        if not isinstance(obj, dict):
            return obj
            
        result = {}
        for key, value in obj.items():
            if isinstance(value, str):
                result[key] = replace_in_string(value, state)
            elif isinstance(value, dict):
                result[key] = replace_in_dict(value, state)
            elif isinstance(value, list):
                result[key] = [replace_in_string(item, state) if isinstance(item, str) else item for item in value]
            else:
                result[key] = value
        return result
    
    # Replace in URL
    config.url = replace_in_string(config.url, session_state)
    
    # Replace in headers
    if config.headers:
        config.headers = replace_in_dict(config.headers, session_state)
    
    # Replace in JSON data
    if config.json_data:
        config.json_data = replace_in_dict(config.json_data, session_state)
    
    # Replace in form data
    if config.data:
        config.data = replace_in_dict(config.data, session_state)
    
    # Replace in raw body
    if config.raw_body:
        config.raw_body = replace_in_string(config.raw_body, session_state)
    
    # Replace in params
    if config.params:
        config.params = replace_in_dict(config.params, session_state)
    
    return config

def extract_data_from_response(response: requests.Response, extractors: Dict[str, Any]) -> Dict[str, Any]:
    """
    Extracts data from a response using a dictionary of extractors.
    """
    extracted_data = {}
    for key, extractor in extractors.items():
        if isinstance(extractor, str):
            # Simple string extraction (e.g., "json.key")
            try:
                extracted_data[key] = response.json()[extractor]
            except (json.JSONDecodeError, KeyError) as e:
                logger.log_warning(f"Failed to extract data for key '{key}' from response: {e}")
        elif callable(extractor):
            # Callable extractor (e.g., a function that takes response and returns data)
            try:
                extracted_data[key] = extractor(response)
            except Exception as e:
                logger.log_warning(f"Failed to extract data for key '{key}' using callable extractor: {e}")
        else:
            logger.log_warning(f"Unsupported extractor type for key '{key}': {type(extractor)}")
    return extracted_data

@monitor_performance("exploit_chain_execution")
def run_exploit_chain(
    session: requests.Session,
    chain: ExploitChain,
    config: Optional[Dict[str, Any]] = None
) -> Any:
    """
    Execute a complete exploit chain with validation and logging.
    - Executes each step sequentially
    - Validates success/failure after each step
    - Updates session state between steps
    - Handles retries and error recovery
    - Generates comprehensive execution report
    - Implements session state size limits to prevent memory leaks
    - Circuit breaker protection for reliability
    """
    # Initialize with size limits to prevent memory leaks
    MAX_SESSION_STATE_SIZE = config.get("max_session_state_size", 1000) if config else 1000
    MAX_SESSION_STATE_KEYS = config.get("max_session_state_keys", 100) if config else 100
    
    # Initialize circuit breaker for this chain
    circuit_breaker_config = CircuitBreakerConfig(
        failure_threshold=config.get("circuit_breaker_failure_threshold", 3) if config else 3,
        recovery_timeout=config.get("circuit_breaker_recovery_timeout", 30.0) if config else 30.0,
        success_threshold=config.get("circuit_breaker_success_threshold", 2) if config else 2
    )
    circuit_breaker = circuit_breaker_registry.get_breaker(
        f"exploit_chain_{chain.name}", 
        circuit_breaker_config
    )
    
    logger = ExploitLogger()
    chain.session_state = chain.session_state or {}
    chain.execution_log = []
    
    def execute_step_with_circuit_breaker(step: ExploitStep, attempt: int):
        """Execute a single step with circuit breaker protection."""
        # Inject payloads if any
        req_cfg = step.request_config
        if step.payload_injection_points:
            req_cfg = inject_payload(req_cfg, step.payload_injection_points, chain.session_state)
        
        # Handle template variable replacement (e.g., {{extracted_uuid}})
        req_cfg = _replace_template_variables(req_cfg, chain.session_state)
        
        req_kwargs = req_cfg.model_dump(exclude_none=True)
        method = req_kwargs.pop("method")
        url = req_kwargs.pop("url")
        
        # Fix field name mapping for requests library compatibility
        if 'verify_ssl' in req_kwargs:
            req_kwargs['verify'] = req_kwargs.pop('verify_ssl')
        if 'json_data' in req_kwargs:
            req_kwargs['json'] = req_kwargs.pop('json_data')
        
        # Check cache first for GET requests
        if method.upper() == "GET":
            cached_response = response_cache.get_response(
                url, method, req_kwargs.get('params'), req_kwargs.get('headers')
            )
            if cached_response:
                logger.log_step(step.name, "cached", {"cache_hit": True, "url": url})
                return cached_response, req_kwargs
        
        # Make request through circuit breaker
        response = circuit_breaker.call(session.request, method, url, **req_kwargs)
        
        # Cache the response for GET requests
        if method.upper() == "GET" and hasattr(response, 'status_code'):
            response_cache.set_response(
                url, method, response, req_kwargs.get('params'), req_kwargs.get('headers')
            )
        
        return response, req_kwargs
    
    for step in chain.steps:
        result = None
        for attempt in range(step.retry_count):
            start = time.time()
            try:
                # Execute request with circuit breaker protection
                response, req_kwargs = execute_step_with_circuit_breaker(step, attempt)
                
                # Validate response
                validation = validate_step_success(response, step, chain.session_state)
                
                # Extract data if needed
                extracted_data = {}
                if hasattr(step, 'data_extractors') and step.data_extractors:
                    extracted_data = extract_data_from_response(response, step.data_extractors)
                
                exec_time = time.time() - start
                result = ExploitResult(
                    step_name=step.name,
                    status=ExploitStatus.SUCCESS if validation.is_valid else ExploitStatus.FAILED,
                    response=response,
                    validation_result=validation,
                    execution_time=exec_time,
                    extracted_data=extracted_data
                )
                
                logger.log_step(step.name, result.status.value, {
                    "execution_time": exec_time,
                    "attempt": attempt + 1,
                    "matched": validation.matched_indicators,
                    "failed": validation.failed_indicators,
                    "circuit_breaker_state": circuit_breaker.state.value
                })
                
                logger.log_request_response(step.name, req_kwargs, response)
                
                if validation.is_valid:
                    chain.session_state.update(validation.extracted_data)
                    chain.session_state.update(extracted_data)
                    
                    # Implement memory leak prevention
                    _cleanup_session_state(chain.session_state, MAX_SESSION_STATE_SIZE, MAX_SESSION_STATE_KEYS)
                    
                    # Record security event for successful exploit step
                    record_security_event(
                        SecurityEventType.EXPLOIT_SUCCESS,
                        SecuritySeverity.LOW,
                        f"Exploit step '{step.name}' completed successfully",
                        metadata={"step": step.name, "chain": chain.name},
                        source_module="exploit_engine"
                    )
                    
                    break
                else:
                    logger.log_error(step.name, f"Validation failed: {validation.reasons}")
                    
                    # Record security event for failed validation
                    record_security_event(
                        SecurityEventType.EXPLOIT_FAILURE,
                        SecuritySeverity.MEDIUM,
                        f"Exploit step '{step.name}' validation failed: {validation.reasons}",
                        metadata={"step": step.name, "chain": chain.name, "reasons": validation.reasons},
                        source_module="exploit_engine"
                    )
                    
            except CircuitBreakerOpenException as e:
                exec_time = time.time() - start
                result = ExploitResult(
                    step_name=step.name,
                    status=ExploitStatus.FAILED,
                    error_message=f"Circuit breaker open: {str(e)}",
                    execution_time=exec_time
                )
                logger.log_error(step.name, f"Circuit breaker prevented execution: {str(e)}")
                
                # Record security event for circuit breaker activation
                record_security_event(
                    SecurityEventType.CIRCUIT_BREAKER_OPEN,
                    SecuritySeverity.HIGH,
                    f"Circuit breaker prevented execution for step '{step.name}': {str(e)}",
                    metadata={"step": step.name, "chain": chain.name, "error": str(e)},
                    source_module="exploit_engine"
                )
                
                break  # Don't retry when circuit breaker is open
                
            except Exception as e:
                exec_time = time.time() - start
                result = ExploitResult(
                    step_name=step.name,
                    status=ExploitStatus.FAILED,
                    error_message=str(e),
                    execution_time=exec_time
                )
                logger.log_error(step.name, str(e))
                
            if step.delay_after_step:
                time.sleep(step.delay_after_step)
        
        chain.execution_log.append(result)
        if step.critical and (not result or result.status != ExploitStatus.SUCCESS):
            logger.log_error(step.name, "Critical step failed, aborting chain.")
            break
    
    # Log circuit breaker metrics
    cb_metrics = circuit_breaker.get_metrics()
    logger.log_info(f"Circuit breaker metrics for chain '{chain.name}'", cb_metrics)
    
    logger.log_chain(chain.name, [r.model_dump() for r in chain.execution_log])
    return chain.execution_log

async def async_run_exploit_chain(
    session: requests.Session,
    chain: ExploitChain,
    config: Optional[Dict[str, Any]] = None
) -> Any:
    """
    Async version of run_exploit_chain. Supports parallel execution of steps if chain.parallel_execution is True.
    Includes circuit breaker protection for enhanced reliability.
    """
    # Initialize with size limits to prevent memory leaks
    MAX_SESSION_STATE_SIZE = config.get("max_session_state_size", 1000) if config else 1000
    MAX_SESSION_STATE_KEYS = config.get("max_session_state_keys", 100) if config else 100
    
    # Initialize circuit breaker for this chain
    circuit_breaker_config = CircuitBreakerConfig(
        failure_threshold=config.get("circuit_breaker_failure_threshold", 3) if config else 3,
        recovery_timeout=config.get("circuit_breaker_recovery_timeout", 30.0) if config else 30.0,
        success_threshold=config.get("circuit_breaker_success_threshold", 2) if config else 2
    )
    circuit_breaker = circuit_breaker_registry.get_breaker(
        f"async_exploit_chain_{chain.name}", 
        circuit_breaker_config
    )
    
    logger = ExploitLogger()
    chain.session_state = chain.session_state or {}
    chain.execution_log = []
    
    async def run_step(step: ExploitStep):
        """Execute a single step with circuit breaker protection (async)."""
        result = None
        
        def execute_step_with_circuit_breaker():
            """Synchronous wrapper for circuit breaker protection."""
            # Inject payloads if any
            req_cfg = step.request_config
            if step.payload_injection_points:
                req_cfg = inject_payload(req_cfg, step.payload_injection_points, chain.session_state)
            
            # Handle template variable replacement (e.g., {{extracted_uuid}})
            req_cfg = _replace_template_variables(req_cfg, chain.session_state)
            
            req_kwargs = req_cfg.model_dump(exclude_none=True)
            method = req_kwargs.pop("method")
            url = req_kwargs.pop("url")
            
            # Fix field name mapping for requests library compatibility
            if 'verify_ssl' in req_kwargs:
                req_kwargs['verify'] = req_kwargs.pop('verify_ssl')
            if 'json_data' in req_kwargs:
                req_kwargs['json'] = req_kwargs.pop('json_data')
            
            # Check cache first for GET requests
            if method.upper() == "GET":
                cached_response = response_cache.get_response(
                    url, method, req_kwargs.get('params'), req_kwargs.get('headers')
                )
                if cached_response:
                    logger.log_step(step.name, "cached", {"cache_hit": True, "url": url})
                    return cached_response, req_kwargs
            
            # Make request through circuit breaker
            response = circuit_breaker.call(session.request, method, url, **req_kwargs)
            
            # Cache the response for GET requests
            if method.upper() == "GET" and hasattr(response, 'status_code'):
                response_cache.set_response(
                    url, method, response, req_kwargs.get('params'), req_kwargs.get('headers')
                )
            
            return response, req_kwargs
        
        for attempt in range(step.retry_count):
            start = time.time()
            try:
                # Execute request with circuit breaker protection
                loop = asyncio.get_event_loop()
                response, req_kwargs = await loop.run_in_executor(None, execute_step_with_circuit_breaker)
                
                # Validate response
                validation = validate_step_success(response, step, chain.session_state)
                
                # Extract data if needed
                extracted_data = {}
                if hasattr(step, 'data_extractors') and step.data_extractors:
                    extracted_data = extract_data_from_response(response, step.data_extractors)
                    chain.session_state.update(extracted_data)
                
                exec_time = time.time() - start
                result = ExploitResult(
                    step_name=step.name,
                    status=ExploitStatus.SUCCESS if validation.is_valid else ExploitStatus.FAILED,
                    response=response,
                    validation_result=validation,
                    execution_time=exec_time,
                    extracted_data=extracted_data
                )
                
                logger.log_step(step.name, result.status.value, {
                    "execution_time": exec_time,
                    "attempt": attempt + 1,
                    "matched": validation.matched_indicators,
                    "failed": validation.failed_indicators,
                    "circuit_breaker_state": circuit_breaker.state.value
                })
                
                logger.log_request_response(step.name, req_kwargs, response)
                
                if validation.is_valid:
                    chain.session_state.update(validation.extracted_data)
                    
                    # Implement memory leak prevention
                    _cleanup_session_state(chain.session_state, MAX_SESSION_STATE_SIZE, MAX_SESSION_STATE_KEYS)
                    
                    # Record security event for successful exploit step
                    record_security_event(
                        SecurityEventType.EXPLOIT_SUCCESS,
                        SecuritySeverity.LOW,
                        f"Async exploit step '{step.name}' completed successfully",
                        metadata={"step": step.name, "chain": chain.name, "async": True},
                        source_module="exploit_engine"
                    )
                    
                    break
                else:
                    logger.log_error(step.name, f"Validation failed: {validation.reasons}")
                    
                    # Record security event for failed validation
                    record_security_event(
                        SecurityEventType.EXPLOIT_FAILURE,
                        SecuritySeverity.MEDIUM,
                        f"Async exploit step '{step.name}' validation failed: {validation.reasons}",
                        metadata={"step": step.name, "chain": chain.name, "reasons": validation.reasons, "async": True},
                        source_module="exploit_engine"
                    )
                    
            except CircuitBreakerOpenException as e:
                exec_time = time.time() - start
                result = ExploitResult(
                    step_name=step.name,
                    status=ExploitStatus.FAILED,
                    error_message=f"Circuit breaker open: {str(e)}",
                    execution_time=exec_time
                )
                logger.log_error(step.name, f"Circuit breaker prevented execution: {str(e)}")
                
                # Record security event for circuit breaker activation
                record_security_event(
                    SecurityEventType.CIRCUIT_BREAKER_OPEN,
                    SecuritySeverity.HIGH,
                    f"Circuit breaker prevented async execution for step '{step.name}': {str(e)}",
                    metadata={"step": step.name, "chain": chain.name, "error": str(e), "async": True},
                    source_module="exploit_engine"
                )
                
                break  # Don't retry when circuit breaker is open
                
            except Exception as e:
                exec_time = time.time() - start
                result = ExploitResult(
                    step_name=step.name,
                    status=ExploitStatus.FAILED,
                    error_message=str(e),
                    execution_time=exec_time
                )
                logger.log_error(step.name, str(e))
                
            if step.critical and (not result or result.status != ExploitStatus.SUCCESS):
                logger.log_error(step.name, "Critical step failed, aborting chain.")
                break
                
            if step.delay_after_step:
                await asyncio.sleep(step.delay_after_step)
        
        return result
    
    # Execute steps based on parallel execution setting
    if getattr(chain, 'parallel_execution', False):
        results = await asyncio.gather(*(run_step(step) for step in chain.steps))
        chain.execution_log.extend(results)
    else:
        for step in chain.steps:
            result = await run_step(step)
            chain.execution_log.append(result)
            if step.critical and result and result.status != ExploitStatus.SUCCESS:
                break
    
    # Log circuit breaker metrics
    cb_metrics = circuit_breaker.get_metrics()
    logger.log_info(f"Async circuit breaker metrics for chain '{chain.name}'", cb_metrics)
    
    logger.log_chain(chain.name, [r.model_dump() for r in chain.execution_log])
    return chain.execution_log

def load_exploit_chain_from_file(file_path: str) -> ExploitChain:
    """
    Load an exploit chain from a YAML or JSON file.
    """
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"Exploit chain file not found: {file_path}")
    with open(file_path, 'r') as f:
        if file_path.endswith('.yaml') or file_path.endswith('.yml'):
            if not yaml:
                raise ImportError("PyYAML is required to load YAML files.")
            data = yaml.safe_load(f)
        elif file_path.endswith('.json'):
            data = json.load(f)
        else:
            raise ValueError("Unsupported file format. Use .yaml, .yml, or .json")
    return ExploitChain.parse_obj(data) 