import os
import threading
import time
import warnings

import duckdb

from pyulysses.connector.pyarrow_client import (
    connect_to_dremio_flight_server_endpoint,
)
from pyulysses.env_var.ulysses_env_var import get_string_env_var
from pyulysses.logger.ulysses_logger import UlyssesLogger, set_logging_level

# Suppress Arrow buffer alignment warnings
warnings.filterwarnings('ignore', category=UserWarning, module='pyarrow')
os.environ['ARROW_DEFAULT_MEMORY_POOL'] = 'system'

LOGGER = set_logging_level(
    level_string=(
        get_string_env_var(UlyssesLogger.LOGGING_LEVEL, is_mandatory=False)
        or UlyssesLogger.DEFAULT_LOGGING_LEVEL
    ),
    module_name=__name__,
)


class Client:
    """
    The Client class represents a client for the Dremio server.

    Attributes:
    username (str): The username for authentication.
    token (str): The token for authentication.
    host (str): The hostname of the Dremio server.
    port (int): The port number of the Dremio server.
    """

    def __init__(self, username, token, host=None, port=None):
        """
        The constructor for Client class.

        Parameters:
        username (str): The username for authentication.
        token (str): The token for authentication.
        host (str): The hostname of the Dremio server. If not provided, reads from DREMIO_HOST env var.
        port (int): The port number of the Dremio server. If not provided, reads from DREMIO_PORT env var.
        """
        self.username = username
        self.token = token
        self.host = host or get_string_env_var(
            'DREMIO_HOST', default='dremio.example.com', is_mandatory=False
        )
        self.port = port or int(
            get_string_env_var(
                'DREMIO_PORT', default='9047', is_mandatory=False
            )
        )

    def set_host(self, host=None, port=None):
        """
        The function to set a new host and port for the Dremio server.

        Parameters:
        host (str): The new hostname of the Dremio server. Defaults to None.
        port (int): The new port number of the Dremio server. Defaults to None.
        """
        if host is not None:
            self.host = host
        if port is not None:
            self.port = port

        LOGGER.info(f'[info] switched to {self.host} on port {self.port}')

    def query(
        self, query, retries=1, delay=2, query_timeout=60, arrow_format=False
    ):
        """
        This function will query the datahub but taking into consideration retries and
        some delay, measured in seconds.

        Parameters:
        query (str): The query to be executed.
        retries (int): Number of retries. Defaults to 1.
        delay (int): Delay in seconds between the retries of the query. Defaults to 2.
        query_timeout (int): Timeout for each query attempt in seconds. Defaults to 60.
        arrow_format(boolean): Use of pyarrow default output format (True) or convert output to DuckDB (false)

        Returns:
        duckdb.DuckDBPyRelation or pyarrow.Table: The result of the query.
        """

        if retries > 10 or retries < 0:
            LOGGER.error(
                'Number of retries must not be lower than 0 and must be up to 10.'
            )
            return None
        retries = retries + 1

        for attempt in range(retries):
            try:
                LOGGER.debug(f'[Retry {attempt + 1}] Querying Dremio: {query}')
                return self.query_with_timeout(
                    query, query_timeout, arrow_format
                )
            except Exception as e:
                LOGGER.error(f'Querying Dremio failed. Reason: {e}')

            if attempt < retries - 1:
                LOGGER.info(f'Retrying in {delay} seconds...')
                time.sleep(delay)

        LOGGER.error('All retry attempts failed.')
        return None

    def query_with_timeout(self, query, timeout, arrow_format):
        """
        Executes a query with a specified timeout.
        This method runs the provided query in a separate thread and waits for it to complete within the given timeout period.
        If the query does not complete within the timeout, an error is logged and the method returns None.
        If an exception occurs during the execution of the query, it is raised after the timeout period.

        Parameters:
            query (str): The query to be executed.
            timeout (int): The maximum time (in seconds) to wait for the query to complete.
            arrow_format(boolean): Use of pyarrow default output format (True) or convert output to DuckDB (false)

        Returns:
            Any: The result of the query if it completes within the timeout period, otherwise None.

        Raises:
            Exception: If an exception occurs during the execution of the query, it is raised after the timeout period.
        """

        result = [None]
        exception = [None]

        def target():
            try:
                if arrow_format:
                    result[0] = connect_to_dremio_flight_server_endpoint(
                        self.host,
                        self.port,
                        self.username,
                        self.token,
                        query,
                        True,
                        False,
                        True,
                        False,
                        False,
                        False,
                    )
                else:
                    reader = connect_to_dremio_flight_server_endpoint(
                        self.host,
                        self.port,
                        self.username,
                        self.token,
                        query,
                        True,
                        False,
                        True,
                        False,
                        False,
                        False,
                    )
                    # Convert Arrow table to DuckDB relation
                    arrow_table = reader.read_all()
                    result[0] = duckdb.arrow(arrow_table)
            except TimeoutException as e:
                # Handle timeout errors
                LOGGER.error(
                    f'Query execution exceeded the timeout of {timeout} seconds.'
                )
                raise e
            except Exception as e:
                LOGGER.error(
                    f'Query failed with error. Query: {query}, Error: {e}'
                )
                raise e

        thread = threading.Thread(target=target)
        thread.start()

        start_time = time.time()
        thread.join(timeout)
        elapsed_time = time.time() - start_time

        if thread.is_alive():
            # LOGGER.error(f"Query timed out after {timeout} seconds")
            raise TimeoutException(f'Query timed out after {timeout} seconds')
        if result[0] is not None:
            LOGGER.info(f'Query executed in {elapsed_time:.2f} seconds')
            return result[0]
        if exception[0]:
            raise exception[0]

    def list_tables(self):
        """
        The function to list all tables in the Dremio server.

        Returns:
        duckdb.DuckDBPyRelation: The list of tables.
        """
        query = 'SELECT * FROM INFORMATION_SCHEMA."TABLES"'
        return self.query(query)

    def list_columns(self, table):
        """
        The function to list all columns in a table.

        Parameters:
        table (str): The name of the table.

        Returns:
        duckdb.DuckDBPyRelation: The list of columns in the table.
        """
        query = f"SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME='{table}'"
        return self.query(query)


class TimeoutException(Exception):
    """Custom exception for query timeout."""

    pass
