import ollama, os, argparse, threading, shutil, json, re
from cybertask import config
from cybertask.utils.ollama_models import ollama_models
from cybertask.utils.streaming_word_wrapper import StreamingWordWrapper
from cybertask.health_check import HealthCheck
if not hasattr(config, "currentMessages"):
    HealthCheck.setBasicConfig()
    HealthCheck.saveConfig()
    #print("Configurations updated!")
from prompt_toolkit.styles import Style
from prompt_toolkit import PromptSession
from prompt_toolkit.history import FileHistory
from prompt_toolkit.shortcuts import clear
from prompt_toolkit.completion import WordCompleter, FuzzyCompleter
from pathlib import Path

promptStyle = Style.from_dict({
    # User input (default text).
    "": config.terminalCommandEntryColor2,
    # Prompt.
    "indicator": config.terminalPromptIndicatorColor2,
})

class OllamaChat:

    def __init__(self):
        # authentication
        if shutil.which("ollama"):
            self.runnable = True
        else:
            print("Local LLM Server 'Ollama' not found! Install Ollama first!")
            print("Read https://ollama.com/.")
            self.runnable = False

    def run(self, prompt="", model="mistral") -> None:
        def checkModel(thisModel) -> bool:
            # check model
            if not f"'model': '{thisModel}'" in str(ollama.list()).replace(":latest", ""):
                # download model
                if thisModel in ollama_models: # known models
                    # display progress
                    os.system(f"ollama pull {thisModel}")
                else:
                    # attempt to download unknown models; handle error
                    try:
                        HealthCheck.print3(f"Downloading '{thisModel}' ...")
                        ollama.pull(thisModel)
                    except ollama.ResponseError as e:
                        print('Error:', e.error)
                        return False
            return True
        
        def extractImagePath(content) -> str:
            promptPrefix = """Write a JSON with two keys "imagePath" and "queryAboutImage" from the following request. "imagePath" is the path of a given image. "queryAboutImage" is the query about the image. Remember, return the JSON ONLY, WITHOUT any extra comment or information:

"""
            response = ollama.chat(
                model="gemma:2b",
                messages=[
                    {
                        "role": "user",
                        "content": f"{promptPrefix}{content}",
                    },
                ])
            answer = response["message"]["content"]
            extract = re.findall(r"```json(.*?)```", answer, re.DOTALL)
            if extract:
                answer = extract[0].strip()
            try:
                imagePath = json.loads(answer)["imagePath"]
                if not imagePath:
                    return ""
                if imagePath and os.path.isfile(imagePath) and HealthCheck.is_valid_image_file(imagePath):
                    return imagePath
            except:
                return ""
            return ""

        if not self.runnable:
            return None

        # check model
        if not checkModel(model):
            return None
        if model.startswith("llava"):
            checkModel("gemma:2b")
        
        previoiusModel = config.ollamaDefaultModel
        config.ollamaDefaultModel = model
        if not config.ollamaDefaultModel:
            config.ollamaDefaultModel = "mistral"
        if not config.ollamaDefaultModel == previoiusModel:
            HealthCheck.saveConfig()

        historyFolder = os.path.join(HealthCheck.getFiles(), "history")
        Path(historyFolder).mkdir(parents=True, exist_ok=True)
        chat_history = os.path.join(historyFolder, f"ollama_{model}")
        chat_session = PromptSession(history=FileHistory(chat_history))

        HealthCheck.print2(f"\n{model.capitalize()} loaded!")

        # history
        messages = []
        if hasattr(config, "currentMessages"):
            for i in config.currentMessages[:-1]:
                if "role" in i and i["role"] in ("system", "user", "assistant") and "content" in i and i.get("content"):
                    messages.append(i)

        # bottom toolbar
        if hasattr(config, "currentMessages"):
            bottom_toolbar = f""" {str(config.hotkey_exit).replace("'", "")} {config.exit_entry}"""
        else:
            bottom_toolbar = f""" {str(config.hotkey_exit).replace("'", "")} {config.exit_entry} {str(config.hotkey_new).replace("'", "")} .new"""
            print("(To start a new chart, enter '.new')")
        print(f"(To exit, enter '{config.exit_entry}')\n")

        while True:
            if not prompt:
                prompt = HealthCheck.simplePrompt(style=promptStyle, promptSession=chat_session, bottom_toolbar=bottom_toolbar)
                if prompt and not prompt in (".new", config.exit_entry) and hasattr(config, "currentMessages"):
                    config.currentMessages.append({"content": prompt, "role": "user"})
            else:
                prompt = HealthCheck.simplePrompt(style=promptStyle, promptSession=chat_session, bottom_toolbar=bottom_toolbar, default=prompt, accept_default=True)
            if prompt == config.exit_entry:
                break
            elif prompt == ".new" and not hasattr(config, "currentMessages"):
                clear()
                messages = []
                print("New chat started!")
            elif prompt := prompt.strip():
                streamingWordWrapper = StreamingWordWrapper()
                config.pagerContent = ""
                if model.startswith("llava"):
                    imagePath = extractImagePath(prompt)
                    if imagePath:
                        messages.append({'role': 'user', 'content': prompt, 'images': [imagePath]})
                        HealthCheck.print3(f"Analyzing image: {imagePath}")
                    else:
                        messages.append({'role': 'user', 'content': prompt})
                else:
                    messages.append({'role': 'user', 'content': prompt})
                try:
                    completion = ollama.chat(
                        model=model,
                        messages=messages,
                        stream=True,
                    )
                    # Create a new thread for the streaming task
                    streaming_event = threading.Event()
                    self.streaming_thread = threading.Thread(target=streamingWordWrapper.streamOutputs, args=(streaming_event, completion,))
                    # Start the streaming thread
                    self.streaming_thread.start()

                    # wait while text output is steaming; capture key combo 'ctrl+q' or 'ctrl+z' to stop the streaming
                    streamingWordWrapper.keyToStopStreaming(streaming_event)

                    # when streaming is done or when user press "ctrl+q"
                    self.streaming_thread.join()

                    # update messages
                    messages.append({"role": "assistant", "content": config.new_chat_response})
                except ollama.ResponseError as e:
                    if hasattr(self, "streaming_thread"):
                        self.streaming_thread.join()
                    print('Error:', e.error)

            prompt = ""

        HealthCheck.print2(f"\n{model.capitalize()} closed!")
        if hasattr(config, "currentMessages"):
            HealthCheck.print2(f"Return back to {config.letMeDoItName} prompt ...")

