from pathlib import Path
from typing import Any

from pydantic import BaseModel, Field

from ...utils import resolve_default
from ..utils import _request


class Column(BaseModel):
    name: str
    """The name of the column."""
    data_type: str
    """The data type of the column (as a postgres data type)."""
    default_value: str | None = None
    """The default value of the column."""
    is_nullable: bool
    """Whether the column can be null."""
    is_unique: bool
    """Whether the column is unique."""
    is_primary: bool 
    """Whether the column is a primary key or part of a composite primary key."""
    comment: str | None = None
    """The comment of the column."""


class Table(BaseModel):
    name: str
    """The name of the table."""
    columns: list[Column]
    """The columns of the table."""
    comment: str | None = None
    """The comment of the table."""
    constraints: list[Any] # TODO Create a proper enum for this
    """The constraints on the table."""


async def create_table(
    org_id: str = Field(
        ...,
        description="The organization ID."
    ),
    kb_id: str = Field(
        ...,
        description="The ID of the knowledge base to create the table in."
    ),
    name: str = Field(
        ...,
        description="The name of the table to create."
    ),
    columns: list[Column] = Field(
        ...,
        description="List of column definitions."
    ),
    comment: str | None = Field(
        None,
        description="The comment of the table.",
    ),
) -> None:
    """
    Create a new table in a knowledge base.

    Args:
        org_id (str): The organization ID.
        kb_id (str): The knowledge base ID where the table will be created.
        name (str): The name of the table to create.
        columns (list[Column]): A list of column definitions.
        comment (str | None): An optional comment describing the table.
    """
    comment = resolve_default("comment", comment)
    # Validate columns
    columns = [column if isinstance(column, Column) else Column.model_validate(column) for column in columns]

    path = f"orgs/{org_id}/kbs/{kb_id}/tables"
    payload = {
        "name": name,
        "columns": [column.model_dump() for column in columns],
        "comment": comment,
    }

    await _request("POST", path, json=payload)


async def delete_table(
    org_id: str = Field(
        ...,
        description="The organization ID."
    ),
    kb_id: str = Field(
        ...,
        description="The ID of the knowledge base to delete the table from."
    ),
    name: str = Field(
        ...,
        description="The name of the table to delete."
    ),
    cascade: bool = Field(
        default=True,
        description="Whether to cascade the delete to the table's indexes."
    ),
) -> None:
    """
    Delete a table in a knowledge base.

    Args:
        org_id (str): The organization ID.
        kb_id (str): The knowledge base ID containing the table.
        name (str): The name of the table to delete.
        cascade (bool): Whether to also delete all indexes and dependent objects associated with the table.
    """
    cascade = resolve_default("cascade", cascade)

    path = f"orgs/{org_id}/kbs/{kb_id}/tables/{name}?"
    payload = {"cascade": cascade}

    await _request("DELETE", path, json=payload)


async def get_table(
    org_id: str = Field(
        ...,
        description="The organization ID."
    ),
    kb_id: str = Field(
        ...,
        description="The ID of the knowledge base to get the table from."
    ),
    name: str = Field(
        ...,
        description="The name of the table to get."
    ),
) -> Table:
    """
    Retrieve the full definition of a table from a knowledge base.

    Args:
        org_id (str): The organization ID.
        kb_id (str): The knowledge base ID containing the table.
        name (str): The name of the table to retrieve.

    Returns:
        Table: A Table model containing the table's name, columns, comment, and constraints.
    """
    path = f"orgs/{org_id}/kbs/{kb_id}/tables/{name}"
    params = {"format": "full"}
    
    res = await _request("GET", path, params=params)
    return Table.model_validate(res.json().get("table"))


async def get_table_sql(
    org_id: str = Field(
        ...,
        description="The organization ID."
    ),
    kb_id: str = Field(
        ...,
        description="The ID of the knowledge base."
    ),
    name: str = Field(
        ...,
        description="The name of the table to get the definition for."
    ),
) -> str:
    """
    Retrieve the complete SQL definition of a table, including its structure and constraints.

    Args:
        org_id (str): The organization ID.
        kb_id (str): The knowledge base ID containing the table.
        name (str): The name of the table to retrieve the definition for.

    Returns:
        str: The full CREATE TABLE statement, including columns, constraints, and any associated indexes, as a formatted SQL string.
    """
    path = f"orgs/{org_id}/kbs/{kb_id}/tables/{name}"
    params = {"format": "definition"}
    
    res = await _request("GET", path, params=params)
    return res.json().get("table")


