from typing import Any

import numpy as np
import torch
from sam2.sam2_image_predictor import SAM2ImagePredictor


class Segmentor(object):
    __slots__ = ["_model", "_pth", "_device"]

    def __init__(self, pth: float = 0.5, device: str = "cuda") -> None:
        self._pth = pth
        self._device = device

    @property
    def pth(self) -> float:
        return self._pth

    def __call__(self, img: np.ndarray) -> Any:
        raise NotImplementedError

    def _sigmoid(self, logits: np.ndarray) -> np.ndarray:
        return 1 / (1 + np.exp(-logits))


class Sam2Segmentor(Segmentor):
    def __init__(
        self,
        model_id: str = "facebook/sam2-hiera-large",
        pth: float = 0.5,
        device: str = "cuda",
    ) -> None:
        super().__init__(pth=pth, device=device)
        self._model: SAM2ImagePredictor = SAM2ImagePredictor.from_pretrained(model_id)

    def __call__(
        self, img: np.ndarray, input_points: np.ndarray, input_labels: np.ndarray
    ) -> np.ndarray:
        self._model.set_image(img)
        with torch.inference_mode(), torch.autocast(self._device, dtype=torch.bfloat16):
            mask_logits, _, _ = self._model.predict(
                point_coords=input_points,
                point_labels=input_labels,
                multimask_output=False,
                return_logits=True,
            )
        return self._sigmoid(mask_logits[0])
