"""Databricks Guard implementation."""
from dataclasses import dataclass

from virtueai.models import VirtueAIModel, DatabricksDbModel, VirtueAIResponseStatus, VirtueAIResponse
from openai import OpenAI
from .together_safety_model import TogetherSafetyModel
from virtueai.models import SafetyModel
import logging

@dataclass
class GuardDatabricksConfig:
    databricks_api_key: str
    databricks_url: str
    database_db_model: DatabricksDbModel
    safety_model: VirtueAIModel
    together_api_key: str | None = None
    max_tokens: int = 256

class GuardDatabricks:
    def __init__(self, config: GuardDatabricksConfig):
        self.config = config
        logging.info(f"Initializing GuardDatabricks with config: {self.config}")
        try:
            self.databricks_client = OpenAI(
                api_key=self.config.databricks_api_key,
                base_url=self.config.databricks_url,
            )
        except Exception as e:
            logging.error(f"Error initializing Databricks client: {e}")
            raise e

        if self.config.together_api_key:
            # self.together_client = Together(api_key=self.config.together_api_key)
            self.safety_model: SafetyModel = TogetherSafetyModel(api_key=self.config.together_api_key, safety_model=self.config.safety_model)
        else:
            self.safety_model: SafetyModel = None

    def databricks_chat(self, messages: list[dict]) -> str:
        # Call Databricks model
        try:
            completion = self.databricks_client.chat.completions.create(
                model=self.config.database_db_model.value,
                messages=messages,
                max_tokens=self.config.max_tokens,
            )
            assistant_output = completion.choices[0].message.content

            # Handle potential list structure returned by some models
            if isinstance(assistant_output, list):
                assistant_output = " ".join(
                    item.get("text", "")
                    for item in assistant_output
                    if item.get("type") == "text"
                )

            return assistant_output
        except Exception as e:
            logging.error(f"Error calling Databricks model: {e}")
            raise e

    async def __safety_check(self, query: str) -> bool:
        if self.safety_model:
            try:
                return await self.safety_model(query)
            except Exception as e:
                logging.error(f"Error calling Safety model: {e}")
                return False
        else:
            return False

    async def __call__(
        self,
        messages,
    ) -> VirtueAIResponse:
        user_content = " ".join(m["content"] for m in messages if m["role"] == "user")

        if not await self.__safety_check(user_content):
            return VirtueAIResponse(status=VirtueAIResponseStatus.UNSAFE, message="Sorry, I can't help with that.")

        # Databricks chat
        assistant_output = self.databricks_chat(messages)

        # Safety-check the model's response if flag enabled (treat output as a standalone user message)
        if not await self.__safety_check(assistant_output):
            return VirtueAIResponse(status=VirtueAIResponseStatus.UNSAFE, message="Sorry, I can't help with that.")

        return VirtueAIResponse(status=VirtueAIResponseStatus.SUCCESS, validated_output=assistant_output)