async def get_tables(
    org_id: str = Field(
        ...,
        description="The organization ID."
    ),
    kb_id: str = Field(
        ...,
        description="The ID of the knowledge base."
    ),
) -> list[Table]:
    """
    List all tables in a knowledge base.

    Args:
        org_id (str): The organization ID.
        kb_id (str): The knowledge base ID.

    Returns:
        list[Table]: A list of Table models, each containing the table's name, columns, comment, and constraints.
    """    
    path = f"orgs/{org_id}/kbs/{kb_id}/tables"
    params = {"format": "full"}
    
    res = await _request("GET", path, params=params)
    return [Table.model_validate(table) for table in res.json().get("tables", [])]


# TODO Maybe we want to join the list into a single string (LLM oriented)
async def get_tables_sql(
    org_id: str = Field(
        ...,
        description="The organization ID."
    ),
    kb_id: str = Field(
        ...,
        description="The ID of the knowledge base."
    ),
) -> list[str]:
    """
    List all tables in a knowledge base.

    Args:
        org_id (str): The organization ID.
        kb_id (str): The knowledge base ID.

    Returns:
        list[str]: A list of full table SQL definitions.
    """    
    path = f"orgs/{org_id}/kbs/{kb_id}/tables"
    params = {"format": "definition"}
    
    res = await _request("GET", path, params=params)
    return res.json().get("tables", [])


async def import_records(
    org_id: str = Field(
        ...,
        description="The organization ID."
    ),
    kb_id: str = Field(
        ...,
        description="The ID of the knowledge base to import records to."
    ),
    table_name: str = Field(
        ...,
        description="The name of the table to import records to."
    ),
    records: list[dict[str, Any]] = Field(
        ...,
        description="The records to import."
    ),
) -> None:
    """
    Import records into a table in a knowledge base.

    Args:
        org_id (str): The organization ID.
        kb_id (str): The knowledge base ID containing the table.
        table_name (str): The name of the table to import records to.
        records (list[dict[str, Any]]): The records to import. Each record should be a dictionary where keys match the table's column names.

    Example:
        # For an example table "Documents" with columns: id, filename, content
        await import_records(
            org_id="10",
            kb_id="48",
            table_name="Documents",
            records=[
                {"id": 1, "filename": "foo.txt", "content": "Hello world!"},
                {"id": 2, "filename": "bar.txt", "content": "Another document"}
            ]
        )
    """
    path = f"orgs/{org_id}/kbs/{kb_id}/tables/{table_name}/records"
    payload = {"records": records}

    await _request("POST", path, json=payload)


async def import_csv(
    org_id: str = Field(
        ...,
        description="The organization ID."
    ),
    kb_id: str = Field(
        ...,
        description="The ID of the knowledge base containing the table."
    ),
    table_name: str = Field(
        ...,
        description="The name of the table to upload the CSV to."
    ),
    csv_path: str | Path = Field(
        ...,
        description="The path to the CSV file."
    ),
) -> None:
    """
    Upload a CSV file to a table in a knowledge base.

    This function imports data from a CSV file into an existing table in the specified knowledge base. 
    The CSV file must match the table's schema (column names and types). 
    You can choose to either overwrite the table's contents or append to it.

    Args:
        org_id (str): The organization ID.
        kb_id (str): The ID of the knowledge base containing the table.
        table_name (str): The name of the table to upload the CSV to.
        csv_path (str | Path): The path to the CSV file on disk.
    """
    path = f"orgs/{org_id}/kbs/{kb_id}/tables/{table_name}/csv"
    headers = {"Content-Type": "text/csv"}

    with open(csv_path, "rb") as f:
        csv_data = f.read()
    
    await _request("POST", path, headers=headers, content=csv_data)


