import argparse
import asyncio
from prompt_toolkit import PromptSession
from prompt_toolkit.history import InMemoryHistory
from virtueai.guard import GuardDatabricks, GuardDatabricksConfig
from virtueai.models import VirtueAIModel, DatabricksDbModel


async def chat():
    """Interactive chat prompt"""
    session = PromptSession(history=InMemoryHistory())

    print("Interactive Chat (type 'exit' or 'quit' to stop)\n")

    while True:
        try:
            user_input = await session.prompt_async("You: ")

            if user_input.lower() in ['exit', 'quit', 'q']:
                print("Goodbye!")
                break

            if not user_input.strip():
                continue

            # Return user input for processing
            yield user_input

        except KeyboardInterrupt:
            print("\nGoodbye!")
            break
        except EOFError:
            print("\nGoodbye!")
            break


async def main(databricks_api_key: str, databricks_url: str, together_api_key: str):
    config = GuardDatabricksConfig(
        databricks_api_key=databricks_api_key,
        databricks_url=databricks_url,
        database_db_model=DatabricksDbModel.META_LLAMA_3_1_8B_INSTRUCT,
        safety_model=VirtueAIModel.VIRTUE_GUARD_TEXT_LITE,
        together_api_key=together_api_key,
    )
    guard = GuardDatabricks(config)

    async for user_message in chat():
        messages = [{"role": "user", "content": user_message}]
        response = await guard(messages)

        if response.validated_output:
            print(f"Assistant: {response.validated_output}\n")
        else:
            print(f"Assistant: {response.message}\n")


def main_cli():
    parser = argparse.ArgumentParser(description="Interactive Databricks Guard Demo")
    parser.add_argument("--databricks-api-key", required=True, help="Databricks API key")
    parser.add_argument("--databricks-url", required=True, help="Databricks URL")
    parser.add_argument("--together-api-key", required=True, help="Together API key")

    args = parser.parse_args()

    asyncio.run(main(
        databricks_api_key=args.databricks_api_key,
        databricks_url=args.databricks_url,
        together_api_key=args.together_api_key
    ))


if __name__ == "__main__":
    main_cli()