import whisper
from whisper_timestamped import transcribe_timestamped


_MODELS = {
    "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
    "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
    "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
    "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
    "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
    "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
    "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
    "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
    "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
    "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
    "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
    "large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
}

# whisper records/tmp.wav --language zh --model tiny --model_dir ./models
# whisper records/tmp.wav --language zh --model base --model_dir ./models
# whisper records/tmp.wav --language zh --model small --model_dir ./models
# whisper records/tmp.wav --language zh --model medium --model_dir ./models
# whisper records/tmp.wav --language zh --model large --model_dir ./models


def whisper_transcribe(audio_path, download_root, model_size="base", target_lang="zh"):
    audio_model = whisper.load_model(model_size, download_root=download_root)
    result = audio_model.transcribe(
        audio_path,
        language=target_lang,
        beam_size=5,
        word_timestamps=True,
        condition_on_previous_text=True,
    )
    print(result)
    text = result["text"].strip()
    print(text)


# whisper_timestamped records/tmp.wav --language zh --model base --model_dir ./models
# whisper_timestamped records/tmp.wav --language zh --model base --model_dir ./models
# whisper_timestamped records/tmp.wav --language zh --model small --model_dir ./models
# whisper_timestamped records/tmp.wav --language zh --model medium --model_dir ./models
# whisper_timestamped records/tmp.wav --language zh --model large --model_dir ./models


def whisper_transcribe_timestamped(audio_path, download_root, model_size="base", target_lang="zh"):
    audio_model = whisper.load_model(model_size, download_root=download_root)
    # help(transcribe_timestamped)
    result = transcribe_timestamped(audio_model, audio_path, language=target_lang)
    print(result)
    text = result["text"].strip()
    print(text)


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_type",
        "-t",
        type=str,
        default="whisper",
        help="choice whisper | whisper_timestamped",
    )
    parser.add_argument(
        "--audio_path", "-a", type=str, default="./records/tmp.wav", help="audio path"
    )
    parser.add_argument("--model_size", "-s", type=str, default="base", help="model size")
    parser.add_argument("--model_path", "-m", type=str, default="./models", help="model root path")
    parser.add_argument("--lang", "-l", type=str, default="zh", help="target language")
    args = parser.parse_args()
    if args.model_type == "whisper_timestamped":
        whisper_transcribe_timestamped(args.audio_path, args.model_path, args.model_size, args.lang)
    else:
        whisper_transcribe(args.audio_path, args.model_path, args.model_size, args.lang)
