"""
Message Bus implementation with pluggable backends.

Provides RedisBus (production) and MemoryBus (testing/dev) implementations.
All agent communication and task distribution happens through this abstraction.
"""

from __future__ import annotations

import asyncio
from dataclasses import dataclass, field
from datetime import datetime
import json
import time
from typing import Any
import uuid


try:
    import redis.asyncio as aioredis
except ImportError:
    aioredis = None

try:
    from aiokafka import AIOKafkaConsumer, AIOKafkaProducer  # type: ignore
except ImportError:  # optional dependency
    AIOKafkaProducer = None  # type: ignore
    AIOKafkaConsumer = None  # type: ignore


@dataclass
class TaskMessage:
    """Message sent to an agent's task queue."""

    task_id: str
    agent_name: str
    payload: dict[str, Any]
    created_at: float = field(default_factory=time.time)

    def to_dict(self) -> dict:
        return {
            "task_id": self.task_id,
            "agent_name": self.agent_name,
            "payload": self.payload,
            "created_at": self.created_at,
        }

    @classmethod
    def from_dict(cls, data: dict) -> TaskMessage:
        return cls(
            task_id=data["task_id"],
            agent_name=data["agent_name"],
            payload=data["payload"],
            created_at=data.get("created_at", time.time()),
        )


@dataclass
class ResponseMessage:
    """Response from an agent after processing a task."""

    task_id: str
    status: str  # "success" | "error" | "pending"
    result: dict[str, Any] | None = None
    error: str | None = None
    created_at: float = field(default_factory=time.time)

    def to_dict(self) -> dict:
        return {
            "task_id": self.task_id,
            "status": self.status,
            "result": self.result,
            "error": self.error,
            "created_at": self.created_at,
        }

    @classmethod
    def from_dict(cls, data: dict) -> ResponseMessage:
        return cls(
            task_id=data["task_id"],
            status=data["status"],
            result=data.get("result"),
            error=data.get("error"),
            created_at=data.get("created_at", time.time()),
        )


