"""
Queue backends: RedisBus, MemoryBus, KafkaBus.

Moved from evenage.core.message_bus to unified backends module.
"""
from __future__ import annotations

# Content copied from core.message_bus

import asyncio
from dataclasses import dataclass, field
from datetime import datetime
import json
import time
from typing import Any
import uuid

try:
    import redis.asyncio as aioredis
except ImportError:
    aioredis = None

try:
    from aiokafka import AIOKafkaConsumer, AIOKafkaProducer  # type: ignore
except ImportError:  # optional dependency
    AIOKafkaProducer = None  # type: ignore
    AIOKafkaConsumer = None  # type: ignore


@dataclass
class TaskMessage:
    task_id: str
    agent_name: str
    payload: dict[str, Any]
    created_at: float = field(default_factory=time.time)

    def to_dict(self) -> dict:
        return {
            "task_id": self.task_id,
            "agent_name": self.agent_name,
            "payload": self.payload,
            "created_at": self.created_at,
        }

    @classmethod
    def from_dict(cls, data: dict) -> "TaskMessage":
        return cls(
            task_id=data["task_id"],
            agent_name=data["agent_name"],
            payload=data["payload"],
            created_at=data.get("created_at", time.time()),
        )


@dataclass
class ResponseMessage:
    task_id: str
    status: str  # "success" | "error" | "pending"
    result: dict[str, Any] | None = None
    error: str | None = None
    created_at: float = field(default_factory=time.time)

    def to_dict(self) -> dict:
        return {
            "task_id": self.task_id,
            "status": self.status,
            "result": self.result,
            "error": self.error,
            "created_at": self.created_at,
        }

    @classmethod
    def from_dict(cls, data: dict) -> "ResponseMessage":
        return cls(
            task_id=data["task_id"],
            status=data["status"],
            result=data.get("result"),
            error=data.get("error"),
            created_at=data.get("created_at", time.time()),
        )


class RedisBus:
    def __init__(self, redis_url: str):
        self.redis_url = redis_url
        self._client = None

    async def _get_client(self) -> Any:
        if self._client is None:
            if aioredis is None:
                raise RuntimeError("redis package not installed. Install with: pip install redis")
            self._client = await aioredis.from_url(self.redis_url, decode_responses=True)
        return self._client

    async def register_agent(self, name: str, metadata: dict) -> bool:
        client = await self._get_client()
        metadata["last_heartbeat"] = time.time()
        metadata["status"] = metadata.get("status", "active")
        await client.hset("evenage:agents:registry", name, json.dumps(metadata))
        return True

    async def publish_task(self, agent_name: str, task: dict) -> str:
        client = await self._get_client()
        task_id = str(uuid.uuid4())
        message = TaskMessage(task_id=task_id, agent_name=agent_name, payload=task)
        await client.xadd(f"evenage:tasks:{agent_name}", {"data": json.dumps(message.to_dict())})
        return task_id

    async def publish_response(self, task_id: str, response: dict) -> bool:
        client = await self._get_client()
        # Optional offload
        try:
            storage = getattr(self, "_storage", None)
            bucket = getattr(self, "_storage_bucket", None) or "evenage"
            threshold_kb = getattr(self, "_storage_threshold_kb", None)
            payload_bytes = json.dumps(response).encode("utf-8")
            if storage and threshold_kb and len(payload_bytes) > threshold_kb * 1024:
                date_prefix = datetime.utcnow().strftime("%Y/%m/%d")
                key = f"responses/{date_prefix}/{task_id}.json"
                metadata = {"content-type": "application/json"}
                try:
                    await storage.ensure_bucket(bucket)
                    await storage.put_object(bucket, key, payload_bytes, metadata=metadata)
                    response = {"task_id": task_id, "offloaded": True, "bucket": bucket, "key": key, "size_bytes": len(payload_bytes)}
                except Exception:
                    pass
        except Exception:
            pass
        await client.setex(f"evenage:response:{task_id}", 300, json.dumps(response))
        await client.publish(f"evenage:response:{task_id}", json.dumps(response))
        return True

    async def consume_tasks(self, agent_name: str, block_ms: int, count: int) -> list[dict]:
        client = await self._get_client()
        stream_name = f"evenage:tasks:{agent_name}"
        try:
            result = await client.xread({stream_name: "$"}, count=count, block=block_ms)
            if not result:
                return []
            tasks = []
            for stream, messages in result:
                for msg_id, msg_data in messages:
                    data = json.loads(msg_data["data"])
                    tasks.append(data)
                    await client.xdel(stream_name, msg_id)
            return tasks
        except Exception:
            return []

    async def wait_for_response(self, task_id: str, timeout_sec: int) -> dict | None:
        client = await self._get_client()
        existing = await client.get(f"evenage:response:{task_id}")
        if existing:
            return json.loads(existing)
        pubsub = client.pubsub()
        await pubsub.subscribe(f"evenage:response:{task_id}")
        try:
            start = time.time()
            while time.time() - start < timeout_sec:
                message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
                if message and message["type"] == "message":
                    return json.loads(message["data"])
                existing = await client.get(f"evenage:response:{task_id}")
                if existing:
                    return json.loads(existing)
            return None
        finally:
            await pubsub.unsubscribe(f"evenage:response:{task_id}")
            await pubsub.close()

    async def get_registered_agents(self) -> dict[str, dict]:
        client = await self._get_client()
        agents_raw = await client.hgetall("evenage:agents:registry")
        agents = {}
        for name, metadata_json in agents_raw.items():
            agents[name] = json.loads(metadata_json)
        return agents

    async def list_agents(self) -> list[dict]:
        registry = await self.get_registered_agents()
        agents_list: list[dict] = []
        for name, meta in registry.items():
            agents_list.append({"name": name, "role": meta.get("role"), "goal": meta.get("goal"), "status": meta.get("status", "unknown")})
        return agents_list

    async def get_queue_depth(self, agent_name: str) -> int:
        client = await self._get_client()
        try:
            length = await client.xlen(f"evenage:tasks:{agent_name}")
            return length
        except Exception:
            return 0

    async def health_check(self) -> bool:
        try:
            client = await self._get_client()
            await client.ping()
            return True
        except Exception:
            return False


