"""Multi-model management and orchestration for pantsonfire v2.0"""

from typing import Optional, Dict, Any, Literal
from enum import Enum
from pydantic import BaseModel, Field

from .config import Config


class ModelTier(str, Enum):
    """Model tiers for different use cases"""
    FAST = "fast"           # Quick, cheap extraction (Haiku, Grok-beta)
    SEARCH = "search"       # Web-search enabled (Perplexity)
    REASONING = "reasoning" # Deep analysis (Opus, GPT-4, Grok-4)


class ModelConfig(BaseModel):
    """Configuration for a specific model"""
    name: str = Field(description="Model identifier for OpenRouter")
    tier: ModelTier = Field(description="Model tier classification")
    cost_per_1k_tokens: float = Field(description="Approximate cost per 1K tokens")
    supports_json: bool = Field(default=True, description="Supports structured JSON output")
    supports_web_search: bool = Field(default=False, description="Has web search capabilities")
    max_tokens: int = Field(default=4096, description="Maximum context window")
    recommended_for: list[str] = Field(default_factory=list, description="Use cases")


class ModelManager:
    """
    Manage multiple LLM models for cost optimization and capability matching.
    
    Strategy:
    - Use FAST models for bulk extraction
    - Use SEARCH models for fact verification
    - Use REASONING models for conflicts and complex analysis
    """
    
    # Predefined model configurations
    MODELS = {
        # Fast tier - for extraction
        "anthropic/claude-3-haiku": ModelConfig(
            name="anthropic/claude-3-haiku",
            tier=ModelTier.FAST,
            cost_per_1k_tokens=0.00025,
            supports_json=True,
            max_tokens=4096,
            recommended_for=["factoid_extraction", "classification", "simple_comparison"]
        ),
        
        # Search tier - for verification
        "perplexity/sonar-pro": ModelConfig(
            name="perplexity/sonar-pro",
            tier=ModelTier.SEARCH,
            cost_per_1k_tokens=0.003,
            supports_json=True,
            supports_web_search=True,
            max_tokens=4096,
            recommended_for=["fact_verification", "web_search", "current_info"]
        ),
        "perplexity/sonar": ModelConfig(
            name="perplexity/sonar",
            tier=ModelTier.SEARCH,
            cost_per_1k_tokens=0.001,
            supports_json=True,
            supports_web_search=True,
            max_tokens=4096,
            recommended_for=["fact_verification", "web_search"]
        ),
        
        # Reasoning tier - for complex analysis
        "anthropic/claude-opus-4": ModelConfig(
            name="anthropic/claude-opus-4",
            tier=ModelTier.REASONING,
            cost_per_1k_tokens=0.015,
            supports_json=True,
            max_tokens=8192,
            recommended_for=["conflict_resolution", "deep_analysis", "remediation"]
        ),
        "openai/gpt-4-turbo": ModelConfig(
            name="openai/gpt-4-turbo",
            tier=ModelTier.REASONING,
            cost_per_1k_tokens=0.01,
            supports_json=True,
            max_tokens=4096,
            recommended_for=["conflict_resolution", "complex_reasoning"]
        ),
        "x-ai/grok-2-1212": ModelConfig(
            name="x-ai/grok-2-1212",
            tier=ModelTier.REASONING,
            cost_per_1k_tokens=0.002,
            supports_json=True,
            max_tokens=8192,
            recommended_for=["reasoning", "analysis"]
        ),
    }
    
    def __init__(self, config: Config):
        self.config = config
        self.usage_stats = {
            "fast": {"calls": 0, "tokens": 0, "cost": 0.0},
            "search": {"calls": 0, "tokens": 0, "cost": 0.0},
            "reasoning": {"calls": 0, "tokens": 0, "cost": 0.0},
        }
    
    def get_model_for_task(self, task: str) -> str:
        """
        Get the best model for a specific task.
        
        Args:
            task: Task type (factoid_extraction, fact_verification, etc.)
        
        Returns:
            Model name to use
        """
        task_to_tier = {
            "factoid_extraction": ModelTier.FAST,
            "classification": ModelTier.FAST,
            "simple_comparison": ModelTier.FAST,
            "fact_verification": ModelTier.SEARCH,
            "web_search": ModelTier.SEARCH,
            "current_info": ModelTier.SEARCH,
            "conflict_resolution": ModelTier.REASONING,
            "deep_analysis": ModelTier.REASONING,
            "remediation": ModelTier.REASONING,
        }
        
        tier = task_to_tier.get(task, ModelTier.FAST)
        
        # Get available models for this tier
        tier_models = [
            (name, cfg) for name, cfg in self.MODELS.items()
            if cfg.tier == tier and task in cfg.recommended_for
        ]
        
        if not tier_models:
            # Fallback to any model in the tier
            tier_models = [
                (name, cfg) for name, cfg in self.MODELS.items()
                if cfg.tier == tier
            ]
        
        if not tier_models:
            # Ultimate fallback - use centralized default
            from .models_config import DEFAULT_MODEL
            return DEFAULT_MODEL
        
        # Return the cheapest model in the tier
        tier_models.sort(key=lambda x: x[1].cost_per_1k_tokens)
        return tier_models[0][0]
    
    def track_usage(
        self,
        model: str,
        tokens_used: int,
        tier: Optional[ModelTier] = None
    ) -> None:
        """Track model usage for cost monitoring"""
        if tier is None:
            # Determine tier from model
            model_config = self.MODELS.get(model)
            if model_config:
                tier = model_config.tier
            else:
                tier = ModelTier.FAST  # Default
        
        tier_key = tier.value
        if tier_key in self.usage_stats:
            self.usage_stats[tier_key]["calls"] += 1
            self.usage_stats[tier_key]["tokens"] += tokens_used
            
            # Calculate cost
            model_config = self.MODELS.get(model)
            if model_config:
                cost = (tokens_used / 1000) * model_config.cost_per_1k_tokens
                self.usage_stats[tier_key]["cost"] += cost
    
    def get_usage_summary(self) -> Dict[str, Any]:
        """Get summary of model usage and costs"""
        total_cost = sum(stats["cost"] for stats in self.usage_stats.values())
        total_calls = sum(stats["calls"] for stats in self.usage_stats.values())
        total_tokens = sum(stats["tokens"] for stats in self.usage_stats.values())
        
        return {
            "total_calls": total_calls,
            "total_tokens": total_tokens,
            "total_cost": total_cost,
            "by_tier": self.usage_stats,
            "avg_cost_per_call": total_cost / total_calls if total_calls > 0 else 0.0
        }
    
    def estimate_cost(
        self,
        num_factoids: int,
        use_web_search: bool = False,
        complexity: Literal["simple", "medium", "complex"] = "simple"
    ) -> Dict[str, Any]:
        """
        Estimate cost for an analysis job.
        
        Args:
            num_factoids: Expected number of factoids to process
            use_web_search: Whether web search will be used
            complexity: Job complexity level
        
        Returns:
            Cost estimate breakdown
        """
        # Token estimates per factoid
        tokens_per_factoid = {
            "simple": 500,   # Just extraction
            "medium": 1500,  # Extraction + verification
            "complex": 3000  # Extraction + verification + conflict resolution
        }
        
        tokens_needed = num_factoids * tokens_per_factoid[complexity]
        
        # Model selection - use centralized defaults
        from .models_config import DEFAULT_FAST_MODEL, DEFAULT_SEARCH_MODEL
        extraction_model = self.MODELS.get(DEFAULT_FAST_MODEL, self.MODELS["anthropic/claude-3-haiku"])
        search_model = self.MODELS.get(DEFAULT_SEARCH_MODEL, self.MODELS["perplexity/sonar-pro"])
        
        extraction_cost = (tokens_needed / 1000) * extraction_model.cost_per_1k_tokens
        search_cost = 0.0
        
        if use_web_search:
            search_cost = (tokens_needed / 1000) * search_model.cost_per_1k_tokens
        
        return {
            "num_factoids": num_factoids,
            "estimated_tokens": tokens_needed,
            "extraction_cost": extraction_cost,
            "search_cost": search_cost,
            "total_cost": extraction_cost + search_cost,
            "models_used": {
                "extraction": extraction_model.name,
                "search": search_model.name if use_web_search else None
            }
        }
    
    def should_use_reasoning_model(
        self,
        factoid_count: int,
        conflict_score: float,
        importance: float
    ) -> bool:
        """
        Decide if a reasoning model should be used.
        
        Args:
            factoid_count: Number of conflicting factoids
            conflict_score: How different the truth scores are
            importance: How important this claim is (0-1)
        
        Returns:
            True if reasoning model should be used
        """
        # Use reasoning model for:
        # 1. High-importance claims
        # 2. Significant conflicts
        # 3. Multiple conflicting sources
        
        if importance >= 0.8:
            return True
        
        if conflict_score >= 0.7 and factoid_count >= 3:
            return True
        
        return False


class BatchProcessor:
    """Process factoids in batches for efficiency"""
    
    def __init__(self, batch_size: int = 10, rate_limit_delay: float = 1.0):
        self.batch_size = batch_size
        self.rate_limit_delay = rate_limit_delay
    
    def process_in_batches(
        self,
        items: list,
        processor_func,
        progress_callback: Optional[callable] = None
    ) -> list:
        """
        Process items in batches with rate limiting.
        
        Args:
            items: Items to process
            processor_func: Function to call for each item
            progress_callback: Optional callback for progress updates
        
        Returns:
            List of processed results
        """
        import time
        
        results = []
        total_batches = (len(items) + self.batch_size - 1) // self.batch_size
        
        for batch_idx in range(0, len(items), self.batch_size):
            batch = items[batch_idx:batch_idx + self.batch_size]
            
            if progress_callback:
                progress_callback(batch_idx // self.batch_size + 1, total_batches)
            
            # Process batch
            for item in batch:
                result = processor_func(item)
                results.append(result)
            
            # Rate limiting
            if batch_idx + self.batch_size < len(items):
                time.sleep(self.rate_limit_delay)
        
        return results