class RedisBus:
    """
    Redis-based message bus for production use.
    
    Uses Redis Streams for task queues and pub/sub for responses.
    Agents are registered in a hash with metadata and heartbeats.
    """

    def __init__(self, redis_url: str):
        self.redis_url = redis_url
        self._client = None

    async def _get_client(self) -> Any:
        """Lazy connection to Redis."""
        if self._client is None:
            if aioredis is None:
                raise RuntimeError("redis package not installed. Install with: pip install redis")
            self._client = await aioredis.from_url(self.redis_url, decode_responses=True)
        return self._client

    async def register_agent(self, name: str, metadata: dict) -> bool:
        """
        Register an agent with metadata.
        
        Metadata should include: role, goal, tools, status, host_url (optional).
        """
        client = await self._get_client()
        metadata["last_heartbeat"] = time.time()
        metadata["status"] = metadata.get("status", "active")

        await client.hset(
            "evenage:agents:registry",
            name,
            json.dumps(metadata)
        )
        return True

    async def publish_task(self, agent_name: str, task: dict) -> str:
        """
        Publish a task to an agent's queue.
        
        Returns task_id for tracking.
        """
        client = await self._get_client()
        task_id = str(uuid.uuid4())

        message = TaskMessage(
            task_id=task_id,
            agent_name=agent_name,
            payload=task
        )

        # Add to agent's stream
        await client.xadd(
            f"evenage:tasks:{agent_name}",
            {"data": json.dumps(message.to_dict())}
        )

        return task_id

    async def publish_response(self, task_id: str, response: dict) -> bool:
        """
        Publish a response for a task.
        
        Stores in hash and publishes to response channel for waiters.
        If large-response offload is configured, store payload in storage and
        publish a small pointer instead.
        """
        client = await self._get_client()

        # Optionally offload large responses to storage
        try:
            storage = getattr(self, "_storage", None)
            bucket = getattr(self, "_storage_bucket", None) or "evenage"
            threshold_kb = getattr(self, "_storage_threshold_kb", None)

            payload_bytes = json.dumps(response).encode("utf-8")
            if storage and threshold_kb and len(payload_bytes) > threshold_kb * 1024:
                # Build a key with date prefix for grouping
                date_prefix = datetime.utcnow().strftime("%Y/%m/%d")
                key = f"responses/{date_prefix}/{task_id}.json"
                metadata = {"content-type": "application/json"}
                try:
                    await storage.ensure_bucket(bucket)
                    await storage.put_object(bucket, key, payload_bytes, metadata=metadata)
                    response = {
                        "task_id": task_id,
                        "offloaded": True,
                        "bucket": bucket,
                        "key": key,
                        "size_bytes": len(payload_bytes),
                    }
                except Exception:
                    # Fallback: keep inline if storage fails
                    pass
        except Exception:
            pass

        # Store response
        await client.setex(
            f"evenage:response:{task_id}",
            300,  # 5 minute TTL
            json.dumps(response)
        )

        # Notify waiters via pub/sub
        await client.publish(
            f"evenage:response:{task_id}",
            json.dumps(response)
        )

        return True

    async def consume_tasks(self, agent_name: str, block_ms: int, count: int) -> list[dict]:
        """
        Consume tasks from agent's queue.
        
        Blocks up to block_ms milliseconds waiting for tasks.
        Returns up to count tasks.
        """
        client = await self._get_client()
        stream_name = f"evenage:tasks:{agent_name}"

        # Read from stream
        try:
            result = await client.xread(
                {stream_name: "$"},  # Read new messages
                count=count,
                block=block_ms
            )

            if not result:
                return []

            tasks = []
            for stream, messages in result:
                for msg_id, msg_data in messages:
                    data = json.loads(msg_data["data"])
                    tasks.append(data)

                    # ACK by deleting (simple approach)
                    await client.xdel(stream_name, msg_id)

            return tasks
        except Exception:
            return []

    async def wait_for_response(self, task_id: str, timeout_sec: int) -> dict | None:
        """
        Wait for a response to a task (for synchronous chat endpoints).
        
        Uses pub/sub to listen for response, with fallback to polling.
        """
        client = await self._get_client()

        # First check if response already exists
        existing = await client.get(f"evenage:response:{task_id}")
        if existing:
            return json.loads(existing)

        # Subscribe to response channel
        pubsub = client.pubsub()
        await pubsub.subscribe(f"evenage:response:{task_id}")

        try:
            start = time.time()
            while time.time() - start < timeout_sec:
                message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
                if message and message["type"] == "message":
                    return json.loads(message["data"])

                # Fallback: check storage
                existing = await client.get(f"evenage:response:{task_id}")
                if existing:
                    return json.loads(existing)

            return None
        finally:
            await pubsub.unsubscribe(f"evenage:response:{task_id}")
            await pubsub.close()

    async def get_registered_agents(self) -> dict[str, dict]:
        """Get all registered agents with metadata."""
        client = await self._get_client()

        agents_raw = await client.hgetall("evenage:agents:registry")
        agents = {}

        for name, metadata_json in agents_raw.items():
            agents[name] = json.loads(metadata_json)

        return agents

    async def get_queue_depth(self, agent_name: str) -> int:
        """Get number of pending tasks for an agent."""
        client = await self._get_client()

        try:
            length = await client.xlen(f"evenage:tasks:{agent_name}")
            return length
        except Exception:
            return 0

    async def health_check(self) -> bool:
        """Check if Redis is reachable."""
        try:
            client = await self._get_client()
            await client.ping()
            return True
        except Exception:
            return False


