import json
from typing import Any, Callable, Dict, Optional, List

import psycopg2
from psycopg2.extras import LogicalReplicationConnection
from psycopg2 import ProgrammingError

from dbqueue.utils import log


class PostgresWalSource:
    """
    Reads WAL changes via PostgreSQL logical replication
    and produces normalized events.

    Pluggable decoder: wal2json / pgoutput
    """

    def __init__(
        self,
        dsn: str,
        slot_name: str = "dbqueue_slot",
        publication: str = "dbqueue_pub",
        output_plugin: str = "wal2json",
        on_event: Optional[Callable[[Dict[str, Any]], None]] = None,
    ) -> None:
        self.dsn = dsn
        self.slot_name = slot_name
        self.publication = publication
        self.output_plugin = output_plugin
        self.on_event = on_event

        self._conn: Optional[LogicalReplicationConnection] = None
        self._cur = None
        self._stopped = False
        self.decoder = None  # will be set in start()

    # ------------------------------------------------------------------

    def start(self) -> None:
        import time
        retry_count = 0
        
        while not self._stopped:
            try:
                self._connect()
                self._ensure_slot()
                
                # Decoder + replication options based on plugin
                if self.output_plugin == "wal2json":
                    from dbqueue.sources.decoders.wal2json import Wal2JsonDecoder
                    self.decoder = Wal2JsonDecoder()
                    decode_flag = True
                    options = {
                        "format-version": "2",
                        "actions": "insert,update,delete",
                        "include-pk": "1",
                    }

                elif self.output_plugin == "pgoutput":
                    from dbqueue.sources.decoders.pgoutput import PgOutputDecoder
                    self.decoder = PgOutputDecoder()
                    decode_flag = False
                    options = {
                        "proto_version": "1",
                        "publication_names": self.publication,
                    }

                else:
                    raise ValueError(f"Unsupported output_plugin: {self.output_plugin}")

                log(
                    f"📡 Starting WAL streaming: slot={self.slot_name}, "
                    f"publication={self.publication}, plugin={self.output_plugin}"
                )

                self._cur.start_replication(
                    slot_name=self.slot_name,
                    decode=decode_flag,
                    options=options,
                )

                # Reset counter on successful connection
                retry_count = 0
                self._cur.consume_stream(self._consume)

            except KeyboardInterrupt:
                self.stop()
                break
            except Exception as e:
                if self._stopped:
                    break
                
                retry_count += 1
                wait_time = min(2 ** retry_count, 60) # Wait max 60 sec
                log(f"❌ Postgres connection lost: {e}. Retrying in {wait_time}s...")
                
                # Cleanup
                try:
                    if self._cur: self._cur.close()
                    if self._conn: self._conn.close()
                except:
                    pass
                
                time.sleep(wait_time)

    def stop(self) -> None:
        self._stopped = True
        try:
            if self._cur:
                self._cur.close()
        except Exception:
            pass
        try:
            if self._conn:
                self._conn.close()
        except Exception:
            pass
        log("⏹ PostgresWalSource stopped.")

    # ------------------------------------------------------------------

    def _connect(self) -> None:
        self._conn = psycopg2.connect(
            self.dsn,
            connection_factory=LogicalReplicationConnection,
        )
        self._cur = self._conn.cursor()
        log("🔌 Connected to PostgreSQL (logical replication).")

    def _ensure_slot(self) -> None:
        try:
            self._cur.create_replication_slot(
                self.slot_name,
                output_plugin=self.output_plugin,
            )
            log(f"🧩 Created replication slot: {self.slot_name}")
        except ProgrammingError as e:
            self._conn.rollback()
            if "already exists" in str(e):
                log(f"ℹ️ Replication slot already exists: {self.slot_name}")
            else:
                raise

    # ------------------------------------------------------------------
    # 🔥 Calling the decoder here
    # ------------------------------------------------------------------

    def _consume(self, msg) -> None:
        if self._stopped:
            msg.cursor.send_feedback(flush_lsn=msg.data_start)
            return

        try:
            events: List[Dict[str, Any]] = self.decoder.decode(msg)
            for evt in events:
                if self.on_event:
                    self.on_event(evt)

        except Exception as e:
            log(f"❌ Error decoding WAL message: {e}")

        # Ack the LSN
        msg.cursor.send_feedback(flush_lsn=msg.data_start)

        if self._stopped:
            raise KeyboardInterrupt()