async def search_table(
    org_id: str = Field(
        ...,
        description="The organization ID."
    ),
    kb_id: str = Field(
        ...,
        description="The ID of the knowledge base containing the table."
    ),
    table_name: str = Field(
        ...,
        description="The name of the table to search."
    ),
    embedding_names: list[str] = Field(
        ...,
        description="The names of the embeddings to use for the search."
    ),
    query: str = Field(
        ...,
        description="Semantic search query"
    ),
    fts_columns: list[str] = Field(
        None,
        description="List of column names to use for full-text search. When specified together with embedding_names, search will be hybrid."
    ),
    limit: int = Field(
        10,
        description="The maximum number of results to return.",
    ),
    offset: int = Field(
        0,
        description="The offset to use for pagination.",
    ),
    select_columns: list[str] = Field(
        None,
        description="The columns to select."
    ),
    where: str = Field(
        None,
        description="SQL WHERE expression to apply to the search. ."
    ),
) -> list[dict[str, Any]]:
    """
    Perform a semantic search on a table within a knowledge base using embeddings.

    This function queries the specified table for records most relevant to the provided query string,
    leveraging one or more embedding models for semantic similarity. The search returns the top matching rows
    based on the embeddings and query.

    Args:
        org_id (str): The organization ID.
        kb_id (str): The ID of the knowledge base containing the table.
        table_name (str): The name of the table to search.
        embedding_names (list[str]): The names of the embedding models to use for the search.
        query (str): The natural language query or search phrase.
        fts_columns (list[str], optional): List of column names to use for full-text search. When specified together with embedding_names, search will be hybrid.
        limit (int, optional): The maximum number of results to return. Defaults to 10.
        offset (int, optional): The offset for pagination. Defaults to 0.
        select_columns (list[str], optional): The columns to select. Defaults to all columns.
        where (str, optional): SQL WHERE expression to apply to the search.

    Returns:
        list[dict[str, Any]]: A list of ordered records, each as a dictionary. The structure of each record
        depends on the table schema and the specified select columns.
    """
    fts_columns = resolve_default("fts_columns", fts_columns)
    limit = resolve_default("limit", limit)
    offset = resolve_default("offset", offset)
    select_columns = resolve_default("select_columns", select_columns)
    where = resolve_default("where", where)

    path = f"orgs/{org_id}/kbs/{kb_id}/tables/{table_name}/search"
    payload = {
        "query": query,
        "embedding_names": embedding_names,
        "fts_columns": fts_columns,
        "limit": limit,
        "offset": offset,
        "select_columns": select_columns,
        "where": where,
    }

    res = await _request("POST", path, json=payload)
    return res.json()


async def query(
    org_id: str = Field(
        ...,
        description="The organization ID."
    ),
    kb_id: str = Field(
        ...,
        description="The ID of the knowledge base."
    ),
    sql: str = Field(
        ...,
        description="The SQL query to execute."
    ),
) -> list[dict[str, Any]]:
    """
    Execute a SQL query against a knowledge base table (PostgreSQL dialect).

    Args:
        org_id (str): The organization ID.
        kb_id (str): The knowledge base ID.
        sql (str): The SQL query to execute. This must be valid PostgreSQL SQL.

    Returns:
        list[dict[str, Any]]: The query results as a list of dictionaries, where each dictionary represents a row.

    Notes:
        - SQL syntax must follow PostgreSQL conventions.
        - Table and column names are case sensitive. If your identifiers use uppercase or mixed case, you must escape them with double quotes (e.g., "Documents", "FileName").
        - Unescaped identifiers are automatically lowercased by PostgreSQL.

    Example:
        # Count all the documents in the table
        await query(
            org_id="10",
            kb_id="48",
            sql='SELECT COUNT(*) FROM "Documents"'
        )
    """
    path = f"orgs/{org_id}/kbs/{kb_id}/query"
    payload = {"sql": sql}

    res = await _request("POST", path, json=payload)
    return res.json()
    

async def add_column(
    org_id: str = Field(
        ...,
        description="The organization ID."
    ),
    kb_id: str = Field(
        ...,
        description="The ID of the knowledge base to create the table in."
    ),
    table_name: str = Field(
        ...,
        description="The name of the table to add the column to."
    ),
    column: Column = Field(
        ...,
        description="The column to add."
    ),
) -> None:
    """
    Add a column to an existing table in a knowledge base.

    Args:
        org_id (str): The organization ID.
        kb_id (str): The knowledge base ID containing the table.
        table_name (str): The name of the table to add the column to.
        column (Column): The column definition to add to the table.
    """
    # Validate column
    column = column if isinstance(column, Column) else Column.model_validate(column)

    path = f"orgs/{org_id}/kbs/{kb_id}/tables/{table_name}"
    payload = {"operations": [{"add_column": column.model_dump()}]}

    await _request("PATCH", path, json=payload)
    

