"""
Google Vertex AI wrapper for LangSpend tracking
"""

import time
import asyncio
import json
from typing import Any, Optional
from ..client import LangSpend
from ..types import LangSpendTags, FeatureContext, TrackingData
from ..detection import detect_endpoint, detect_callsite, normalize_feature_name


def detect_feature_context(langspend_tags: Optional[LangSpendTags] = None, request: Optional[Any] = None) -> FeatureContext:
    """
    Detect feature context from the current execution environment.
    Priority: Manual → Endpoint → Function → Uncategorized
    """
    try:
        # Priority 1: Manual override via langspend_tags.feature
        if langspend_tags and langspend_tags.feature:
            return FeatureContext(
                feature_name=langspend_tags.feature.name,
                feature_source="manual",
                endpoint=None,
                call_site=None,
            )
        
        # Priority 2: Endpoint detection (HTTP contexts)
        endpoint = detect_endpoint(request)
        if endpoint:
            feature_name = normalize_feature_name(endpoint)
            return FeatureContext(
                feature_name=feature_name,
                feature_source="endpoint",
                endpoint=endpoint,
                call_site=None,
            )
        
        # Priority 3: Call site detection (fallback for non-HTTP contexts)
        call_site = detect_callsite()
        if call_site:
            call_site_str = f"{call_site.file}:{call_site.function_name}"
            feature_name = normalize_feature_name(call_site.function_name)
            return FeatureContext(
                feature_name=feature_name,
                feature_source="function",
                endpoint=None,
                call_site=call_site_str,
            )
        
        # Priority 4: Uncategorized (detection failed)
        return FeatureContext(
            feature_name="uncategorized",
            feature_source="uncategorized",
            endpoint=None,
            call_site=None,
        )
    except Exception:
        # Graceful failure - never block LLM calls
        return FeatureContext(
            feature_name="uncategorized",
            feature_source="uncategorized",
            endpoint=None,
            call_site=None,
        )




