"""
Connects to PostgreSQL database
"""

import os
import psycopg2
import sys
import platform
from dotenv import load_dotenv
from psycopg2.extras import execute_values
from cryptography.fernet import Fernet


if platform.system() == "Windows":
    home_dir = os.getenv('HOMEPATH')
else:
    home_dir = os.getenv('HOME')

env_file = os.path.join(home_dir, '.pgenv')
if os.path.isfile(env_file):
    load_dotenv(env_file)



class Connect:
    def __init__(self, schema: str = None, table: str = None,
                 host: str = None,
                 port: int = None,
                 db: str = None,
                 user: str = None,
                 password: str = None
                 ) -> None:
        """
        Establish connection to PostgreSQL server.

        :param schema optional. Schema name to connect to. Required if you need to update specific table
        :param table optional. Table name to connect to. Required if you need to update specific table
        
        :param host or set PG_HOST env variable - Server hostname.
        :param port or set PG_PORT env variable - Server port.
        :param db or set PG_DB env variable - Database name to connect. 
        :param user or set PG_USER env variable - Username to connect.
        :param password or set PG_PASSWD env variable - Password for the username.
        
        Example:
            db = pyconnpg.Connect(
                host='hostname',
                db='database',
                user='username',
                password='Password',
            )
        
        
        
        Connecting to specific table. 
        Example:
            db = pg.PG(schema='schema_name', table='table_name')
        """
        
        self.host = host
        self.port = port
        self.db = db
        self.user = user
        self.password = password
        self.schema = schema
        self.table = table

        if self.host is None:
            self.host = os.getenv('PG_HOST')
        if self.port is None:
            self.port = os.getenv('PG_PORT')
        if self.db is None:
            self.db = os.getenv('PG_DB')
        if self.user is None:
            self.user = os.getenv('PG_USER')
        if self.password is None:
            pgPasswd = os.getenv('PG_PASSWD')
            try:
                if '=@' in pgPasswd:
                    key, passwd = pgPasswd.split('@')
                    cipher_suite = Fernet(key)
                    self.password = cipher_suite.decrypt(passwd).decode()
                else:
                    self.password = pgPasswd
            except TypeError:
                pgPasswd = ""
                
        # Create connection to postgres
        try:
            self.conn = psycopg2.connect(host = self.host,
                                        port = self.port,
                                        database = self.db,
                                        user = self.user,
                                        password = self.password)
            self.cursor = self.conn.cursor()

        except psycopg2.Error as e:
            print(e)
            print('ERROR: PGSQL connection failed.')
            sys.exit(1)
            
    
    @property
    def version(self) -> str:
        """
        Returns:
            str: Postgres version
        """
        self.cursor.execute('SELECT VERSION()')
        result = self.cursor.fetchone()[0]
        return result
    

    def check_exists(self,column_name: str, find) -> bool:
        """
        Return True if it exists and False if not.
        column_name: where to find
        find: string or charater to find
        """
        query = f"select exists(select 1 from {self.schema}.{self.table} where {column_name}='{find}')"
        
        self.cursor.execute(query)
        result = self.cursor.fetchone()[0]
        return result
    
    def check_exists_long(self,where: str) -> bool:
        """
        Return True if it exists and False if not.
        where: multiple where and/or clause
        """
        query = f"select exists(select 1 from {self.schema}.{self.table} where {where})"
        
        self.cursor.execute(query)
        result = self.cursor.fetchone()[0]
        return result
    
    
    def insert(self, **kwargs):
        """
        Input (table_column_name=value) to insert.
        Note: column names should NOT be enclosed with quotes
        

        Ex. db.insert(product='ProductName', version=3)
        """
        columns = []
        values = []
        for column, value in kwargs.items():
            columns.append(column)
            values.append(value)

        # Remove quotations from column names
        # because it will cause an error during insert
        columns = ', '.join(columns)

        # Execute insert
        if len(values) == 1:
            # only one element to insert
            # using tuple with one element will cause an error due to comma - ('test',)
            self.cursor.execute(f"INSERT INTO {self.schema}.{self.table} ({columns}) VALUES (%s)",
                                (values[0],))

        else:
            # Inserting more than one values
            placeholders = ', '.join(['%s'] * len(values))
            self.cursor.execute(f"INSERT INTO {self.schema}.{self.table} \
                            ({columns}) VALUES ({placeholders})",
                            values)
            
        self.conn.commit()
        

    def batch_insert(self, data:list[dict]) -> bool:
        """
        Batch insert for PostgreSQL using "execute_values".
        Expects data like:
            [
                {"id": "2025-10-21_1", "date": "2025-10-21", "users": "user1", "usage": 3},
                {"id": "2025-10-21_2", "date": "2025-10-21", "users": "user2", "usage": 5},
            ]
           
        Example: Add dict records to list 
        records_to_insert = []
        for count, (user, usage) in enumerate(checkout_users.items(), start=1):
            record = {
                "id": f"{DT_UTC}_{count}",
                "date": DT_UTC,
                "users": user,
                "usage": usage,
                "check_conflict": "id"
            }
            records_to_insert.append(record)    
            
            
        Returns True when successful else False    
        """
        # Extract column names from first record
        if len(data) == 0:
            return False
        columns = data[0].keys()
        col_names = ', '.join(columns)

        # Convert list of dicts -> list of tuples (values only)
        values = [tuple(record[col] for col in columns) for record in data]
        
        # Build SQL statement
        sql = f"""
            INSERT INTO {self.schema}.{self.table} ({col_names})
            VALUES %s
        """
        
        # Execute batch insert
        execute_values(self.cursor, sql, values)
        self.conn.commit()
        
        return True
            


    def insert_many(self, data:list[dict]) -> bool:
        """
        Batch insert for PostgreSQL.
        Expects data like:
            [
                {"id": "2025-10-21_1", "date": "2025-10-21", "users": "user1", "usage": 3},
                {"id": "2025-10-21_2", "date": "2025-10-21", "users": "user2", "usage": 5},
            ]
           
        Example: Add dict records to list 
        records_to_insert = []
        for count, (user, usage) in enumerate(checkout_users.items(), start=1):
            record = {
                "id": f"{DT_UTC}_{count}",
                "date": DT_UTC,
                "users": user,
                "usage": usage,
                "check_conflict": "id"
            }
            records_to_insert.append(record)    
            
        """

        # Extract column names from first record
        if len(data) == 0:
            return False
        columns = data[0].keys()
        placeholders = ', '.join(['%s'] * len(columns))
        col_names = ', '.join(columns)

        # Convert list of dicts -> list of tuples (values only)
        values = [tuple(record[col] for col in columns) for record in data]
        
        # Build SQL statement
        sql = f"""
            INSERT INTO {self.schema}.{self.table} ({col_names})
            VALUES ({placeholders})
        """
        
        # Execute batch insert
        self.cursor.executemany(sql, values)
        self.conn.commit()
        return True
        
    def insert_with_conflict_check(self, check_conflict, **kwargs):
        """
        Input (table_column_name=value) to insert.
        Note: column names should not be enclosed in quotes
        param: check_conflict is the name of the primary key column name

        Ex. db.insert_with_conflict_check(check_conflict='column_pkey', product='Versionvault', version=3)
        """
        columns = []
        values = []
        for column, value in kwargs.items():
            columns.append(column)
            values.append(value)

        # Remove quotations from column names
        # because it will cause an error during insert
        columns = ', '.join(columns)
        
        # Execute insert
        if len(values) == 1:
            # only one element to insert
            # using tuple with one element will cause an error due to comma - ('test',)
            self.cursor.execute(f"INSERT INTO {self.schema}.{self.table} ({columns}) \
                VALUES (%s) ON CONFLICT ({check_conflict}) DO NOTHING",
                (values[0],))

        else:
            # Inserting more than one values
            placeholders = ', '.join(['%s'] * len(values))
            self.cursor.execute(f"INSERT INTO {self.schema}.{self.table} ({columns}) \
                VALUES ({placeholders}) ON CONFLICT ({check_conflict}) DO NOTHING", values)
            
        self.conn.commit()
    
    
    def update_row(self, column_name: str, value, where: str):
        """
        Update the specific column data where the specified column exists
        column_name: column name to update
        value: value to insert/update in column. String must enclose with ''
        where: specify which specific row. Ex. id=1
        """
        self.cursor.execute(f"UPDATE {self.schema}.{self.table} SET {column_name}='{value}' WHERE {where}")
        self.conn.commit()


    def add_column(self, column_name: str, col_type: str):
        """
        Add a new column into the table.
        column_name: column name to add
        type: type of column. Ex. varchar, integer or date
        """
        self.cursor.execute(f"ALTER TABLE {self.schema}.{self.table} ADD COLUMN IF NOT EXISTS {column_name} {col_type}")
        self.conn.commit()


    def get_value_match(self, column_name: str, where_column: str, match: str) -> str:
        """
        Return the column name value of the where_column query is match.

        """
        self.cursor.execute(f"Select {column_name} from {self.schema}.{self.table} where {where_column}='{match}'")
        value = self.cursor.fetchone()[0]
        return value
    
    
    def delete_row(self, condition: str):
        """
        Delete rows based on condition.

        DELETE FROM {self.schema}.{self.table} WHERE {condition}
        """
        self.cursor.execute(f"DELETE FROM {self.schema}.{self.table} WHERE {condition}")
        self.conn.commit()


    def close(self) -> None:
        """
        Close the connection
        """
        self.conn.close()
        