"""
AWS Bedrock wrapper for LangSpend tracking
"""

import time
import asyncio
import json
from typing import Any, Optional
from ..client import LangSpend
from ..types import LangSpendTags, FeatureContext, TrackingData
from ..detection import detect_endpoint, detect_callsite, normalize_feature_name


def detect_feature_context(langspend_tags: Optional[LangSpendTags] = None, request: Optional[Any] = None) -> FeatureContext:
    """
    Detect feature context from the current execution environment.
    Priority: Manual → Endpoint → Function → Uncategorized
    """
    try:
        # Priority 1: Manual override via langspend_tags.feature
        if langspend_tags and langspend_tags.feature:
            return FeatureContext(
                feature_name=langspend_tags.feature.name,
                feature_source="manual",
                endpoint=None,
                call_site=None,
            )
        
        # Priority 2: Endpoint detection (HTTP contexts)
        endpoint = detect_endpoint(request)
        if endpoint:
            feature_name = normalize_feature_name(endpoint)
            return FeatureContext(
                feature_name=feature_name,
                feature_source="endpoint",
                endpoint=endpoint,
                call_site=None,
            )
        
        # Priority 3: Call site detection (fallback for non-HTTP contexts)
        call_site = detect_callsite()
        if call_site:
            call_site_str = f"{call_site.file}:{call_site.function_name}"
            feature_name = normalize_feature_name(call_site.function_name)
            return FeatureContext(
                feature_name=feature_name,
                feature_source="function",
                endpoint=None,
                call_site=call_site_str,
            )
        
        # Priority 4: Uncategorized (detection failed)
        return FeatureContext(
            feature_name="uncategorized",
            feature_source="uncategorized",
            endpoint=None,
            call_site=None,
        )
    except Exception:
        # Graceful failure - never block LLM calls
        return FeatureContext(
            feature_name="uncategorized",
            feature_source="uncategorized",
            endpoint=None,
            call_site=None,
        )




def parse_bedrock_response(response_body: str) -> dict:
    """
    Parse response from Bedrock Claude models
    """
    try:
        parsed = json.loads(response_body)
        
        # Handle Claude 3.x response format
        if 'usage' in parsed:
            return {
                'input_tokens': parsed['usage'].get('input_tokens', 0),
                'output_tokens': parsed['usage'].get('output_tokens', 0),
                'stop_reason': parsed.get('stop_reason'),
            }
        
        # Handle Claude 2.x response format
        if 'completion' in parsed:
            # Claude 2.x doesn't always return token counts, estimate if missing
            input_tokens = parsed.get('prompt_token_count', 0)
            output_tokens = parsed.get('completion_token_count', 0)
            if output_tokens == 0:
                output_tokens = max(1, len(parsed['completion']) // 4)  # Rough estimate
            return {
                'input_tokens': input_tokens,
                'output_tokens': output_tokens,
                'stop_reason': parsed.get('stop_reason'),
            }
        
        return {'input_tokens': 0, 'output_tokens': 0, 'stop_reason': None}
    except Exception as e:
        print(f"Failed to parse Bedrock response: {e}")
        return {'input_tokens': 0, 'output_tokens': 0, 'stop_reason': None}


def wrap_bedrock(client: Any, langspend: LangSpend) -> Any:
    """
    Wrap an AWS Bedrock client to automatically track requests.
    
    Args:
        client: Bedrock client instance
        langspend: LangSpend client instance
        
    Returns:
        Wrapped Bedrock client
    """
    # Store original send method
    original_send = client.send
    
    def wrapped_send(command: Any, **kwargs):
        """Wrapped send method"""
        # Only track InvokeModel commands
        if command.__class__.__name__ != 'InvokeModelCommand':
            return original_send(command, **kwargs)
        
        start_time = time.time()
        
        # Extract LangSpend-specific options
        langspend_tags = kwargs.pop('langspend_tags', None)
        request = kwargs.pop('request', None)
        
        # Detect feature context
        feature_context = detect_feature_context(langspend_tags, request)
        
        # Make original LLM call
        response = original_send(command, **kwargs)
        
        # Extract model ID and response body
        model_id = getattr(command, 'input', {}).get('modelId', 'unknown')
        
        # Handle response body
        if not hasattr(response, 'body') or not response.body:
            return response
        
        try:
            # Handle response body (bytes)
            if hasattr(response.body, 'read'):
                response_body = response.body.read().decode('utf-8')
            else:
                response_body = response.body.decode('utf-8') if isinstance(response.body, bytes) else str(response.body)
        except Exception as e:
            print(f"Failed to parse Bedrock response body: {e}")
            return response
        
        # Parse usage data
        usage_data = parse_bedrock_response(response_body)
        input_tokens = usage_data['input_tokens']
        output_tokens = usage_data['output_tokens']
        stop_reason = usage_data['stop_reason']
        
        if input_tokens == 0 and output_tokens == 0:
            return response
        
        # Prepare tracking data
        tracking_data = TrackingData(
            provider="aws-bedrock",
            model=model_id,
            input_tokens=input_tokens,
            output_tokens=output_tokens,
            cached_input_tokens=0,  # Bedrock doesn't support prompt caching yet
            metadata={
                "region": getattr(client, 'config', {}).get('region', 'unknown'),
                "stop_reason": stop_reason,
                "feature_name": feature_context.feature_name,
                "feature_source": feature_context.feature_source,
                "endpoint": feature_context.endpoint,
                "call_site": feature_context.call_site,
            },
            tags=langspend_tags,
            timestamp=time.time(),
            latency=time.time() - start_time,
        )
        
        # Track asynchronously (non-blocking)
        try:
            if asyncio.iscoroutinefunction(langspend.track):
                # If we're in an async context, schedule the tracking
                loop = asyncio.get_event_loop()
                if loop.is_running():
                    asyncio.create_task(langspend.track(tracking_data))
                else:
                    loop.run_until_complete(langspend.track(tracking_data))
            else:
                # Sync tracking
                langspend.track_sync(tracking_data)
        except Exception as e:
            if langspend.config.debug:
                print(f"LangSpend tracking failed: {e}")
        
        return response
    
    # Replace the send method
    client.send = wrapped_send
    
    return client
