"""
Storage backends for PromptSim.
"""

import json
import sqlite3
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Optional, Dict, Any
from datetime import datetime

from .models import PromptTemplate, PromptExecution, Experiment, BusinessMetric
from .exceptions import StorageError, PromptNotFoundError


class StorageBackend(ABC):
    """Abstract base class for storage backends."""
    
    @abstractmethod
    def save_prompt(self, prompt: PromptTemplate) -> None:
        """Save a prompt template."""
        pass
    
    @abstractmethod
    def get_prompt(self, key: str, version: Optional[int] = None) -> PromptTemplate:
        """Get a prompt template by key and version."""
        pass
    
    @abstractmethod
    def list_prompts(self, key: Optional[str] = None) -> List[PromptTemplate]:
        """List prompt templates."""
        pass
    
    @abstractmethod
    def save_execution(self, execution: PromptExecution) -> None:
        """Save a prompt execution."""
        pass
    
    @abstractmethod
    def get_executions(self, prompt_key: Optional[str] = None, limit: int = 100) -> List[PromptExecution]:
        """Get prompt executions."""
        pass


class SQLiteStorage(StorageBackend):
    """SQLite storage backend."""
    
    def __init__(self, db_path: str = "./promptsim.db"):
        self.db_path = db_path
        self._init_db()
    
    def _init_db(self):
        """Initialize the database schema."""
        with sqlite3.connect(self.db_path) as conn:
            conn.execute("""
                CREATE TABLE IF NOT EXISTS prompt_templates (
                    id TEXT PRIMARY KEY,
                    key TEXT NOT NULL,
                    name TEXT NOT NULL,
                    template TEXT NOT NULL,
                    model TEXT NOT NULL,
                    version INTEGER NOT NULL,
                    parameters TEXT,
                    tags TEXT,
                    is_active BOOLEAN DEFAULT TRUE,
                    created_at TEXT NOT NULL,
                    created_by TEXT NOT NULL,
                    description TEXT,
                    UNIQUE(key, version)
                )
            """)
            
            conn.execute("""
                CREATE TABLE IF NOT EXISTS prompt_executions (
                    id TEXT PRIMARY KEY,
                    prompt_id TEXT NOT NULL,
                    prompt_key TEXT NOT NULL,
                    prompt_version INTEGER NOT NULL,
                    user_id TEXT,
                    input_variables TEXT,
                    input_text TEXT NOT NULL,
                    output_text TEXT NOT NULL,
                    model TEXT NOT NULL,
                    tokens_used INTEGER NOT NULL,
                    cost_usd REAL NOT NULL,
                    latency_ms INTEGER NOT NULL,
                    error TEXT,
                    metadata TEXT,
                    executed_at TEXT NOT NULL,
                    experiment_id TEXT,
                    variant_name TEXT
                )
            """)
            
            conn.execute("""
                CREATE TABLE IF NOT EXISTS experiments (
                    id TEXT PRIMARY KEY,
                    name TEXT NOT NULL,
                    prompt_key TEXT NOT NULL,
                    description TEXT,
                    variants TEXT,
                    status TEXT NOT NULL,
                    traffic_allocation TEXT,
                    start_date TEXT,
                    end_date TEXT,
                    created_at TEXT NOT NULL,
                    created_by TEXT NOT NULL,
                    success_metrics TEXT
                )
            """)
            
            conn.execute("""
                CREATE TABLE IF NOT EXISTS business_metrics (
                    id TEXT PRIMARY KEY,
                    execution_id TEXT NOT NULL,
                    user_id TEXT NOT NULL,
                    metric_name TEXT NOT NULL,
                    metric_value REAL NOT NULL,
                    metric_type TEXT NOT NULL,
                    recorded_at TEXT NOT NULL,
                    metadata TEXT
                )
            """)
            
            conn.commit()
    
    def save_prompt(self, prompt: PromptTemplate) -> None:
        """Save a prompt template."""
        try:
            with sqlite3.connect(self.db_path) as conn:
                conn.execute("""
                    INSERT OR REPLACE INTO prompt_templates 
                    (id, key, name, template, model, version, parameters, tags, 
                     is_active, created_at, created_by, description)
                    VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
                """, (
                    prompt.id, prompt.key, prompt.name, prompt.template, prompt.model,
                    prompt.version, json.dumps(prompt.parameters), json.dumps(prompt.tags),
                    prompt.is_active, prompt.created_at.isoformat(), prompt.created_by,
                    prompt.description
                ))
                conn.commit()
        except sqlite3.Error as e:
            raise StorageError(f"Failed to save prompt: {e}")
    
    def get_prompt(self, key: str, version: Optional[int] = None) -> PromptTemplate:
        """Get a prompt template by key and version."""
        try:
            with sqlite3.connect(self.db_path) as conn:
                if version is None:
                    # Get the active version or latest version
                    cursor = conn.execute("""
                        SELECT * FROM prompt_templates 
                        WHERE key = ? AND is_active = TRUE
                        ORDER BY version DESC LIMIT 1
                    """, (key,))
                else:
                    cursor = conn.execute("""
                        SELECT * FROM prompt_templates 
                        WHERE key = ? AND version = ?
                    """, (key, version))
                
                row = cursor.fetchone()
                if not row:
                    raise PromptNotFoundError(f"Prompt '{key}' version {version} not found")
                
                return PromptTemplate(
                    id=row[0], key=row[1], name=row[2], template=row[3], model=row[4],
                    version=row[5], parameters=json.loads(row[6] or "{}"),
                    tags=json.loads(row[7] or "[]"), is_active=bool(row[8]),
                    created_at=datetime.fromisoformat(row[9]), created_by=row[10],
                    description=row[11] or ""
                )
        except sqlite3.Error as e:
            raise StorageError(f"Failed to get prompt: {e}")
    
    def list_prompts(self, key: Optional[str] = None) -> List[PromptTemplate]:
        """List prompt templates."""
        try:
            with sqlite3.connect(self.db_path) as conn:
                if key:
                    cursor = conn.execute("""
                        SELECT * FROM prompt_templates 
                        WHERE key = ? ORDER BY version DESC
                    """, (key,))
                else:
                    cursor = conn.execute("""
                        SELECT * FROM prompt_templates 
                        ORDER BY key, version DESC
                    """)
                
                prompts = []
                for row in cursor.fetchall():
                    prompts.append(PromptTemplate(
                        id=row[0], key=row[1], name=row[2], template=row[3], model=row[4],
                        version=row[5], parameters=json.loads(row[6] or "{}"),
                        tags=json.loads(row[7] or "[]"), is_active=bool(row[8]),
                        created_at=datetime.fromisoformat(row[9]), created_by=row[10],
                        description=row[11] or ""
                    ))
                return prompts
        except sqlite3.Error as e:
            raise StorageError(f"Failed to list prompts: {e}")
    
    def save_execution(self, execution: PromptExecution) -> None:
        """Save a prompt execution."""
        try:
            with sqlite3.connect(self.db_path) as conn:
                conn.execute("""
                    INSERT INTO prompt_executions 
                    (id, prompt_id, prompt_key, prompt_version, user_id, input_variables,
                     input_text, output_text, model, tokens_used, cost_usd, latency_ms,
                     error, metadata, executed_at, experiment_id, variant_name)
                    VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
                """, (
                    execution.id, execution.prompt_id, execution.prompt_key,
                    execution.prompt_version, execution.user_id,
                    json.dumps(execution.input_variables), execution.input_text,
                    execution.output_text, execution.model, execution.tokens_used,
                    execution.cost_usd, execution.latency_ms, execution.error,
                    json.dumps(execution.metadata), execution.executed_at.isoformat(),
                    execution.experiment_id, execution.variant_name
                ))
                conn.commit()
        except sqlite3.Error as e:
            raise StorageError(f"Failed to save execution: {e}")
    
    def get_executions(self, prompt_key: Optional[str] = None, limit: int = 100) -> List[PromptExecution]:
        """Get prompt executions."""
        try:
            with sqlite3.connect(self.db_path) as conn:
                if prompt_key:
                    cursor = conn.execute("""
                        SELECT * FROM prompt_executions 
                        WHERE prompt_key = ? 
                        ORDER BY executed_at DESC LIMIT ?
                    """, (prompt_key, limit))
                else:
                    cursor = conn.execute("""
                        SELECT * FROM prompt_executions 
                        ORDER BY executed_at DESC LIMIT ?
                    """, (limit,))
                
                executions = []
                for row in cursor.fetchall():
                    executions.append(PromptExecution(
                        id=row[0], prompt_id=row[1], prompt_key=row[2],
                        prompt_version=row[3], user_id=row[4],
                        input_variables=json.loads(row[5] or "{}"),
                        input_text=row[6], output_text=row[7], model=row[8],
                        tokens_used=row[9], cost_usd=row[10], latency_ms=row[11],
                        error=row[12], metadata=json.loads(row[13] or "{}"),
                        executed_at=datetime.fromisoformat(row[14]),
                        experiment_id=row[15], variant_name=row[16]
                    ))
                return executions
        except sqlite3.Error as e:
            raise StorageError(f"Failed to get executions: {e}")


