import json
import requests
from typing import  List, Dict
from .input_scanners import Scanner, ScannerResult
import os
import asyncio
import inspect
from typing import Any, Awaitable, Callable, Dict, Optional, Union
import httpx

Callback = Callable[[ScannerResult], Union[None, Awaitable[None]]]

REMOTE_TS_API_ADDRESS = 'https://api.testsavant.ai'

REQUEST_JSON = """
{
    "prompt": "{PROMPT}",
    "config": {
        "project_id": "{PROJECT_ID}",
        "fail_fast": true,
        "cache": {
            "enabled": true,
            "ttl": 3600
        }
    },
    "use": [{SCANNERS}],
}
""".replace('\n', ' ')


class Guard:
    """Guard is a class for managing and executing prompt and response scanning using various scanners, 
    with support for both text and multimodal (image) inputs. It handles API requests to a remote 
    scanning service, manages scanner configuration, and supports both synchronous and asynchronous 
    operations.
    Parameters
    ----------
    API_KEY : str
        API key for authenticating requests to the remote scanning service.
    PROJECT_ID : str
        Identifier for the project context in which scanning is performed.
        Base URL for the remote API endpoint. Defaults to REMOTE_TS_API_ADDRESS.
    fail_fast : bool, optional
        If True, scanning will stop at the first failure. Defaults to True.
    Attributes
    ----------
    API_KEY : str
        The API key used for authentication.
    PROJECT_ID : str
        The project identifier.
    fail_fast : bool
        Whether to stop scanning on the first failure.
    scanners : List[Scanner] or None
        List of scanner instances added to this Guard.
    remote_addr : str
        The remote API endpoint address.
    Methods
    -------
    add_scanner(scanner: Scanner)
        Add a scanner instance to the Guard.
    _scanners_to_dict(scanners: List[Scanner], request_only=False, multimodal=False)
        Convert scanners to a dictionary format for API requests.
    _prepare_request_json(prompt, project_id, scanners: Dict, output=None, multimodal=False)
        Prepare the JSON payload for API requests.
    make_request(data, url: str, files: List[str]=None, async_mode: bool=False, callback: Optional[Callback]=None)
        Make a request to the remote API, handling both text and multimodal inputs.
    make_text_request(data, url: str, async_mode: bool=False, callback: Optional[Callback]=None)
        Make a text-only request to the remote API.
    make_multimodal_request(data, url: str, files: List[str], async_mode: bool=False, callback: Optional[Callback]=None)
        Make a multimodal (image) request to the remote API.
    request_api(url: str, *, method: str="POST", data: Optional[Union[Dict[str, Any], str]]=None, files: Optional[List[str]]=None, headers: Optional[Dict[str, str]]=None, timeout: Optional[float]=10, async_mode: bool=False, callback: Optional[Callback]=None)
        Generic method for making synchronous or asynchronous API requests.
    fetch_image_results(image_file_names: List[str], download_dir: str)
        Download and save image results from the remote API.
    Raises
    ------
    ValueError
        If required parameters are missing or invalid.
    Exception
        For failed API requests or file operations.
    """
    def __init__(self, API_KEY, PROJECT_ID, remote_addr=REMOTE_TS_API_ADDRESS, fail_fast=True):
        """        
        scan_mode : str
            "input": analyzes prompts sent to the llm
            "output": analyzes responses generated by the llm
        
        remote_addr : str, optional
            Base URL for the remote API endpoint.
            Default is https://api.testsavant.ai
        """
        self.API_KEY = API_KEY
        self.PROJECT_ID = PROJECT_ID
        self.fail_fast = fail_fast
        self.scanners: List[Scanner] = None
        self.remote_addr = remote_addr
    
    # remove all scanners
    def remove_all_scanners(self):
        """
        Removes all scanners from the Guard instance.
        This method clears the list of scanners, effectively resetting the Guard.
        """
        self.scanners = None

    def add_scanner(self, scanner: Scanner):
        """
        Adds a Scanner instance to the list of scanners.
        If the scanners list is not initialized, it creates an empty list before appending the new scanner.
        Args:
            scanner (Scanner): The scanner instance to be added.
        """
        assert isinstance(scanner, Scanner), "Scanner must be an instance of Scanner class."
        if self.scanners is None:
            self.scanners = []
        self.scanners.append(scanner)

    def _scanners_to_dict(self, scanners: List[Scanner], request_only=False, multimodal=False):
        if scanners is None or len(scanners) == 0:
            raise ValueError("No scanners have been added.")
        scanners_dict = []
        requires_input = set()
        for scanner in scanners:
            if not multimodal and "Image" in scanner.__class__.__name__:
                continue
            scanners_dict.append(scanner.to_dict(request_only=request_only))
            if hasattr(scanner, '_requires_input_prompt') and scanner._requires_input_prompt:
                requires_input.add(scanner.__class__.__name__)
        return scanners_dict, requires_input

    def _prepare_request_json(self, prompt, project_id, scanners: Dict, output = None, multimodal=False):
        req_dict = {
            "prompt": prompt,
            "config": {
                "project_id": project_id,
                "fail_fast": self.fail_fast,
                "cache": {
                    "enabled": True,
                    "ttl": 3600
                }  
            },
            "use": scanners
        }
        if output:
            req_dict["output"] = output
        return req_dict
    
    def _make_request(self, data, url: str, files: List[str]=None, async_mode: bool = False, callback: Optional[Callback] = None):
        if files is not None and len(files) > 0:
            return self._make_multimodal_request(data, url, files, async_mode=async_mode, callback=callback)
        else:
            return self._make_text_request(data, url, async_mode=async_mode, callback=callback)

    def _make_text_request(self, data, url: str, async_mode: bool = False, callback: Optional[Callback] = None):
        if async_mode:
            return self._request_api_async_mode(
                url,
                data=data,
                headers={
                    'x-api-key': self.API_KEY,
                    'Content-Type': 'application/json'
                },
                async_mode=True,
                callback=callback
            )
        response = requests.post(
            url,
            headers={
                'x-api-key': self.API_KEY,
                'Content-Type': 'application/json'
            },
            data=data
        )
        if response.status_code != 200:
            raise Exception(f"Request failed with status code {response.status_code}")
        
        response_json = response.json()
        return ScannerResult(**response_json)
    
    def _make_multimodal_request(self, data, url: str, files: List[str], async_mode: bool = False, callback: Optional[Callback] = None):
        # enure files is not None and is a list of file paths
        if files is None or len(files) == 0:
            raise ValueError("Files must be provided for multi-modal scanning.")
        payload = {
            'metadata': data
        }
        
        payload_files = []
        for file_path in files:
            
            if not os.path.exists(file_path):
                raise ValueError(f"File {file_path} does not exist.")
            
            if not file_path.lower().endswith(('.png', '.jpg', '.jpeg')):
                raise ValueError(f"File {file_path} is not a valid image type.")
            image_type = 'image/jpeg' if file_path.lower().endswith(('.jpg', '.jpeg')) else 'image/png'
            file_name = os.path.basename(file_path)
            payload_files.append(('images', (file_name, open(file_path, 'rb'), image_type)))
        
        if async_mode:
            return self._request_api_async_mode(
                url,
                data=data,
                files=payload_files,
                headers={
                    'x-api-key': self.API_KEY
                },
                async_mode=True,
                callback=callback
            )
        response = requests.post(
            url,
            headers={
                'x-api-key': self.API_KEY
            },
            data=payload,
            files=payload_files
        )
        
        if response.status_code != 200:
            raise Exception(f"Request failed with status code {response.status_code}")
        
        response_json = response.json()
        return ScannerResult(**response_json)
    
    def _request_api_async_mode(self,
        url: str,
        *,
        method: str = "POST",
        data: Optional[Union[Dict[str, Any], str]] = None,
        files: Optional[List[str]] = None,
        headers: Optional[Dict[str, str]] = None,
        timeout: Optional[float] = 10,
        callback: Optional[Callback] = None,
    ) -> Any:
        """
        Fetch *url* synchronously or asynchronously.
        If *callback* is given it will be invoked with the result:
            • sync branch  → called before returning
            • async branch → awaited (if coroutine) or run, then result returned
        """

        def _maybe_call(cb: Callback, result: Any) -> None:
            "Handle sync vs async callbacks transparently from either branch."
            if inspect.iscoroutinefunction(cb):
                # We’re in the sync branch → spin up a short event loop just for this.
                asyncio.run(cb(result))
            else:
                cb(result)

        async def _maybe_await(cb: Callback, result: Any) -> None:
            "Async version of _maybe_call."
            if inspect.iscoroutinefunction(cb):
                await cb(result)
            else:
                cb(result)

        async def _coroutine() -> Any:
            async with httpx.AsyncClient() as client:
                if files is None or len(files) == 0:
                    r = await client.request(method.upper(), url, data=data, headers=headers, timeout=timeout)
                else:
                    files_to_send = [('metadata', (None, data))] + files
                    r = await client.request(
                        method.upper(),
                        url,
                        files=files_to_send,
                        headers=headers,
                        timeout=timeout,
                    )
                
                r.raise_for_status()
                out = (
                    ScannerResult(**r.json())
                    if r.headers.get("Content-Type", "").startswith("application/json")
                    else r.text
                )
                if callback:
                    await _maybe_await(callback, out)
                return out

        return _coroutine()            # caller must await


        
    def fetch_image_results(self, image_file_names: List[str], download_dir: str):
        """
        Downloads image files from a remote server and saves them to a specified directory.
        Args:
            image_file_names (List[str]): A list of image file names to fetch from the server.
            download_dir (str): The directory where the downloaded images will be saved.
        Raises:
            AssertionError: If `image_file_names` is not a list or contains non-string elements.
            ValueError: If `image_file_names` is empty.
            Exception: If the server response is not successful or if saving an image fails.
        Notes:
            - The method sends a POST request for each image file name to the remote server.
            - The server endpoint is constructed using `self.remote_addr` and expects an API key and project ID.
            - If the download directory does not exist, it will be created.
        """


        assert isinstance(image_file_names, list), "image_file_names must be a list of file names."
        if not image_file_names:
            raise ValueError("No image file names provided.")
        assert all(isinstance(name, str) for name in image_file_names), "All file names must be strings."

        if not os.path.exists(download_dir):
            os.makedirs(download_dir)
        
        url = f"{self.remote_addr}/guard/files"
        for fi, file_name in enumerate(image_file_names):
            response = requests.post(
                url,
                headers={'x-api-key': self.API_KEY},
                json={
                    "project_id": self.PROJECT_ID,
                    "file_name": file_name
                }
            )
            
            if response.status_code != 200:
                raise Exception(f"Failed to fetch image results: {response.status_code}")
        
            image_data = response.content
            file_path = os.path.join(download_dir, file_name)
            try:
                with open(file_path, 'wb') as f:
                    f.write(image_data)
            except Exception as e:
                raise Exception(f"Failed to save image {file_name}: {str(e)}")
            
