"""
A utility that executes test cases generated by `dorieh.platform.dbt.create_test.py`
tool.
"""

#  Copyright (c) 2023.  Harvard University
#
#   Developed by Research Software Engineering,
#   Harvard University Research Computing and Data (RCD) Services.
#
#   Author: Michael A Bouzinier
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#          http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
#

import logging
import os.path
from typing import List
from abc import abstractmethod, ABC

from dorieh.platform import init_logging
from dorieh.platform.db import Connection
from dorieh.platform.dbt.dbt_config import DBTConfig


class TestFailedError(Exception):
    pass


class DBTRunner(ABC):
    def __init__(self, context: DBTConfig = None):
        if not context:
            context = DBTConfig(None, __doc__).instantiate()
        self.context = context
        self.scripts = self.context.script
        self.test_names = [
            os.path.splitext(os.path.basename(t))[0] for t in self.scripts
        ]
        init_logging(name="run-tests-" + "-".join(self.test_names))
        self.runs = 0
        self.successes = 0
        self.failures = 0

    def reset(self):
        self.runs = 0
        self.successes = 0
        self.failures = 0

    def analyze_results(self, columns: List, rows: List):
        pi = columns.index("passed")
        n = len(columns)
        lengths = [0 for _ in range(n)]
        passes = 0
        failures = 0
        test_cases = []
        for row in rows:
            values = [row[i] for i in range(n)]
            if row[pi]:
                passes += 1
                values[pi] = "passed"
            else:
                failures += 1
                values[pi] = "failed"
            for i in range(n):
                if len(values[i]) > lengths[i]:
                    lengths[i] = len(str(values[i]))
            test_cases.append(values)
        lengths = [l + 1 for l in lengths]
        logging.info(self.report_row(columns, lengths))
        for row in test_cases:
            s = self.report_row(row, lengths)
            if row[pi] == "passed":
                logging.info(s)
            elif row[pi] == "failed":
                logging.error(s)
            else:
                logging.warning(s)
        logging.info("Passed: {:d}; Failed: {:d}".format(passes, failures))
        self.runs      += len(test_cases)
        self.successes += passes
        self.failures  += failures

    @abstractmethod
    def run(self):
        pass

    @classmethod
    def report_row(cls, row: List, lengths: List[int]) -> str:
        s = ""
        for i in range(len(lengths)):
            cell = str(row[i]).ljust(lengths[i]) + '\t'
            s += cell
        return s

    def test(self):
        self.reset()
        self.run()
        if self.failures > 0:
            err = TestFailedError(f"There are {str(self.failures)} failures")
            logging.error(f"Tests FAILED: {err}")
            raise err
        logging.info("All tests succeeded")

    @classmethod
    def form_query(cls, script_file) -> str:
        with open(script_file) as script:
            lines = [line for line in script]
            query = ''.join(lines)
        return query


class PGDBTRunner(DBTRunner):
    def run(self):
        with Connection(self.context.db, self.context.connection) as cnxn:
            for script_file in self.scripts:
                self.run_script(self.form_query(script_file), cnxn)

    def run_script(self, query, cnxn):
        with cnxn.cursor() as cursor:
            cursor.execute(query)
            columns = [desc[0] for desc in cursor.description]
            rows = [row for row in cursor]
            self.analyze_results(columns, rows)
        return


if __name__ == '__main__':
    runner = PGDBTRunner()
    runner.test()