class FileStorage(StorageBackend):
    """JSON file storage backend."""
    
    def __init__(self, data_dir: str = "./prompt_data"):
        self.data_dir = Path(data_dir)
        self.data_dir.mkdir(exist_ok=True)
        self.prompts_file = self.data_dir / "prompts.json"
        self.executions_file = self.data_dir / "executions.json"
        
        # Initialize files if they don't exist
        if not self.prompts_file.exists():
            self.prompts_file.write_text("[]")
        if not self.executions_file.exists():
            self.executions_file.write_text("[]")
    
    def _load_prompts(self) -> List[Dict[str, Any]]:
        """Load prompts from file."""
        try:
            return json.loads(self.prompts_file.read_text())
        except (json.JSONDecodeError, FileNotFoundError):
            return []
    
    def _save_prompts(self, prompts: List[Dict[str, Any]]) -> None:
        """Save prompts to file."""
        self.prompts_file.write_text(json.dumps(prompts, indent=2, default=str))
    
    def save_prompt(self, prompt: PromptTemplate) -> None:
        """Save a prompt template."""
        prompts = self._load_prompts()
        
        # Convert prompt to dict
        prompt_dict = {
            "id": prompt.id,
            "key": prompt.key,
            "name": prompt.name,
            "template": prompt.template,
            "model": prompt.model,
            "version": prompt.version,
            "parameters": prompt.parameters,
            "tags": prompt.tags,
            "is_active": prompt.is_active,
            "created_at": prompt.created_at.isoformat(),
            "created_by": prompt.created_by,
            "description": prompt.description
        }
        
        # Update existing or add new
        updated = False
        for i, p in enumerate(prompts):
            if p["key"] == prompt.key and p["version"] == prompt.version:
                prompts[i] = prompt_dict
                updated = True
                break
        
        if not updated:
            prompts.append(prompt_dict)
        
        self._save_prompts(prompts)
    
    def get_prompt(self, key: str, version: Optional[int] = None) -> PromptTemplate:
        """Get a prompt template by key and version."""
        prompts = self._load_prompts()
        
        if version is None:
            # Get active version or latest
            matching = [p for p in prompts if p["key"] == key]
            if not matching:
                raise PromptNotFoundError(f"Prompt '{key}' not found")
            
            # Try to find active version first
            active = [p for p in matching if p.get("is_active", True)]
            if active:
                prompt_dict = max(active, key=lambda x: x["version"])
            else:
                prompt_dict = max(matching, key=lambda x: x["version"])
        else:
            matching = [p for p in prompts if p["key"] == key and p["version"] == version]
            if not matching:
                raise PromptNotFoundError(f"Prompt '{key}' version {version} not found")
            prompt_dict = matching[0]
        
        return PromptTemplate(
            id=prompt_dict["id"],
            key=prompt_dict["key"],
            name=prompt_dict["name"],
            template=prompt_dict["template"],
            model=prompt_dict["model"],
            version=prompt_dict["version"],
            parameters=prompt_dict.get("parameters", {}),
            tags=prompt_dict.get("tags", []),
            is_active=prompt_dict.get("is_active", True),
            created_at=datetime.fromisoformat(prompt_dict["created_at"]),
            created_by=prompt_dict.get("created_by", "system"),
            description=prompt_dict.get("description", "")
        )
    
    def list_prompts(self, key: Optional[str] = None) -> List[PromptTemplate]:
        """List prompt templates."""
        prompts = self._load_prompts()
        
        if key:
            prompts = [p for p in prompts if p["key"] == key]
        
        result = []
        for prompt_dict in prompts:
            result.append(PromptTemplate(
                id=prompt_dict["id"],
                key=prompt_dict["key"],
                name=prompt_dict["name"],
                template=prompt_dict["template"],
                model=prompt_dict["model"],
                version=prompt_dict["version"],
                parameters=prompt_dict.get("parameters", {}),
                tags=prompt_dict.get("tags", []),
                is_active=prompt_dict.get("is_active", True),
                created_at=datetime.fromisoformat(prompt_dict["created_at"]),
                created_by=prompt_dict.get("created_by", "system"),
                description=prompt_dict.get("description", "")
            ))
        
        return sorted(result, key=lambda x: (x.key, -x.version))
    
    def save_execution(self, execution: PromptExecution) -> None:
        """Save a prompt execution."""
        try:
            executions = json.loads(self.executions_file.read_text())
        except (json.JSONDecodeError, FileNotFoundError):
            executions = []
        
        execution_dict = {
            "id": execution.id,
            "prompt_id": execution.prompt_id,
            "prompt_key": execution.prompt_key,
            "prompt_version": execution.prompt_version,
            "user_id": execution.user_id,
            "input_variables": execution.input_variables,
            "input_text": execution.input_text,
            "output_text": execution.output_text,
            "model": execution.model,
            "tokens_used": execution.tokens_used,
            "cost_usd": execution.cost_usd,
            "latency_ms": execution.latency_ms,
            "error": execution.error,
            "metadata": execution.metadata,
            "executed_at": execution.executed_at.isoformat(),
            "experiment_id": execution.experiment_id,
            "variant_name": execution.variant_name
        }
        
        executions.append(execution_dict)
        self.executions_file.write_text(json.dumps(executions, indent=2, default=str))
    
    def get_executions(self, prompt_key: Optional[str] = None, limit: int = 100) -> List[PromptExecution]:
        """Get prompt executions."""
        try:
            executions = json.loads(self.executions_file.read_text())
        except (json.JSONDecodeError, FileNotFoundError):
            return []
        
        if prompt_key:
            executions = [e for e in executions if e["prompt_key"] == prompt_key]
        
        # Sort by executed_at descending and limit
        executions = sorted(executions, key=lambda x: x["executed_at"], reverse=True)[:limit]
        
        result = []
        for exec_dict in executions:
            result.append(PromptExecution(
                id=exec_dict["id"],
                prompt_id=exec_dict["prompt_id"],
                prompt_key=exec_dict["prompt_key"],
                prompt_version=exec_dict["prompt_version"],
                user_id=exec_dict.get("user_id"),
                input_variables=exec_dict.get("input_variables", {}),
                input_text=exec_dict["input_text"],
                output_text=exec_dict["output_text"],
                model=exec_dict["model"],
                tokens_used=exec_dict["tokens_used"],
                cost_usd=exec_dict["cost_usd"],
                latency_ms=exec_dict["latency_ms"],
                error=exec_dict.get("error"),
                metadata=exec_dict.get("metadata", {}),
                executed_at=datetime.fromisoformat(exec_dict["executed_at"]),
                experiment_id=exec_dict.get("experiment_id"),
                variant_name=exec_dict.get("variant_name")
            ))
        
        return result