class MemoryBus:
    def __init__(self):
        self._agents: dict[str, dict] = {}
        self._tasks: dict[str, list[dict]] = {}
        self._responses: dict[str, dict] = {}
        self._waiters: dict[str, list[asyncio.Future]] = {}

    async def register_agent(self, name: str, metadata: dict) -> bool:
        metadata["last_heartbeat"] = time.time()
        metadata["status"] = metadata.get("status", "active")
        self._agents[name] = metadata
        if name not in self._tasks:
            self._tasks[name] = []
        return True

    async def publish_task(self, agent_name: str, task: dict) -> str:
        task_id = str(uuid.uuid4())
        message = TaskMessage(task_id=task_id, agent_name=agent_name, payload=task).to_dict()
        self._tasks.setdefault(agent_name, []).append(message)
        return task_id

    async def publish_response(self, task_id: str, response: dict) -> bool:
        try:
            storage = getattr(self, "_storage", None)
            bucket = getattr(self, "_storage_bucket", None) or "evenage"
            threshold_kb = getattr(self, "_storage_threshold_kb", None)
            payload_bytes = json.dumps(response).encode("utf-8")
            if storage and threshold_kb and len(payload_bytes) > threshold_kb * 1024:
                date_prefix = datetime.utcnow().strftime("%Y/%m/%d")
                key = f"responses/{date_prefix}/{task_id}.json"
                metadata = {"content-type": "application/json"}
                try:
                    await storage.ensure_bucket(bucket)
                    await storage.put_object(bucket, key, payload_bytes, metadata=metadata)
                    response = {"task_id": task_id, "offloaded": True, "bucket": bucket, "key": key, "size_bytes": len(payload_bytes)}
                except Exception:
                    pass
        except Exception:
            pass
        self._responses[task_id] = response
        if task_id in self._waiters:
            for fut in self._waiters[task_id]:
                if not fut.done():
                    fut.set_result(response)
            del self._waiters[task_id]
        return True

    async def consume_tasks(self, agent_name: str, block_ms: int, count: int) -> list[dict]:
        queue = self._tasks.get(agent_name, [])
        if not queue:
            if block_ms > 0:
                await asyncio.sleep(block_ms / 1000.0)
            return []
        tasks = queue[:count]
        self._tasks[agent_name] = queue[count:]
        return tasks

    async def wait_for_response(self, task_id: str, timeout_sec: int) -> dict | None:
        if task_id in self._responses:
            return self._responses[task_id]
        fut: asyncio.Future = asyncio.get_event_loop().create_future()
        self._waiters.setdefault(task_id, []).append(fut)
        try:
            return await asyncio.wait_for(fut, timeout=timeout_sec)
        except asyncio.TimeoutError:
            return None

    async def get_registered_agents(self) -> dict[str, dict]:
        return dict(self._agents)

    async def list_agents(self) -> list[dict]:
        registry = await self.get_registered_agents()
        return [{"name": name, "role": meta.get("role"), "goal": meta.get("goal"), "status": meta.get("status", "unknown")} for name, meta in registry.items()]

    async def get_queue_depth(self, agent_name: str) -> int:
        return len(self._tasks.get(agent_name, []))

    async def health_check(self) -> bool:
        return True


