import uuid
import time
from datetime import datetime
from typing import Any, Optional, Iterator
from google import genai
from .storage import LocalStorage
from .remote_storage import RemoteStorage
from .types import TrackedRequest, TrackerOptions


class Client:
    def __init__(
        self,
        api_key: Optional[str] = None,
        vertexai: bool = False,
        project: Optional[str] = None,
        location: Optional[str] = None,
        storage: str = "local",
        local_storage: Optional[dict] = None,
        remote_storage: Optional[dict] = None,
        **kwargs
    ):
        client_kwargs = {k: v for k, v in {
            "api_key": api_key,
            "vertexai": vertexai or None,
            "project": project if vertexai else None,
            "location": location if vertexai else None
        }.items() if v}

        self._client = genai.Client(**client_kwargs)

        if storage == "remote":
            if not remote_storage:
                raise ValueError("remote_storage options required when storage mode is 'remote'")
            self._storage = RemoteStorage(remote_storage)
        else:
            self._storage = LocalStorage(local_storage)
        self._original_generate_content = self._client.models.generate_content
        self._original_generate_content_stream = self._client.models.generate_content_stream

        self._client.models.generate_content = self._wrapped_generate_content
        self._client.models.generate_content_stream = self._wrapped_generate_content_stream

    def _has_search_grounding(self, config: Any, params: dict) -> bool:
        tools = None
        if config:
            tools = config.tools if hasattr(config, "tools") else config.get("tools") if isinstance(config, dict) else None
        if not tools:
            tools = params.get("tools")

        if not tools:
            return False

        return any(
            (hasattr(tool, "google_search") or (isinstance(tool, dict) and "google_search" in tool))
            for tool in tools
        )

    def _extract_grounding_metadata(self, response: Any) -> bool:
        return bool(
            hasattr(response, "candidates") and
            response.candidates and
            hasattr(response.candidates[0], "grounding_metadata") and
            response.candidates[0].grounding_metadata and
            (hasattr(response.candidates[0].grounding_metadata, "search_entry_point") or
             hasattr(response.candidates[0].grounding_metadata, "grounding_chunks"))
        )

    def _safe_storage_add(self, request: TrackedRequest, context: str) -> None:
        try:
            self._storage.add(request)
        except Exception as storage_error:
            print(f"Failed to save {context}: {storage_error}")

    def _build_tracked_request(
        self,
        request_id: str,
        timestamp: datetime,
        model: str,
        contents: Any,
        search_grounding: bool,
        metadata: Any,
        response: Any = None,
        response_time: Optional[float] = None,
        error: Optional[Exception] = None
    ) -> TrackedRequest:
        tracked: TrackedRequest = {
            "id": request_id,
            "timestamp": timestamp,
            "model": model,
            "prompt": contents,
            "search_grounding": search_grounding,
            "metadata": metadata,
        }

        if error:
            tracked["response"] = None
            tracked["error"] = str(error)
            tracked["received_search_grounded_response"] = False
        else:
            tracked["response_time"] = response_time
            tracked["response"] = self._serialize_response(response)
            tracked["received_search_grounded_response"] = self._extract_grounding_metadata(response)
            usage_metadata = response.usage_metadata if hasattr(response, "usage_metadata") else None
            tracked["usage"] = {
                "promptTokens": usage_metadata.prompt_token_count if usage_metadata else None,
                "completionTokens": usage_metadata.candidates_token_count if usage_metadata else None,
                "totalTokens": usage_metadata.total_token_count if usage_metadata else None,
            }

        return tracked

    def _wrapped_generate_content(self, **params: Any) -> Any:
        request_id = str(uuid.uuid4())
        timestamp = datetime.now()
        start_time = time.perf_counter()

        metadata = params.pop("metadata", None)
        model = params.get("model", "unknown")
        contents = params.get("contents")
        config = params.get("config")
        search_grounding = self._has_search_grounding(config, params)

        try:
            response = self._original_generate_content(**params)
            response_time = (time.perf_counter() - start_time) * 1000
            tracked_request = self._build_tracked_request(
                request_id, timestamp, model, contents, search_grounding, metadata,
                response=response, response_time=response_time
            )
            self._safe_storage_add(tracked_request, "tracked request")
            return response
        except Exception as error:
            tracked_request = self._build_tracked_request(
                request_id, timestamp, model, contents, search_grounding, metadata, error=error
            )
            self._safe_storage_add(tracked_request, "tracked request with error")
            raise

    def _serialize_response(self, response: Any) -> dict:
        result = {}

        if hasattr(response, "sdk_http_response"):
            sdk_resp = response.sdk_http_response
            result["sdkHttpResponse"] = {
                "headers": dict(sdk_resp.headers) if hasattr(sdk_resp, "headers") else {}
            }

        if hasattr(response, "candidates"):
            result["candidates"] = []
            for candidate in response.candidates:
                cand_dict = {}
                if hasattr(candidate, "content"):
                    content = candidate.content
                    cand_dict["content"] = {
                        "parts": [{"text": part.text} for part in content.parts if hasattr(part, "text")],
                        "role": content.role if hasattr(content, "role") else None
                    }
                if hasattr(candidate, "finish_reason"):
                    cand_dict["finishReason"] = str(candidate.finish_reason).replace("FinishReason.", "")
                if hasattr(candidate, "grounding_metadata"):
                    cand_dict["groundingMetadata"] = {}
                if hasattr(candidate, "index"):
                    cand_dict["index"] = candidate.index
                result["candidates"].append(cand_dict)

        if hasattr(response, "model_version"):
            result["modelVersion"] = response.model_version
        if hasattr(response, "response_id"):
            result["responseId"] = response.response_id

        if hasattr(response, "usage_metadata"):
            um = response.usage_metadata
            result["usageMetadata"] = {
                "promptTokenCount": getattr(um, "prompt_token_count", None),
                "candidatesTokenCount": getattr(um, "candidates_token_count", None),
                "totalTokenCount": getattr(um, "total_token_count", None),
            }
            if hasattr(um, "prompt_tokens_details"):
                result["usageMetadata"]["promptTokensDetails"] = [
                    {"modality": str(d.modality).replace("MediaModality.", ""), "tokenCount": d.token_count}
                    for d in um.prompt_tokens_details
                ]
            if hasattr(um, "thoughts_token_count"):
                result["usageMetadata"]["thoughtsTokenCount"] = um.thoughts_token_count

        return result

    def _wrapped_generate_content_stream(self, **params: Any) -> Iterator[Any]:
        request_id = str(uuid.uuid4())
        timestamp = datetime.now()
        start_time = time.perf_counter()

        metadata = params.pop("metadata", None)
        model = params.get("model", "unknown")
        contents = params.get("contents")
        config = params.get("config")
        search_grounding = self._has_search_grounding(config, params)

        try:
            stream = self._original_generate_content_stream(**params)
            chunks = []
            for chunk in stream:
                chunks.append(chunk)
                yield chunk

            response_time = (time.perf_counter() - start_time) * 1000
            full_response = chunks[-1] if chunks else None
            tracked_request = self._build_tracked_request(
                request_id, timestamp, model, contents, search_grounding, metadata,
                response=full_response, response_time=response_time
            )
            self._safe_storage_add(tracked_request, "streamed tracked request")
        except Exception as error:
            tracked_request = self._build_tracked_request(
                request_id, timestamp, model, contents, search_grounding, metadata, error=error
            )
            self._safe_storage_add(tracked_request, "streamed tracked request with error")
            raise

    def __getattr__(self, name: str) -> Any:
        return getattr(self._client, name)

    def get_tracked_requests(self) -> list[TrackedRequest]:
        return self._storage.get_all()

    def load_history(self) -> None:
        if hasattr(self._storage, "load"):
            self._storage.load()

    def save_history(self) -> None:
        if hasattr(self._storage, "save"):
            self._storage.save()

    def clear_history(self) -> None:
        self._storage.clear()
