import json
import threading
import pika
from typing import Callable, Dict
from dbqueue.utils import log


class RabbitConsumer:
    def __init__(
        self,
        url: str,
        exchange: str,
        routing_key_template: str,
        handlers: Dict[str, Dict[str, Callable]],
    ):
        self.url = url
        self.exchange = exchange
        self.routing_key_template = routing_key_template
        self.handlers = handlers
        self._stop = False

    def start(self):
        threading.Thread(target=self._run, daemon=True).start()
        log("🟢 Rabbit consumer started")

    def stop(self):
        self._stop = True
        log("🔴 Rabbit consumer stopping...")

    def _run(self):
        import time
        
        while not self._stop:
            try:
                self._consume_loop()
            except Exception as e:
                log(f"❌ Consumer connection lost: {e}. Reconnecting in 5s...")
                time.sleep(5)
        
        log("⛔ Rabbit consumer closed")

    def _consume_loop(self):
        params = pika.URLParameters(self.url)
        connection = pika.BlockingConnection(params)
        channel = connection.channel()

        queue_name = "dbqueue.consumer"
        channel.queue_declare(queue=queue_name, durable=True)

        channel.queue_bind(
            queue=queue_name,
            exchange=self.exchange,
            routing_key="#"
        )

        def callback(ch, method, properties, body):
            try:
                event = json.loads(body)  # formatted CDC event
                rk = method.routing_key.split(".")

                # Routing key format: schema.table.op
                # But sometimes it might be different, let's split safely
                if len(rk) >= 3:
                    schema, table, op = rk[0], rk[1], rk[2].upper()
                    
                    if op in self.handlers and table in self.handlers[op]:
                        try:
                            self.handlers[op][table](event)
                        except Exception as handler_err:
                            log(f"❌ User handler error for {table}: {handler_err}")
                            # We ack even if there is an error to avoid infinite loop (Poison Pill)
                            # Ideally DLX (Dead Letter Exchange) should be used.
                    else:
                        # log(f"⚠️ No handler for {method.routing_key}")
                        pass
                
                ch.basic_ack(delivery_tag=method.delivery_tag)

            except Exception as e:
                log(f"❌ Consumer processing error: {e}")
                # If we don't ack, the message will come again. In this case, reject(requeue=False) might be better
                # For now, safe side: ack
                ch.basic_ack(delivery_tag=method.delivery_tag)

        channel.basic_consume(queue=queue_name, on_message_callback=callback)

        while not self._stop and connection.is_open:
            try:
                connection.process_data_events(time_limit=1)
            except pika.exceptions.AMQPError:
                break
            except Exception:
                break
        
        if connection.is_open:
            connection.close()