async def drop_column(
    org_id: str = Field(
        ...,
        description="The organization ID."
    ),
    kb_id: str = Field(
        ...,
        description="The ID of the knowledge base containing the table."
    ),
    table_name: str = Field(
        ...,
        description="The name of the table to drop the column from."
    ),
    name: str = Field(
        ...,
        description="The name of the column to drop."
    ),
    cascade: bool = Field(
        True,
        description="Whether to cascade the drop to dependent objects."
    ),
) -> None:
    """
    Drop a column from an existing table in a knowledge base.

    Args:
        org_id (str): The organization ID.
        kb_id (str): The knowledge base ID containing the table.
        table_name (str): The name of the table to drop the column from.
        name (str): The name of the column to drop.
        cascade (bool): Whether to cascade the drop to dependent objects.
    """
    cascade = resolve_default("cascade", cascade)

    path = f"orgs/{org_id}/kbs/{kb_id}/tables/{table_name}"
    payload = {"operations": [{"drop_column": {
        "name": name,
        "cascade": cascade
    }}]}

    await _request("PATCH", path, json=payload)


async def rename_column(
    org_id: str = Field(
        ...,
        description="The organization ID."
    ),
    kb_id: str = Field(
        ...,
        description="The ID of the knowledge base containing the table."
    ),
    table_name: str = Field(
        ...,
        description="The name of the table to rename the column in."
    ),
    name: str = Field(
        ...,
        description="The name of the column to rename."
    ),
    new_name: str = Field(
        ...,
        description="The new name for the column."
    ),
) -> None:
    """
    Rename a column in an existing table in a knowledge base.

    Args:
        org_id (str): The organization ID.
        kb_id (str): The knowledge base ID containing the table.
        table_name (str): The name of the table to rename the column in.
        name (str): The current name of the column to rename.
        new_name (str): The new name for the column.
    """
    path = f"orgs/{org_id}/kbs/{kb_id}/tables/{table_name}"
    payload = {"operations": [{"rename_column": {
        "name": name,
        "new_name": new_name
    }}]}

    await _request("PATCH", path, json=payload)


async def rename_table(
    org_id: str = Field(
        ...,
        description="The organization ID."
    ),
    kb_id: str = Field(
        ...,
        description="The ID of the knowledge base containing the table."
    ),
    name: str = Field(
        ...,
        description="The name of the table to rename."
    ),
    new_name: str = Field(
        ...,
        description="The new name for the table."
    ),
) -> None:
    """
    Rename an existing table in a knowledge base.

    Args:
        org_id (str): The organization ID.
        kb_id (str): The knowledge base ID containing the table.
        name (str): The current name of the table to rename.
        new_name (str): The new name for the table.
    """
    path = f"orgs/{org_id}/kbs/{kb_id}/tables/{name}"
    payload = {"operations": [{"rename_table": {"new_name": new_name}}]}

    await _request("PATCH", path, json=payload)


async def add_fk(
    org_id: str = Field(
        ...,
        description="The organization ID."
    ),
    kb_id: str = Field(
        ...,
        description="The ID of the knowledge base containing the table."
    ),
    table_name: str = Field(
        ...,
        description="The name of the table to add the foreign key to."
    ),
    column_names: list[str] = Field(
        ...,
        description="The name of the columns to add the foreign key to."
    ),
    fk_table_name: str = Field(
        ...,
        description="The name of the referenced table."
    ),
    fk_column_names: list[str] = Field(
        ...,
        description="The name of the referenced columns in the foreign table."
    ),
    name: str = Field(
        ...,
        description="The name of the foreign key constraint."
    ),
    on_delete_action: str = Field(
        ...,
        description="The action to take on delete (e.g., 'NO ACTION', 'CASCADE')."
    ),
    on_update_action: str = Field(
        ...,
        description="The action to take on update (e.g., 'NO ACTION', 'CASCADE')."
    ),
) -> None:
    """
    Add a foreign key to an existing table in a knowledge base.

    Args:
        org_id (str): The organization ID.
        kb_id (str): The knowledge base ID containing the table.
        table_name (str): The name of the table to add the foreign key to.
        column_names (list[str]): The name of the columns to add the foreign key to.
        fk_table_name (str): The name of the referenced table.
        fk_column_names (list[str]): The name of the referenced columns in the foreign table.
        name (str): The name of the foreign key constraint.
        on_delete_action (str): The action to take on delete (e.g., 'NO ACTION', 'CASCADE').
        on_update_action (str): The action to take on update (e.g., 'NO ACTION', 'CASCADE').
    """
    path = f"orgs/{org_id}/kbs/{kb_id}/tables/{table_name}"
    payload = {
        "operations": [{
            "add_constraint": {
                "type": "FOREIGN KEY",
                "name": name,
                "columns": column_names,
                "reference_table": fk_table_name,
                "reference_columns": fk_column_names,
                "on_delete_action": on_delete_action,
                "on_update_action": on_update_action
            }
        }]
    }

    await _request("PATCH", path, json=payload)


