"""
Database layer with SQLAlchemy models and high-level API.

Moved from evenage.core.database to unified backends module.
"""
from __future__ import annotations

from collections.abc import Generator
from contextlib import contextmanager
from datetime import datetime
from typing import Any
import uuid

from sqlalchemy import JSON, Column, DateTime, Integer, String, create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, sessionmaker


Base = declarative_base()


class Job(Base):
    __tablename__ = "jobs"
    job_id = Column(String(36), primary_key=True)
    pipeline_name = Column(String(255), nullable=False)
    status = Column(String(50), default="pending")
    inputs = Column(JSON, nullable=False)
    outputs = Column(JSON, nullable=True)
    created_at = Column(DateTime, default=datetime.utcnow)
    completed_at = Column(DateTime, nullable=True)


class PromptExecution(Base):
    __tablename__ = "prompt_executions"
    prompt_id = Column(String(36), primary_key=True)
    prompt_name = Column(String(255), nullable=False)
    status = Column(String(50), default="pending")
    inputs = Column(JSON, nullable=False)
    outputs = Column(JSON, nullable=True)
    created_at = Column(DateTime, default=datetime.utcnow)
    completed_at = Column(DateTime, nullable=True)


class Trace(Base):
    __tablename__ = "traces"
    id = Column(Integer, primary_key=True, autoincrement=True)
    job_id = Column(String(36), nullable=False, index=True)
    agent_name = Column(String(255), nullable=False, index=True)
    event_type = Column(String(100), nullable=False)
    payload = Column(JSON, nullable=False)
    timestamp = Column(DateTime, default=datetime.utcnow, index=True)


class Memory(Base):
    __tablename__ = "memory"
    id = Column(Integer, primary_key=True, autoincrement=True)
    agent_name = Column(String(255), nullable=False, index=True)
    job_id = Column(String(36), nullable=True, index=True)
    key = Column(String(255), nullable=False)
    value = Column(JSON, nullable=False)
    updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)


class AgentRegistry(Base):
    __tablename__ = "agent_registry"
    agent_name = Column(String(255), primary_key=True)
    meta = Column(JSON, nullable=False)
    last_seen = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)