class MemoryBus:
    """
    In-memory message bus for testing and quick local dev.
    
    Not suitable for production or distributed systems.
    """

    def __init__(self):
        self._agents: dict[str, dict] = {}
        self._tasks: dict[str, list[dict]] = {}
        self._responses: dict[str, dict] = {}
        self._waiters: dict[str, list[asyncio.Future]] = {}

    async def register_agent(self, name: str, metadata: dict) -> bool:
        """Register an agent with metadata."""
        metadata["last_heartbeat"] = time.time()
        metadata["status"] = metadata.get("status", "active")
        self._agents[name] = metadata

        if name not in self._tasks:
            self._tasks[name] = []

        return True

    async def publish_task(self, agent_name: str, task: dict) -> str:
        """Publish a task to an agent's in-memory queue."""
        task_id = str(uuid.uuid4())
        message = TaskMessage(
            task_id=task_id, agent_name=agent_name, payload=task
        ).to_dict()
        self._tasks.setdefault(agent_name, []).append(message)
        return task_id

    async def publish_response(self, task_id: str, response: dict) -> bool:
        """Publish a response for a task, notifying any waiters; supports optional offload metadata."""
        # Optional offload: emulate RedisBus behavior by rewriting payload to pointer if needed
        try:
            storage = getattr(self, "_storage", None)
            bucket = getattr(self, "_storage_bucket", None) or "evenage"
            threshold_kb = getattr(self, "_storage_threshold_kb", None)
            payload_bytes = json.dumps(response).encode("utf-8")
            if storage and threshold_kb and len(payload_bytes) > threshold_kb * 1024:
                date_prefix = datetime.utcnow().strftime("%Y/%m/%d")
                key = f"responses/{date_prefix}/{task_id}.json"
                metadata = {"content-type": "application/json"}
                try:
                    await storage.ensure_bucket(bucket)
                    await storage.put_object(bucket, key, payload_bytes, metadata=metadata)
                    response = {
                        "task_id": task_id,
                        "offloaded": True,
                        "bucket": bucket,
                        "key": key,
                        "size_bytes": len(payload_bytes),
                    }
                except Exception:
                    pass
        except Exception:
            pass

        self._responses[task_id] = response
        # Notify waiters
        if task_id in self._waiters:
            for fut in self._waiters[task_id]:
                if not fut.done():
                    fut.set_result(response)
            del self._waiters[task_id]
        return True

    async def consume_tasks(self, agent_name: str, block_ms: int, count: int) -> list[dict]:
        """Consume tasks from an agent's queue, simulating blocking behavior."""
        queue = self._tasks.get(agent_name, [])
        if not queue:
            if block_ms > 0:
                await asyncio.sleep(block_ms / 1000.0)
            return []
        tasks = queue[:count]
        self._tasks[agent_name] = queue[count:]
        return tasks

    async def wait_for_response(self, task_id: str, timeout_sec: int) -> dict | None:
        """Wait for a response published via publish_response."""
        if task_id in self._responses:
            return self._responses[task_id]
        fut: asyncio.Future = asyncio.get_event_loop().create_future()
        self._waiters.setdefault(task_id, []).append(fut)
        try:
            return await asyncio.wait_for(fut, timeout=timeout_sec)
        except asyncio.TimeoutError:
            return None

    async def get_registered_agents(self) -> dict[str, dict]:
        return dict(self._agents)

    async def get_queue_depth(self, agent_name: str) -> int:
        return len(self._tasks.get(agent_name, []))

    async def health_check(self) -> bool:
        return True