class InputGuard(Guard):
    """
    InputGuard is a subclass of Guard that provides input validation and scanning functionality for prompts and files.
    Args:
        API_KEY (str): The API key used for authentication.
        PROJECT_ID (str): The project identifier.
        remote_addr (str, optional): The remote address of the TS API. Defaults to REMOTE_TS_API_ADDRESS.
        fail_fast (bool, optional): Whether to fail fast on errors. Defaults to True.
    Methods:
        scan(prompt: str, files: List[str]=None, is_async=False, callback: Callback=None) -> Union[ScannerResult, Any]:
            Scans the provided prompt and/or files using the configured scanners.
            Args:
                prompt (str): The input prompt to scan.
                files (List[str], optional): List of file paths to scan. Defaults to None.
                is_async (bool, optional): Whether to perform the scan asynchronously. Defaults to False.
                callback (Callback, optional): Callback function to be called with the result if async. Defaults to None.
            Returns:
                Union[ScannerResult, Any]: The result of the scan, or the async task if is_async is True.
            Raises:
                ValueError: If no scanners have been added, or if neither prompt nor files are provided.
                AssertionError: If files is not a list of strings.
    """

    def __init__(self, API_KEY, PROJECT_ID, remote_addr=REMOTE_TS_API_ADDRESS, fail_fast=True):
        """
        Initializes the Guard class with the provided API key, project ID, and optional parameters.
        Args:
            API_KEY (str): The API key used for authentication.
            PROJECT_ID (str): The identifier for the project.
            remote_addr (str, optional): The remote address of the TS API. Defaults to REMOTE_TS_API_ADDRESS.
            fail_fast (bool, optional): If True, initialization will fail immediately on errors. Defaults to True.
        """

        super().__init__(API_KEY, PROJECT_ID, remote_addr,fail_fast=fail_fast)
        self.remote_addr = remote_addr

    
    def scan(self, prompt: str, 
             files: List[str]=None, 
             is_async=False, callback: 
             Callback = None) -> Union[ScannerResult, Any]:
        """
        Scans the provided prompt and/or files using the configured scanners.
        Args:
            prompt (str): The input text prompt to be scanned.
            files (List[str], optional): A list of file paths to be scanned. Defaults to None.
            is_async (bool, optional): Whether to perform the scan asynchronously. Defaults to False.
            callback (Callback, optional): A callback function to be called upon completion (used if is_async is True). Defaults to None.
        Returns:
            Union[ScannerResult, Any]: The result of the scan operation. Returns a ScannerResult object for synchronous calls,
            or any type as defined by the callback for asynchronous calls.
        Raises:
            ValueError: If no scanners have been added, or if neither prompt nor files are provided.
            AssertionError: If files is not a list of strings.
        """

        if not self.scanners:
            raise ValueError("No scanners have been added.")
        
        if type(prompt) is not str and (type(files) is not list or len(files) == 0):
            num_files = len(files) if files else 0
            raise ValueError(f"No scanners have been added, and no prompt or files provided for scanning. prompt: {type(prompt)}, files: {type(files)}, num_files: {num_files}")
        
        assert files is None or (isinstance(files, list) and all(isinstance(file, str) for file in files)), "Files must be a list of file paths."

        if files:
            files = list(set(files))
            scanners_dict, _= self._scanners_to_dict(self.scanners, request_only=True, multimodal=True)
            url = f'{self.remote_addr}/guard/image-input'
        else:
            scanners_dict, _ = self._scanners_to_dict(self.scanners, request_only=True, multimodal=False)
            url = f'{self.remote_addr}/guard/prompt-input'
        
        request_body = self._prepare_request_json(prompt, self.PROJECT_ID, scanners_dict)
        return self._make_request(json.dumps(request_body), url, files=files, async_mode=is_async, callback=callback)
    