# available cli: 'ollamachat', 'mistral', 'llama2', 'llama213b', 'llama270b', 'gemma2b', 'gemma7b', 'llava', 'phi', 'vicuna'

def mistral():
    main("mistral")

def llama2():
    main("llama2")

def llama213b():
    main("llama2:13b")

def llama270b():
    main("llama2:70b")

def codellama():
    main("codellama")

def gemma2b():
    main("gemma:2b")

def gemma7b():
    main("gemma:7b")

def llava():
    main("llava")

def phi():
    main("phi")

def vicuna():
    main("vicuna")

def main(thisModel=""):
    # Create the parser
    parser = argparse.ArgumentParser(description="palm2 cli options")
    # Add arguments
    parser.add_argument("default", nargs="?", default=None, help="default entry")
    if not thisModel:
        parser.add_argument('-m', '--model', action='store', dest='model', help="specify language model with -m flag; default: mistral")
    # Parse arguments
    args = parser.parse_args()
    # Get options
    prompt = args.default.strip() if args.default and args.default.strip() else ""
    if thisModel:
        model = thisModel
    else:
        if args.model and args.model.strip():
            model = args.model.strip()
        else:
            historyFolder = os.path.join(HealthCheck.getFiles(), "history")
            Path(historyFolder).mkdir(parents=True, exist_ok=True)
            model_history = os.path.join(historyFolder, "ollama_default")
            model_session = PromptSession(history=FileHistory(model_history))
            completer = FuzzyCompleter(WordCompleter(sorted(ollama_models), ignore_case=True))
            bottom_toolbar = f""" {str(config.hotkey_exit).replace("'", "")} {config.exit_entry}"""

            HealthCheck.print2("Ollama chat launched!")
            print("Select a model below:")
            print("Note: You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.")
            model = HealthCheck.simplePrompt(style=promptStyle, promptSession=model_session, bottom_toolbar=bottom_toolbar, default=config.ollamaDefaultModel, completer=completer)
            if model and model.lower() == config.exit_entry:
                HealthCheck.print2("\nOllama chat closed!")
                return None

    if not model:
        model = config.ollamaDefaultModel
    # Run chat bot
    OllamaChat().run(
        prompt=prompt,
        model=model,
    )
        

if __name__ == '__main__':
    main()