class DatabaseService:
    def __init__(self, database_url: str):
        self.engine = create_engine(database_url, echo=False)
        self.SessionLocal = sessionmaker(bind=self.engine)
        Base.metadata.create_all(self.engine)

    def create_tables(self):
        Base.metadata.create_all(self.engine)

    @contextmanager
    def get_session(self) -> Generator[Session, None, None]:
        session = self.SessionLocal()
        try:
            yield session
            session.commit()
        except Exception:
            session.rollback()
            raise
        finally:
            session.close()

    def create_job(self, job_id: str | None, pipeline: str, inputs: dict) -> str:
        if job_id is None:
            job_id = str(uuid.uuid4())
        with self.get_session() as session:
            job = Job(job_id=job_id, pipeline_name=pipeline, status="pending", inputs=inputs)
            session.add(job)
        return job_id

    def save_result(self, job_id: str, outputs: dict, status: str = "completed") -> None:
        with self.get_session() as session:
            job = session.query(Job).filter_by(job_id=job_id).first()
            if job:
                job.outputs = outputs
                job.status = status
                job.completed_at = datetime.utcnow()

    def get_result(self, job_id: str) -> dict | None:
        with self.get_session() as session:
            job = session.query(Job).filter_by(job_id=job_id).first()
            if not job:
                return None
            return {
                "job_id": job.job_id,
                "pipeline_name": job.pipeline_name,
                "status": job.status,
                "inputs": job.inputs,
                "outputs": job.outputs,
                "created_at": job.created_at.isoformat() if job.created_at else None,
                "completed_at": job.completed_at.isoformat() if job.completed_at else None,
            }

    def list_jobs(self, limit: int = 50) -> list[dict]:
        with self.get_session() as session:
            jobs = session.query(Job).order_by(Job.created_at.desc()).limit(limit).all()
            return [
                {
                    "job_id": job.job_id,
                    "pipeline_name": job.pipeline_name,
                    "status": job.status,
                    "created_at": job.created_at.isoformat() if job.created_at else None,
                }
                for job in jobs
            ]

    def create_prompt(self, prompt_id: str | None, prompt_name: str, inputs: dict) -> str:
        if prompt_id is None:
            prompt_id = str(uuid.uuid4())
        with self.get_session() as session:
            prompt = PromptExecution(prompt_id=prompt_id, prompt_name=prompt_name, status="pending", inputs=inputs)
            session.add(prompt)
        return prompt_id

    def save_prompt_result(self, prompt_id: str, outputs: dict, status: str = "completed") -> None:
        with self.get_session() as session:
            prompt = session.query(PromptExecution).filter_by(prompt_id=prompt_id).first()
            if prompt:
                prompt.outputs = outputs
                prompt.status = status
                prompt.completed_at = datetime.utcnow()

    def get_prompt_result(self, prompt_id: str) -> dict | None:
        with self.get_session() as session:
            prompt = session.query(PromptExecution).filter_by(prompt_id=prompt_id).first()
            if not prompt:
                return None
            return {
                "prompt_id": prompt.prompt_id,
                "prompt_name": prompt.prompt_name,
                "status": prompt.status,
                "inputs": prompt.inputs,
                "outputs": prompt.outputs,
                "created_at": prompt.created_at.isoformat() if prompt.created_at else None,
                "completed_at": prompt.completed_at.isoformat() if prompt.completed_at else None,
            }

    def list_prompts(self, limit: int = 50) -> list[dict]:
        with self.get_session() as session:
            prompts = session.query(PromptExecution).order_by(PromptExecution.created_at.desc()).limit(limit).all()
            return [
                {
                    "prompt_id": prompt.prompt_id,
                    "prompt_name": prompt.prompt_name,
                    "status": prompt.status,
                    "created_at": prompt.created_at.isoformat() if prompt.created_at else None,
                }
                for prompt in prompts
            ]

    def get_job_traces(self, job_id: str) -> list[dict]:
        with self.get_session() as session:
            traces = session.query(Trace).filter_by(job_id=job_id).order_by(Trace.timestamp).all()
            return [
                {
                    "id": trace.id,
                    "job_id": trace.job_id,
                    "agent_name": trace.agent_name,
                    "event_type": trace.event_type,
                    "payload": trace.payload,
                    "timestamp": trace.timestamp.isoformat() if trace.timestamp else None,
                }
                for trace in traces
            ]

    def append_trace(self, job_id: str, agent_name: str, event_type: str, payload: dict) -> None:
        with self.get_session() as session:
            trace = Trace(job_id=job_id, agent_name=agent_name, event_type=event_type, payload=payload)
            session.add(trace)

    def list_traces(self, agent: str | None = None, limit: int = 100) -> list[dict]:
        with self.get_session() as session:
            query = session.query(Trace)
            if agent:
                query = query.filter_by(agent_name=agent)
            traces = query.order_by(Trace.timestamp.desc()).limit(limit).all()
            return [
                {
                    "id": trace.id,
                    "job_id": trace.job_id,
                    "agent_name": trace.agent_name,
                    "event_type": trace.event_type,
                    "payload": trace.payload,
                    "timestamp": trace.timestamp.isoformat() if trace.timestamp else None,
                }
                for trace in traces
            ]

    def get_trace(self, trace_id: str) -> dict | None:
        with self.get_session() as session:
            trace = session.query(Trace).filter_by(id=trace_id).first()
            if not trace:
                return None
            return {
                "id": trace.id,
                "job_id": trace.job_id,
                "agent_name": trace.agent_name,
                "event_type": trace.event_type,
                "payload": trace.payload,
                "timestamp": trace.timestamp.isoformat() if trace.timestamp else None,
            }

    def memory_put(self, agent_name: str, key: str, value: Any, job_id: str | None = None) -> None:
        with self.get_session() as session:
            existing = session.query(Memory).filter_by(agent_name=agent_name, key=key, job_id=job_id).first()
            if existing:
                existing.value = value
                existing.updated_at = datetime.utcnow()
            else:
                memory = Memory(agent_name=agent_name, job_id=job_id, key=key, value=value)
                session.add(memory)

    def memory_get(self, agent_name: str, key: str, job_id: str | None = None) -> Any:
        with self.get_session() as session:
            memory = session.query(Memory).filter_by(agent_name=agent_name, key=key, job_id=job_id).first()
            return memory.value if memory else None

    def memory_list(self, agent_name: str, job_id: str | None = None) -> dict[str, Any]:
        with self.get_session() as session:
            query = session.query(Memory).filter_by(agent_name=agent_name)
            if job_id:
                query = query.filter_by(job_id=job_id)
            memories = query.all()
            return {mem.key: mem.value for mem in memories}

    def register_agent(self, agent_name: str, metadata: dict) -> None:
        with self.get_session() as session:
            existing = session.query(AgentRegistry).filter_by(agent_name=agent_name).first()
            if existing:
                existing.meta = metadata
                existing.last_seen = datetime.utcnow()
            else:
                registry = AgentRegistry(agent_name=agent_name, meta=metadata)
                session.add(registry)

    def list_agents(self) -> list[dict]:
        with self.get_session() as session:
            agents = session.query(AgentRegistry).all()
            return [
                {
                    "agent_name": agent.agent_name,
                    "metadata": agent.meta,
                    "last_seen": agent.last_seen.isoformat() if agent.last_seen else None,
                }
                for agent in agents
            ]

    def get_metrics(self) -> dict[str, Any]:
        with self.get_session() as session:
            total_jobs = session.query(Job).count()
            completed_jobs = session.query(Job).filter_by(status="completed").count()
            failed_jobs = session.query(Job).filter_by(status="failed").count()
            completed = session.query(Job).filter_by(status="completed").all()
            latencies = []
            for job in completed:
                if job.created_at and job.completed_at:
                    delta = (job.completed_at - job.created_at).total_seconds()
                    latencies.append(delta)
            avg_latency_sec = sum(latencies) / len(latencies) if latencies else 0
            active_agents = session.query(AgentRegistry).count()
            tool_calls = session.query(Trace).filter_by(event_type="tool_call").count()
            cache_hits = session.query(Trace).filter_by(event_type="cache_hit").count()
            return {
                "total_jobs": total_jobs,
                "completed_jobs": completed_jobs,
                "failed_jobs": failed_jobs,
                "avg_latency_ms": int(avg_latency_sec * 1000),
                "active_agents_count": active_agents,
                "total_tools_called": tool_calls,
                "cache_hits": cache_hits,
            }