class KafkaBus:
    """
    Kafka-based message bus (experimental / enterprise option).

    Uses topics per agent: evenage.tasks.<agent_name>
    Responses published to: evenage.responses
    Agent registry is memory-only in this minimal implementation.
    """

    def __init__(self, bootstrap_servers: str):
        if AIOKafkaProducer is None or AIOKafkaConsumer is None:
            raise RuntimeError("aiokafka is not installed. Install with: pip install aiokafka")
        self.bootstrap_servers = bootstrap_servers
        self._producer = None
        self._agents: dict[str, dict] = {}
        self._waiters: dict[str, list[asyncio.Future]] = {}

    async def _get_producer(self) -> Any:
        if self._producer is None:
            self._producer = AIOKafkaProducer(bootstrap_servers=self.bootstrap_servers)
            await self._producer.start()
        return self._producer

    async def register_agent(self, name: str, metadata: dict) -> bool:
        metadata["last_heartbeat"] = time.time()
        metadata["status"] = metadata.get("status", "active")
        self._agents[name] = metadata
        return True

    async def publish_task(self, agent_name: str, task: dict) -> str:
        producer = await self._get_producer()
        task_id = str(uuid.uuid4())
        message = TaskMessage(task_id=task_id, agent_name=agent_name, payload=task).to_dict()
        data = json.dumps(message).encode("utf-8")
        await producer.send_and_wait(f"evenage.tasks.{agent_name}", data)
        return task_id

    async def publish_response(self, task_id: str, response: dict) -> bool:
        # Optional offload similar to other buses
        try:
            storage = getattr(self, "_storage", None)
            bucket = getattr(self, "_storage_bucket", None) or "evenage"
            threshold_kb = getattr(self, "_storage_threshold_kb", None)
            payload_bytes = json.dumps(response).encode("utf-8")
            if storage and threshold_kb and len(payload_bytes) > threshold_kb * 1024:
                date_prefix = datetime.utcnow().strftime("%Y/%m/%d")
                key = f"responses/{date_prefix}/{task_id}.json"
                metadata = {"content-type": "application/json"}
                try:
                    await storage.ensure_bucket(bucket)
                    await storage.put_object(bucket, key, payload_bytes, metadata=metadata)
                    response = {
                        "task_id": task_id,
                        "offloaded": True,
                        "bucket": bucket,
                        "key": key,
                        "size_bytes": len(payload_bytes),
                    }
                except Exception:
                    pass
        except Exception:
            pass

        # Notify in-memory waiters
        if task_id in self._waiters:
            for fut in self._waiters[task_id]:
                if not fut.done():
                    fut.set_result(response)
            del self._waiters[task_id]

        # Fire-and-forget to responses topic for external consumers
        try:
            producer = await self._get_producer()
            await producer.send_and_wait("evenage.responses", json.dumps(response).encode("utf-8"))
        except Exception:
            pass
        return True

    async def consume_tasks(self, agent_name: str, block_ms: int, count: int) -> list[dict]:
        consumer = AIOKafkaConsumer(  # type: ignore[name-defined]
            f"evenage.tasks.{agent_name}",
            bootstrap_servers=self.bootstrap_servers,
            enable_auto_commit=False,
            auto_offset_reset="latest",
            max_poll_records=count,
            request_timeout_ms=max(block_ms, 1000),
        )
        await consumer.start()
        tasks: list[dict] = []
        try:
            result = await consumer.getmany(timeout_ms=block_ms, max_records=count)
            for tp, messages in result.items():
                for msg in messages:
                    try:
                        data = json.loads(msg.value.decode("utf-8"))
                        tasks.append(data)
                    except Exception:
                        continue
            # commit offsets for consumed messages
            if tasks:
                await consumer.commit()
        finally:
            await consumer.stop()
        return tasks

    async def wait_for_response(self, task_id: str, timeout_sec: int) -> dict | None:
        # Minimal in-memory waiter; external consumers can listen on topic
        if task_id in getattr(self, "_responses", {}):
            return self._responses[task_id]
        fut = asyncio.Future()
        self._waiters.setdefault(task_id, []).append(fut)
        try:
            return await asyncio.wait_for(fut, timeout=timeout_sec)
        except asyncio.TimeoutError:
            return None

    async def get_registered_agents(self) -> dict[str, dict]:
        return dict(self._agents)

    async def get_queue_depth(self, agent_name: str) -> int:
        # Not tracked in this minimal implementation
        return 0

    async def health_check(self) -> bool:
        # Producer init check
        try:
            await self._get_producer()
            return True
        except Exception:
            return False