async def drop_constraint(
    org_id: str = Field(
        ...,
        description="The organization ID."
    ),
    kb_id: str = Field(
        ...,
        description="The ID of the knowledge base containing the table."
    ),
    table_name: str = Field(
        ...,
        description="The name of the table to drop the constraint from."
    ),
    name: str = Field(
        ...,
        description="The name of the constraint to drop."
    ),
    cascade: bool = Field(
        True,
        description="Whether to cascade the drop to dependent objects."
    ),
) -> None:
    """
    Drop a constraint from an existing table in a knowledge base.

    Args:
        org_id (str): The organization ID.
        kb_id (str): The knowledge base ID containing the table.
        table_name (str): The name of the table to drop the constraint from.
        name (str): The name of the constraint to drop.
        cascade (bool): Whether to cascade the drop to dependent objects.
    """
    cascade = resolve_default("cascade", cascade)
    
    path = f"orgs/{org_id}/kbs/{kb_id}/tables/{table_name}"
    payload = {
        "operations": [{
            "drop_constraint": {
                "name": name,
                "cascade": cascade
            }
        }]
    }

    await _request("PATCH", path, json=payload)


async def add_check(
    org_id: str = Field(
        ...,
        description="The organization ID."
    ),
    kb_id: str = Field(
        ...,
        description="The ID of the knowledge base containing the table."
    ),
    table_name: str = Field(
        ...,
        description="The name of the table to add the check constraint to."
    ),
    name: str = Field(
        ...,
        description="The name of the check constraint."
    ),
    expression: str = Field(
        ...,
        description="The check constraint expression (e.g., 'length(name) < 10')."
    ),
) -> None:
    """
    Add a check constraint to an existing table in a knowledge base.

    Args:
        org_id (str): The organization ID.
        kb_id (str): The knowledge base ID containing the table.
        table_name (str): The name of the table to add the check constraint to.
        name (str): The name of the check constraint.
        expression (str): The check constraint expression (e.g., 'length(name) < 10').
    """
    path = f"orgs/{org_id}/kbs/{kb_id}/tables/{table_name}"
    payload = {
        "operations": [{
            "add_constraint": {
                "type": "CHECK",
                "name": name,
                "expression": expression
            }
        }]
    }

    await _request("PATCH", path, json=payload)


async def add_unique(
    org_id: str = Field(
        ...,
        description="The organization ID."
    ),
    kb_id: str = Field(
        ...,
        description="The ID of the knowledge base containing the table."
    ),
    table_name: str = Field(
        ...,
        description="The name of the table to add the unique constraint to."
    ),
    name: str = Field(
        ...,
        description="The name of the unique constraint."
    ),
    columns: list[str] = Field(
        ...,
        description="The columns to add the unique constraint to."
    ),
) -> None:
    """
    Add a unique constraint to an existing table in a knowledge base.

    Args:
        org_id (str): The organization ID.
        kb_id (str): The knowledge base ID containing the table.
        table_name (str): The name of the table to add the unique constraint to.
        name (str): The name of the unique constraint.
        columns (list[str]): The columns to add the unique constraint to.
    """
    path = f"orgs/{org_id}/kbs/{kb_id}/tables/{table_name}"
    payload = {
        "operations": [{
            "add_constraint": {
                "type": "UNIQUE",
                "name": name,
                "columns": columns
            }
        }]
    }

    await _request("PATCH", path, json=payload)
    