import argparse
import asyncio
from virtueai import GuardDatabricks, GuardDatabricksConfig, VirtueAIModel, DatabricksDbModel, VirtueAIResponseStatus, VirtueAIResponse


def test_databricks_guard(databricks_api_key: str, databricks_url: str, together_api_key: str):
    config = GuardDatabricksConfig(
        databricks_api_key=databricks_api_key,
        databricks_url=databricks_url,
        database_db_model=DatabricksDbModel.META_LLAMA_3_1_8B_INSTRUCT,
        safety_model=VirtueAIModel.VIRTUE_GUARD_TEXT_LITE,
        together_api_key=together_api_key,
    )
    guard: GuardDatabricks = GuardDatabricks(config)
    print(guard)
    response: VirtueAIResponse = asyncio.run(guard([{"role": "user", "content": "hey how are you?"}]))
    # response: VirtueAIResponse = await guard([{"role": "user", "content": "Hello, how are you?"}])
    assert response.status == VirtueAIResponseStatus.SUCCESS
    print(f"Response: {response.validated_output}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Test Databricks Guard")
    parser.add_argument("--databricks-api-key", required=True, help="Databricks API key")
    parser.add_argument("--databricks-url", required=True, help="Databricks URL")
    parser.add_argument("--together-api-key", required=True, help="Together API key")

    args = parser.parse_args()

    test_databricks_guard(
        databricks_api_key=args.databricks_api_key,
        databricks_url=args.databricks_url,
        together_api_key=args.together_api_key
    )