import logging
import time
import ctranslate2
import torch
import transformers
from dataclasses import dataclass, field
import huggingface_hub
from whisperlivekit.translation.mapping_languages import get_nllb_code
from whisperlivekit.timed_objects import Translation

logger = logging.getLogger(__name__)

#In diarization case, we may want to translate just one speaker, or at least start the sentences there

MIN_SILENCE_DURATION_DEL_BUFFER = 3 #After a silence of x seconds, we consider the model should not use the buffer, even if the previous
# sentence is not finished.

@dataclass
class TranslationModel():
    translator: ctranslate2.Translator
    device: str
    tokenizer: dict = field(default_factory=dict)
    backend_type: str = 'ctranslate2'
    model_size: str = '600M'
    
    def get_tokenizer(self, input_lang):
        if not self.tokenizer.get(input_lang, False):
            self.tokenizer[input_lang] = transformers.AutoTokenizer.from_pretrained(
                f"facebook/nllb-200-distilled-{self.model_size}",
                src_lang=input_lang,
                clean_up_tokenization_spaces=True
            )
        return self.tokenizer[input_lang]
            

def load_model(src_langs, backend='ctranslate2', model_size='600M'):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    MODEL = f'nllb-200-distilled-{model_size}-ctranslate2'
    if backend=='ctranslate2':
        MODEL_GUY = 'entai2965'
        huggingface_hub.snapshot_download(MODEL_GUY + '/' + MODEL,local_dir=MODEL)
        translator = ctranslate2.Translator(MODEL,device=device)
    elif backend=='transformers':
        translator = transformers.AutoModelForSeq2SeqLM.from_pretrained(f"facebook/nllb-200-distilled-{model_size}")
    tokenizer = dict()
    for src_lang in src_langs:
        if src_lang != 'auto':
            tokenizer[src_lang] = transformers.AutoTokenizer.from_pretrained(MODEL, src_lang=src_lang, clean_up_tokenization_spaces=True)

    translation_model = TranslationModel(
        translator=translator,
        tokenizer=tokenizer,
        backend_type=backend,
        device = device,
        model_size = model_size
    )
    for src_lang in src_langs:
        if src_lang != 'auto':
            translation_model.get_tokenizer(src_lang)
    return translation_model

class OnlineTranslation:
    def __init__(self, translation_model: TranslationModel, input_languages: list, output_languages: list):
        self.buffer = []
        self.len_processed_buffer = 0
        self.translation_remaining = Translation()
        self.validated = []
        self.translation_pending_validation = ''
        self.translation_model = translation_model
        self.input_languages = input_languages
        self.output_languages = output_languages

    def compute_common_prefix(self, results):
        #we dont want want to prune the result for the moment. 
        if not self.buffer:
            self.buffer = results
        else:
            for i in range(min(len(self.buffer), len(results))):
                if self.buffer[i] != results[i]:
                    self.commited.extend(self.buffer[:i])
                    self.buffer = results[i:]

    def translate(self, input, input_lang, output_lang):
        if not input:
            return ""
        nllb_output_lang = get_nllb_code(output_lang)
            
        tokenizer = self.translation_model.get_tokenizer(input_lang)
        tokenizer_output = tokenizer(input, return_tensors="pt").to(self.translation_model.device)
        
        if self.translation_model.backend_type == 'ctranslate2':
            source = tokenizer.convert_ids_to_tokens(tokenizer_output['input_ids'][0])    
            results = self.translation_model.translator.translate_batch([source], target_prefix=[[nllb_output_lang]])
            target = results[0].hypotheses[0][1:]
            result = tokenizer.decode(tokenizer.convert_tokens_to_ids(target))
        else:
            translated_tokens = self.translation_model.translator.generate(**tokenizer_output, forced_bos_token_id=tokenizer.convert_tokens_to_ids(nllb_output_lang))
            result = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
        return result
    
    def translate_tokens(self, tokens):
        if tokens:
            text = ' '.join([token.text for token in tokens])
            start = tokens[0].start
            end = tokens[-1].end
            if self.input_languages[0] == 'auto':
                input_lang = tokens[0].detected_language
            else:
                input_lang = self.input_languages[0]
                
            translated_text = self.translate(text,
                                            input_lang,
                                            self.output_languages[0]
                                            )
            translation = Translation(
                text=translated_text,
                start=start,
                end=end,
            )
            if 'fr' in translation.text:
                print("ok")
            return translation
        return None
            

    def insert_tokens(self, tokens):
        self.buffer.extend(tokens)
        pass
    
    def process(self):
        i = 0
        if len(self.buffer) < self.len_processed_buffer + 3: #nothing new to process
            return self.validated + [self.translation_remaining]
        while i < len(self.buffer):
            if self.buffer[i].is_punctuation():
                translation_sentence = self.translate_tokens(self.buffer[:i+1])
                self.validated.append(translation_sentence)
                self.buffer = self.buffer[i+1:]
                i = 0
            else:
                i+=1
        self.translation_remaining = self.translate_tokens(self.buffer)
        self.len_processed_buffer = len(self.buffer)
        return self.validated + [self.translation_remaining]

    def insert_silence(self, silence_duration: float):
        if silence_duration >= MIN_SILENCE_DURATION_DEL_BUFFER:
            self.buffer = []
            self.validated += [self.translation_remaining]

if __name__ == '__main__':
    input_lang = "fr"
    output_lang = 'en'
    
    text = " J 'ai  tent"
    
    shared_model = load_model([input_lang], backend='ctranslate2')
    online_translation = OnlineTranslation(shared_model, input_languages=[input_lang], output_languages=[output_lang])
    result = online_translation.translate(text,
            input_lang,
            output_lang
            )
    print(result) # Jfr 'ai tent is a very good one.
    