"""
OpenAI integration with cost calculation.
"""

import time
from typing import Dict, Any, Optional

from .exceptions import OpenAIError

try:
    import tiktoken
    from openai import OpenAI
    OPENAI_AVAILABLE = True
except ImportError:
    OPENAI_AVAILABLE = False
    tiktoken = None
    OpenAI = None


class OpenAIClient:
    """OpenAI client with cost calculation and token counting."""
    
    # Pricing per 1K tokens (as of 2024)
    PRICING = {
        "gpt-4": {"input": 0.03, "output": 0.06},
        "gpt-4-turbo": {"input": 0.01, "output": 0.03},
        "gpt-3.5-turbo": {"input": 0.0015, "output": 0.002},
        "gpt-3.5-turbo-16k": {"input": 0.003, "output": 0.004},
    }
    
    def __init__(self, api_key: Optional[str] = None):
        """Initialize OpenAI client."""
        if not OPENAI_AVAILABLE:
            raise OpenAIError("OpenAI and tiktoken packages are required. Install with: pip install openai tiktoken")
        self.client = OpenAI(api_key=api_key)
    
    def count_tokens(self, text: str, model: str = "gpt-3.5-turbo") -> int:
        """Count tokens in text for the given model."""
        try:
            encoding = tiktoken.encoding_for_model(model)
            return len(encoding.encode(text))
        except KeyError:
            # Fallback to cl100k_base for unknown models
            encoding = tiktoken.get_encoding("cl100k_base")
            return len(encoding.encode(text))
    
    def calculate_cost(self, input_tokens: int, output_tokens: int, model: str) -> float:
        """Calculate cost for the given token usage."""
        if model not in self.PRICING:
            # Use gpt-3.5-turbo pricing as fallback
            model = "gpt-3.5-turbo"
        
        pricing = self.PRICING[model]
        input_cost = (input_tokens / 1000) * pricing["input"]
        output_cost = (output_tokens / 1000) * pricing["output"]
        return input_cost + output_cost
    
    def complete(
        self,
        prompt: str,
        model: str = "gpt-3.5-turbo",
        temperature: float = 0.7,
        max_tokens: Optional[int] = None,
        **kwargs
    ) -> Dict[str, Any]:
        """Complete a prompt using OpenAI API."""
        start_time = time.time()
        
        try:
            # Count input tokens
            input_tokens = self.count_tokens(prompt, model)
            
            # Make API call
            response = self.client.chat.completions.create(
                model=model,
                messages=[{"role": "user", "content": prompt}],
                temperature=temperature,
                max_tokens=max_tokens,
                **kwargs
            )
            
            # Calculate metrics
            end_time = time.time()
            latency_ms = int((end_time - start_time) * 1000)
            
            output_text = response.choices[0].message.content
            output_tokens = self.count_tokens(output_text, model)
            total_tokens = input_tokens + output_tokens
            cost_usd = self.calculate_cost(input_tokens, output_tokens, model)
            
            return {
                "text": output_text,
                "tokens_used": total_tokens,
                "input_tokens": input_tokens,
                "output_tokens": output_tokens,
                "cost_usd": cost_usd,
                "latency_ms": latency_ms,
                "model": model,
                "usage": response.usage.model_dump() if response.usage else {}
            }
            
        except Exception as e:
            raise OpenAIError(f"OpenAI API call failed: {str(e)}")
    
    def complete_with_messages(
        self,
        messages: list,
        model: str = "gpt-3.5-turbo",
        temperature: float = 0.7,
        max_tokens: Optional[int] = None,
        **kwargs
    ) -> Dict[str, Any]:
        """Complete using a list of messages."""
        start_time = time.time()
        
        try:
            # Count input tokens (approximate)
            input_text = "\n".join([msg.get("content", "") for msg in messages])
            input_tokens = self.count_tokens(input_text, model)
            
            # Make API call
            response = self.client.chat.completions.create(
                model=model,
                messages=messages,
                temperature=temperature,
                max_tokens=max_tokens,
                **kwargs
            )
            
            # Calculate metrics
            end_time = time.time()
            latency_ms = int((end_time - start_time) * 1000)
            
            output_text = response.choices[0].message.content
            output_tokens = self.count_tokens(output_text, model)
            total_tokens = input_tokens + output_tokens
            cost_usd = self.calculate_cost(input_tokens, output_tokens, model)
            
            return {
                "text": output_text,
                "tokens_used": total_tokens,
                "input_tokens": input_tokens,
                "output_tokens": output_tokens,
                "cost_usd": cost_usd,
                "latency_ms": latency_ms,
                "model": model,
                "usage": response.usage.model_dump() if response.usage else {}
            }
            
        except Exception as e:
            raise OpenAIError(f"OpenAI API call failed: {str(e)}")