from typing import List, Union, Tuple
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification


class Grader:
    def __init__(
            self,
            model_name_or_path: str = None,
            use_fp16: bool = False,
            num_gpus: int = 1,
            main_gpu: int = 0,
    ) -> None:

        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path)

        # Set the device to main_gpu so that the model can be split across multiple GPUs later
        if torch.cuda.is_available():
            self.device = torch.device(f'cuda:{main_gpu}')

        if use_fp16:
            self.model.half()

        self.model = self.model.to(self.device)

        self.model.eval()
        self.num_gpus = num_gpus

        # DataParallel will divide and allocate batch to all available GPUs
        if self.num_gpus > 1:
            self.model = torch.nn.DataParallel(self.model)

    @torch.no_grad()
    def compute_score(self, sentence_pairs: Union[List[Tuple[str, str]], Tuple[str, str]], batch_size: int = 256,
                      max_length: int = 512) -> List[float]:
        if self.num_gpus > 0:
            batch_size = batch_size * self.num_gpus

        assert isinstance(sentence_pairs, list)
        if isinstance(sentence_pairs[0], str):
            sentence_pairs = [sentence_pairs]

        all_scores = []
        for start_index in tqdm(range(0, len(sentence_pairs), batch_size), desc="Compute Scores",
                                disable=len(sentence_pairs) < 128):
            sentences_batch = sentence_pairs[start_index:start_index + batch_size]
            inputs = self.tokenizer(
                sentences_batch,
                padding=True,
                truncation=True,
                return_tensors='pt',
                max_length=max_length,
            ).to(self.device)

            scores = self.model(**inputs, return_dict=True).logits.view(-1, ).float()
            all_scores.extend(scores.cpu().numpy().tolist())

        if len(all_scores) == 1:
            return all_scores[0]
        return all_scores