def wrap_vertex(client: Any, langspend: LangSpend) -> Any:
    """
    Wrap a Google Vertex AI client to track requests with LangSpend
    
    Args:
        client: Google Vertex AI client instance
        langspend: LangSpend client instance
        
    Returns:
        Wrapped Vertex AI client with tracking enabled
        
    Example:
        ```python
        from google.cloud import vertexai
        from langspend import LangSpend, wrap_vertex
        
        vertex = vertexai.VertexAI(project="your-project", location="us-central1")
        langspend = LangSpend(api_key="lsp_your_api_key")
        wrapped = wrap_vertex(vertex, langspend)
        
        # Use as normal - tracking happens automatically
        model = wrapped.get_generative_model(model="claude-3-5-sonnet-20241022")
        response = model.generate_content(
            contents=[{"role": "user", "parts": [{"text": "Hello!"}]}],
            langspend_tags={"customer_id": "user_123"}
        )
        ```
    """
    # Store original get_generative_model method
    original_get_generative_model = client.get_generative_model
    
    def wrapped_get_generative_model(model_params: dict) -> Any:
        model = original_get_generative_model(model_params)
        
        # Store original generate_content method
        original_generate_content = model.generate_content
        
        def wrapped_generate_content(*args, **kwargs):
            start_time = time.time()
            
            # Extract langspend_tags from kwargs
            langspend_tags = kwargs.pop('langspend_tags', None)
            request = kwargs.pop('req', None)
            
            # Detect feature context
            feature_context = detect_feature_context(langspend_tags, request)
            
            # Make original LLM call
            response = original_generate_content(*args, **kwargs)
            
            # Extract usage data from response
            usage_metadata = getattr(response, 'response_metadata', {}).get('usage_metadata', {})
            if not usage_metadata:
                # No usage data, can't track
                return response
            
            # Prepare tracking data
            model_name = model_params.get('model', 'claude-sonnet-4.5')
            tracking_data = TrackingData(
                provider="google-vertex",
                model=model_name,
                input_tokens=usage_metadata.get('prompt_token_count', 0),
                output_tokens=usage_metadata.get('candidates_token_count', 0),
                cached_input_tokens=usage_metadata.get('cached_content_token_count', 0),
                metadata={
                    "customer": langspend_tags.customer.__dict__ if langspend_tags and hasattr(langspend_tags, 'customer') and langspend_tags.customer else None,
                    "feature": langspend_tags.feature.__dict__ if langspend_tags and hasattr(langspend_tags, 'feature') and langspend_tags.feature else None,
                    "session_id": getattr(langspend_tags, 'session_id', None) if langspend_tags else None,
                    "environment": getattr(langspend_tags, 'environment', None) if langspend_tags else None,
                    "finish_reason": getattr(response.candidates[0], 'finish_reason', None) if response.candidates else None,
                    "safety_ratings": [rating.__dict__ for rating in response.candidates[0].safety_ratings] if response.candidates and response.candidates[0].safety_ratings else None,
                    "feature_name": feature_context.feature_name,
                    "feature_source": feature_context.feature_source,
                    "endpoint": feature_context.endpoint,
                    "call_site": feature_context.call_site,
                },
                tags=langspend_tags,
                timestamp=time.strftime('%Y-%m-%dT%H:%M:%S.%fZ'),
                latency=(time.time() - start_time) * 1000,
            )
            
            # Track asynchronously (non-blocking)
            asyncio.create_task(langspend.track(tracking_data))
            
            return response
        
        # Replace the method
        model.generate_content = wrapped_generate_content
        
        # Handle streaming if available
        if hasattr(model, 'generate_content_stream'):
            original_generate_content_stream = model.generate_content_stream
            
            def wrapped_generate_content_stream(*args, **kwargs):
                start_time = time.time()
                first_chunk = True
                
                # Extract langspend_tags from kwargs
                langspend_tags = kwargs.pop('langspend_tags', None)
                request = kwargs.pop('req', None)
                
                # Detect feature context
                feature_context = detect_feature_context(langspend_tags, request)
                
                # Stream the response
                for chunk in original_generate_content_stream(*args, **kwargs):
                    if first_chunk:
                        first_chunk = False
                        # Track streaming request after first chunk
                        def track_streaming():
                            usage_metadata = getattr(chunk, 'response_metadata', {}).get('usage_metadata', {})
                            if usage_metadata:
                                model_name = model_params.get('model', 'claude-sonnet-4.5')
                                cost = calculate_vertex_cost(
                                    model_name,
                                    usage_metadata.get('prompt_token_count', 0),
                                    usage_metadata.get('candidates_token_count', 0),
                                    usage_metadata.get('cached_content_token_count', 0)
                                )
                                
                                tracking_data = TrackingData(
                                    provider="google-vertex",
                                    model=model_name,
                                    input_tokens=usage_metadata.get('prompt_token_count', 0),
                                    output_tokens=usage_metadata.get('candidates_token_count', 0),
                                    cached_input_tokens=usage_metadata.get('cached_content_token_count', 0),
                                    metadata={
                                        "customer": langspend_tags.customer.__dict__ if langspend_tags and hasattr(langspend_tags, 'customer') and langspend_tags.customer else None,
                                        "feature": langspend_tags.feature.__dict__ if langspend_tags and hasattr(langspend_tags, 'feature') and langspend_tags.feature else None,
                                        "session_id": getattr(langspend_tags, 'session_id', None) if langspend_tags else None,
                                        "environment": getattr(langspend_tags, 'environment', None) if langspend_tags else None,
                                        "finish_reason": getattr(chunk.candidates[0], 'finish_reason', None) if chunk.candidates else None,
                                        "safety_ratings": [rating.__dict__ for rating in chunk.candidates[0].safety_ratings] if chunk.candidates and chunk.candidates[0].safety_ratings else None,
                                        "feature_name": feature_context.feature_name,
                                        "feature_source": feature_context.feature_source,
                                        "endpoint": feature_context.endpoint,
                                        "call_site": feature_context.call_site,
                                        "streaming": True,
                                    },
                                    tags=langspend_tags,
                                    timestamp=time.strftime('%Y-%m-%dT%H:%M:%S.%fZ'),
                                    latency=(time.time() - start_time) * 1000,
                                )
                                
                                asyncio.create_task(langspend.track(tracking_data))
                        
                        # Track in background
                        asyncio.create_task(asyncio.to_thread(track_streaming))
                    
                    yield chunk
            
            model.generate_content_stream = wrapped_generate_content_stream
        
        return model
    
    # Replace the method
    client.get_generative_model = wrapped_get_generative_model
    
    return client