class KafkaBus:
    def __init__(self, bootstrap_servers: str):
        if AIOKafkaProducer is None or AIOKafkaConsumer is None:
            raise RuntimeError("aiokafka is not installed. Install with: pip install aiokafka")
        self.bootstrap_servers = bootstrap_servers
        self._producer = None
        self._agents: dict[str, dict] = {}
        self._waiters: dict[str, list[asyncio.Future]] = {}

    async def _get_producer(self) -> Any:
        if self._producer is None:
            self._producer = AIOKafkaProducer(bootstrap_servers=self.bootstrap_servers)
            await self._producer.start()
        return self._producer

    async def register_agent(self, name: str, metadata: dict) -> bool:
        metadata["last_heartbeat"] = time.time()
        metadata["status"] = metadata.get("status", "active")
        self._agents[name] = metadata
        return True

    async def publish_task(self, agent_name: str, task: dict) -> str:
        producer = await self._get_producer()
        task_id = str(uuid.uuid4())
        message = TaskMessage(task_id=task_id, agent_name=agent_name, payload=task).to_dict()
        data = json.dumps(message).encode("utf-8")
        await producer.send_and_wait(f"evenage.tasks.{agent_name}", data)
        return task_id

    async def publish_response(self, task_id: str, response: dict) -> bool:
        try:
            storage = getattr(self, "_storage", None)
            bucket = getattr(self, "_storage_bucket", None) or "evenage"
            threshold_kb = getattr(self, "_storage_threshold_kb", None)
            payload_bytes = json.dumps(response).encode("utf-8")
            if storage and threshold_kb and len(payload_bytes) > threshold_kb * 1024:
                date_prefix = datetime.utcnow().strftime("%Y/%m/%d")
                key = f"responses/{date_prefix}/{task_id}.json"
                metadata = {"content-type": "application/json"}
                try:
                    await storage.ensure_bucket(bucket)
                    await storage.put_object(bucket, key, payload_bytes, metadata=metadata)
                    response = {"task_id": task_id, "offloaded": True, "bucket": bucket, "key": key, "size_bytes": len(payload_bytes)}
                except Exception:
                    pass
        except Exception:
            pass
        if task_id in self._waiters:
            for fut in self._waiters[task_id]:
                if not fut.done():
                    fut.set_result(response)
            del self._waiters[task_id]
        try:
            producer = await self._get_producer()
            await producer.send_and_wait("evenage.responses", json.dumps(response).encode("utf-8"))
        except Exception:
            pass
        return True

    async def consume_tasks(self, agent_name: str, block_ms: int, count: int) -> list[dict]:
        consumer = AIOKafkaConsumer(  # type: ignore[name-defined]
            f"evenage.tasks.{agent_name}",
            bootstrap_servers=self.bootstrap_servers,
            enable_auto_commit=False,
            auto_offset_reset="latest",
            max_poll_records=count,
            request_timeout_ms=max(block_ms, 1000),
        )
        await consumer.start()
        tasks: list[dict] = []
        try:
            result = await consumer.getmany(timeout_ms=block_ms, max_records=count)
            for tp, messages in result.items():
                for msg in messages:
                    try:
                        data = json.loads(msg.value.decode("utf-8"))
                        tasks.append(data)
                    except Exception:
                        continue
            if tasks:
                await consumer.commit()
        finally:
            await consumer.stop()
        return tasks

    async def wait_for_response(self, task_id: str, timeout_sec: int) -> dict | None:
        if task_id in getattr(self, "_responses", {}):
            return self._responses[task_id]
        fut = asyncio.Future()
        self._waiters.setdefault(task_id, []).append(fut)
        try:
            return await asyncio.wait_for(fut, timeout=timeout_sec)
        except asyncio.TimeoutError:
            return None

    async def get_registered_agents(self) -> dict[str, dict]:
        return dict(self._agents)

    async def list_agents(self) -> list[dict]:
        registry = await self.get_registered_agents()
        return [
            {"name": name, "role": meta.get("role"), "goal": meta.get("goal"), "status": meta.get("status", "unknown")} for name, meta in registry.items()
        ]

    async def get_queue_depth(self, agent_name: str) -> int:
        return 0

    async def health_check(self) -> bool:
        try:
            await self._get_producer()
            return True
        except Exception:
            return False