class OutputGuard(InputGuard):
    """
    OutputGuard is a subclass of InputGuard designed to scan and validate the outputs of language models using configured scanners.
    Args:
        API_KEY (str): The API key used for authentication.
        PROJECT_ID (str): The project identifier.
        remote_addr (str, optional): The remote address of the TS API. Defaults to REMOTE_TS_API_ADDRESS.
        fail_fast (bool, optional): If True, scanning will stop at the first failure. Defaults to True.
    Methods:
        scan(prompt: Optional[str], output: Optional[str], is_async: bool = False, callback: Callback = None) -> Union[ScannerResult, Any]:
            Scans the provided LLM output (and optionally the input prompt) using the configured scanners.
            Args:
                prompt (Optional[str]): The input prompt associated with the LLM output, if required by scanners.
                output (Optional[str]): The LLM output to be scanned.
                is_async (bool, optional): If True, the scan will be performed asynchronously. Defaults to False.
                callback (Callback, optional): A callback function to be invoked with the scan result in async mode.
            Returns:
                Union[ScannerResult, Any]: The result of the scan, or the async task if is_async is True.
            Raises:
                ValueError: If no scanners have been added or if required input prompt is missing.
    """

    def __init__(self, API_KEY, PROJECT_ID, remote_addr=REMOTE_TS_API_ADDRESS, fail_fast=True):
        """
        Initializes the Guard class with API credentials and configuration options.
        Args:
            API_KEY (str): The API key used for authentication.
            PROJECT_ID (str): The project identifier.
            remote_addr (str, optional): The remote API address. Defaults to REMOTE_TS_API_ADDRESS.
            fail_fast (bool, optional): If True, initialization will fail immediately on errors. Defaults to True.
        """

        super().__init__(API_KEY, PROJECT_ID, remote_addr,fail_fast=fail_fast)
        self.remote_addr = remote_addr
        
    def scan(
        self,
        prompt: Optional[str],
        output: Optional[str],
        is_async: bool = False,
        callback: Callback = None,
    ) -> Union[ScannerResult, Any]:
        
        """
        Scans the provided prompt and/or output using the configured scanners.
        Args:
            prompt (Optional[str]): The input prompt to be scanned. Required if any scanner needs the input prompt.
            output (Optional[str]): The output (e.g., LLM response) to be scanned.
            is_async (bool, optional): Whether to perform the scan asynchronously. Defaults to False.
            callback (Callback, optional): A callback function to be invoked if running asynchronously.
        Returns:
            Union[ScannerResult, Any]: The result of the scan operation, or the asynchronous request handle if is_async is True.
        Raises:
            ValueError: If no scanners have been added, or if required input prompt is missing.
        """

        if not self.scanners:
            raise ValueError("No scanners have been added.")

        scanners_dict, requires_input_prompt = self._scanners_to_dict(self.scanners, request_only=True, multimodal=False)
        if requires_input_prompt and not prompt:
            raise ValueError(f"Input scanners {requires_input_prompt} require a input prompt along with LLM output for output scanning.")
        
        url = f'{self.remote_addr}/guard/prompt-output'
        
        request_body = self._prepare_request_json(prompt=prompt, project_id=self.PROJECT_ID, scanners=scanners_dict, output=output)
        
        return self._make_request(json.dumps(request_body), url, async_mode=is_async, callback=callback)
    