import concurrent
import time
import requests

from adam.commands.command import Command
from adam.config import Config
from adam.repl_state import ReplState
from adam.utils import log, log2
from adam.utils_athena import audit_query, run_audit_query

class AuditRepairTables(Command):
    COMMAND = 'audit repair'

    # the singleton pattern
    def __new__(cls, *args, **kwargs):
        if not hasattr(cls, 'instance'): cls.instance = super(AuditRepairTables, cls).__new__(cls)

        return cls.instance

    def __init__(self, successor: Command=None):
        super().__init__(successor)
        self.auto_repaired = False

    def command(self):
        return AuditRepairTables.COMMAND

    def run(self, cmd: str, state: ReplState):
        if not(args := self.args(cmd)):
            return super().run(cmd, state)

        state, args = self.apply_state(args, state)

        tables = Config().get('audit.tables', 'audit').split(',')
        if args:
            tables = args

        self.repair(tables)

        return state

    def completion(self, state: ReplState):
        if state.device == ReplState.L:
            if not self.auto_repaired:
                if hours := Config().get('audit.athena.auto-repair.elapsed_hours', 12):
                    with concurrent.futures.ThreadPoolExecutor(max_workers=Config().get('audit.workers', 3)) as executor:
                        executor.submit(self.auto_repair, hours,)

            return super().completion(state)

        return {}

    def auto_repair(self, hours: int):
        self.auto_repaired = True

        state, _, rs = audit_query(f'select * from meta')
        if state == 'SUCCEEDED':
            do_repair = True
            if len(rs) > 1:
                try:
                    checked_in = float(rs[1]['Data'][0]['VarCharValue'])
                    do_repair = checked_in + hours * 60 * 60 < time.time()
                except:
                    pass

            if do_repair:
                tables = Config().get('audit.athena.tables', 'audit').split(',')
                self.repair(tables, show_sql=True)
                log2(f'Audit tables have been auto-repaired.')

    def repair(self, tables: list[str], show_sql = False):
        with concurrent.futures.ThreadPoolExecutor(max_workers=Config().get('audit.workers', 3)) as executor:
            for table in tables:
                if show_sql:
                    log(f'MSCK REPAIR TABLE {table}')

                executor.submit(run_audit_query, f'MSCK REPAIR TABLE {table}', None,)
            executor.submit(self.check_in,)

    def help(self, _: ReplState):
        return f"{AuditRepairTables.COMMAND} \t run MSCK REPAIR command for new partition discovery"

    def check_in(self):
        payload = {
            'action': 'check-in'
        }
        audit_endpoint = Config().get("audit.endpoint", "https://4psvtaxlcb.execute-api.us-west-2.amazonaws.com/prod/")
        try:
            response = requests.post(audit_endpoint, json=payload, timeout=Config().get("audit.timeout", 10))
            if response.status_code in [200, 201]:
                Config().debug(response.text)
            else:
                log2(f"Error: {response.status_code} {response.text}")
        except requests.exceptions.Timeout as e:
            log2(f"Timeout occurred: {e}")