# Copyright 2024 Marimo. All rights reserved.
from __future__ import annotations

import ast
import re
from dataclasses import dataclass, field
from textwrap import dedent
from typing import Any, Literal, Optional, Union

from marimo import _loggers
from marimo._dependencies.dependencies import DependencyManager

LOGGER = _loggers.marimo_logger()

COMMON_FILE_EXTENSIONS = (
    ".csv",
    ".parquet",
    ".json",
    ".txt",
    ".db",
    ".tsv",
    ".xlsx",
)

SQLKind = Literal["table", "view", "schema", "catalog"]

SQLTypes = Union[SQLKind, Literal["any"]]


class SQLVisitor(ast.NodeVisitor):
    """
    Find any SQL queries in the AST.
    This should be inside a function called `.execute` or `.sql`.
    """

    def __init__(self, raw: bool = False) -> None:
        super().__init__()
        self._sqls: list[str] = []
        self._raw = raw

    def visit_Call(self, node: ast.Call) -> None:
        # Check if the call is a method call and the method is named
        # either 'execute' or 'sql'
        if isinstance(node.func, ast.Attribute) and node.func.attr in (
            "execute",
            "sql",
        ):
            # Check if there are arguments and the first argument is a
            # string or f-string
            if node.args:
                first_arg = node.args[0]
                sql: Optional[str] = None
                if isinstance(first_arg, ast.Constant):
                    sql = first_arg.value
                elif isinstance(first_arg, ast.JoinedStr):
                    if self._raw:
                        f_sql = ast.unparse(first_arg)
                        sql = dedent(
                            f_sql[1:]
                            .strip(f_sql[1])
                            .encode()
                            .decode("unicode_escape")
                        )
                    else:
                        sql = normalize_sql_f_string(first_arg)

                if sql is not None:
                    # Append the SQL query to the list
                    self._sqls.append(sql)
        # Continue walking through the AST
        self.generic_visit(node)

    def get_sqls(self) -> list[str]:
        return self._sqls


def normalize_sql_f_string(node: ast.JoinedStr) -> str:
    """
    Normalize a f-string to a string by joining the parts.

    We add placeholder for {...} expressions in the f-string.
    This is so we can create a valid SQL query to be passed to
    other utilities.
    """

    def print_part(part: ast.expr) -> str:
        if isinstance(part, ast.FormattedValue):
            return print_part(part.value)
        elif isinstance(part, ast.JoinedStr):
            return normalize_sql_f_string(part)
        elif isinstance(part, ast.Constant):
            return str(part.value)
        else:
            # Just add null as a placeholder for {...} expressions
            return "null"

    result = "".join(print_part(part) for part in node.values)
    return result


class TokenExtractor:
    def __init__(self, sql_statement: str, tokens: list[Any]) -> None:
        self.sql_statement = sql_statement
        self.tokens = tokens

    def token_str(self, i: int) -> str:
        sql_statement, tokens = self.sql_statement, self.tokens
        token = tokens[i]
        start = token[0]

        # If it starts with a quote, find the matching end quote
        if sql_statement[start] == '"':
            end = sql_statement.find('"', start + 1) + 1
        elif sql_statement[start] == "'":
            end = sql_statement.find("'", start + 1) + 1
        elif sql_statement[start:].startswith("e'"):
            start += 1
            end = sql_statement.find("'", start + 1) + 1
        else:
            # For non-quoted tokens, find until space or comment
            maybe_end = re.search(r"[\s\-/]", sql_statement[start:])
            end = (
                start + maybe_end.start() if maybe_end else len(sql_statement)
            )
            if i + 1 < len(tokens):
                # For tokens squashed together e.g. '(select' or 'x);;'
                # in (select * from x);;
                end = min(end, tokens[i + 1][0])

        return sql_statement[start:end]

    def is_keyword(self, i: int, match: str) -> bool:
        import duckdb

        if self.tokens[i][1] != duckdb.token_type.keyword:
            return False
        return self.token_str(i).lower() == match

    def strip_quotes(self, token: str) -> str:
        if token.startswith('"') and token.endswith('"'):
            return token.strip('"')
        elif token.startswith("'") and token.endswith("'"):
            return token.strip("'")
        return token


@dataclass
class SQLDefs:
    tables: list[SQLRef] = field(default_factory=list)
    views: list[SQLRef] = field(default_factory=list)
    schemas: list[str] = field(default_factory=list)
    catalogs: list[str] = field(default_factory=list)

    # The schemas referenced in the CREATE SQL statement
    reffed_schemas: list[str] = field(default_factory=list)
    # The catalogs referenced in the CREATE SQL statement
    reffed_catalogs: list[str] = field(default_factory=list)


def find_sql_defs(sql_statement: str) -> SQLDefs:
    """
    Find the tables, views, schemas, and catalogs created/attached in a SQL statement.

    This function uses the DuckDB tokenizer to find the tables created
    and schemas attached in a SQL statement. It returns a list of the table
    names created, views created, schemas created, and catalogs attached in the
    statement.

    Args:
        sql_statement: The SQL statement to parse.

    Returns:
        SQLDefs
    """
    if not DependencyManager.duckdb.has():
        return SQLDefs()

    import duckdb

    tokens = duckdb.tokenize(sql_statement)
    token_extractor = TokenExtractor(
        sql_statement=sql_statement, tokens=tokens
    )
    created_tables: list[SQLRef] = []
    created_views: list[SQLRef] = []
    created_schemas: list[str] = []
    created_catalogs: list[str] = []

    reffed_schemas: list[str] = []
    reffed_catalogs: list[str] = []
    i = 0

    # See
    #
    #   https://duckdb.org/docs/sql/statements/create_table#syntax
    #   https://duckdb.org/docs/sql/statements/create_view#syntax
    #
    # for the CREATE syntax, and
    #
    #   https://duckdb.org/docs/sql/statements/attach#attach-syntax
    #
    # for ATTACH syntax
    while i < len(tokens):
        if token_extractor.is_keyword(i, "create"):
            # CREATE TABLE, CREATE VIEW, CREATE SCHEMA have the same syntax
            i += 1
            if i < len(tokens) and token_extractor.is_keyword(i, "or"):
                i += 2  # Skip 'OR REPLACE'
            if i < len(tokens) and (
                token_extractor.is_keyword(i, "temporary")
                or token_extractor.is_keyword(i, "temp")
            ):
                i += 1  # Skip 'TEMPORARY' or 'TEMP'

            is_table = False
            is_view = False
            is_schema = False

            if i < len(tokens) and (
                (is_table := token_extractor.is_keyword(i, "table"))
                or (is_view := token_extractor.is_keyword(i, "view"))
                or (is_schema := token_extractor.is_keyword(i, "schema"))
            ):
                i += 1
                if i < len(tokens) and token_extractor.is_keyword(i, "if"):
                    i += 3  # Skip 'IF NOT EXISTS'
                if i < len(tokens):
                    # Get table name parts, this could be:
                    # - catalog.schema.table
                    # - catalog.table (this is shorthand for catalog.main.table)
                    # - table

                    parts: list[str] = []
                    while i < len(tokens):
                        part = token_extractor.strip_quotes(
                            token_extractor.token_str(i)
                        )
                        parts.append(part)
                        # next token is a dot, so we continue getting parts
                        if (
                            i + 1 < len(tokens)
                            and token_extractor.token_str(i + 1) == "."
                        ):
                            i += 2
                            continue
                        break

                    # Assert parts is either 1, 2, or 3
                    if len(parts) not in (1, 2, 3):
                        LOGGER.warning(
                            "Unexpected number of parts in CREATE TABLE: %s",
                            parts,
                        )

                    if is_table:
                        # only add the table name
                        created_tables.append(SQLRef.from_parts(parts))
                        # add the catalog and schema if exist
                        if len(parts) == 3:
                            reffed_catalogs.append(parts[0])
                            reffed_schemas.append(parts[1])
                        if len(parts) == 2:
                            reffed_catalogs.append(parts[0])
                    elif is_view:
                        # only add the table name
                        created_views.append(SQLRef.from_parts(parts))
                        # add the catalog and schema if exist
                        if len(parts) == 3:
                            reffed_catalogs.append(parts[0])
                            reffed_schemas.append(parts[1])
                        if len(parts) == 2:
                            reffed_catalogs.append(parts[0])
                    elif is_schema:
                        # only add the schema name
                        created_schemas.append(parts[-1])
                        # add the catalog if exist
                        if len(parts) == 2:
                            reffed_catalogs.append(parts[0])
        elif token_extractor.is_keyword(i, "attach"):
            catalog_name = None
            i += 1
            if i < len(tokens) and token_extractor.is_keyword(i, "database"):
                i += 1  # Skip 'DATABASE'
            if i < len(tokens) and token_extractor.is_keyword(i, "if"):
                i += 3  # Skip "IF NOT EXISTS"
            if i < len(tokens):
                catalog_name = token_extractor.strip_quotes(
                    token_extractor.token_str(i)
                )
                if "." in catalog_name:
                    # e.g. "db.sqlite"
                    # strip the extension from the name
                    catalog_name = catalog_name.split(".")[0]
                if ":" in catalog_name:
                    # e.g. "md:my_db"
                    # split on ":" and take the second part
                    catalog_name = catalog_name.split(":")[1]
            if i + 1 < len(tokens) and token_extractor.is_keyword(i + 1, "as"):
                # Skip over database-path 'AS'
                i += 2
                # AS clause gets precedence in creating database
                if i < len(tokens):
                    catalog_name = token_extractor.strip_quotes(
                        token_extractor.token_str(i)
                    )
            if catalog_name is not None:
                created_catalogs.append(catalog_name)

        i += 1

    # Remove 'memory' from catalogs, as this is the default and doesn't have a def
    if "memory" in reffed_catalogs:
        reffed_catalogs.remove("memory")
    # Remove 'main' from schemas, as this is the default and doesn't have a def
    if "main" in reffed_schemas:
        reffed_schemas.remove("main")

    return SQLDefs(
        tables=created_tables,
        views=created_views,
        schemas=created_schemas,
        catalogs=created_catalogs,
        reffed_schemas=reffed_schemas,
        reffed_catalogs=reffed_catalogs,
    )


@dataclass(frozen=True)
class SQLRef:
    # Tables are synonymous with views,
    # since we can't know the difference in queries
    table: str
    schema: Optional[str] = None
    catalog: Optional[str] = None

    @classmethod
    def from_parts(
        cls,
        parts: list[str],
    ) -> SQLRef:
        catalog = None
        schema = None
        table = ""
        if len(parts) == 3:
            catalog, schema, table = parts
            catalog = catalog.lower()
            schema = schema.lower()
        elif len(parts) == 2:
            schema, table = parts
            schema = schema.lower()
        elif len(parts) == 1:
            table = parts[0]
        return cls(table=table.lower(), schema=schema, catalog=catalog)

    @property
    def qualified_name(self) -> str:
        """Convert a SQLRef to a fully qualified name to be used as a reference in the visitor"""
        parts = []
        if self.catalog is not None:
            parts.append(self.catalog)
        if self.schema is not None:
            parts.append(self.schema)

        # Table is always required
        parts.append(self.table)
        name = ".".join(parts)
        return name.lower()

    def matches_hierarchical_ref(
        self, name: str, ref: str, kind: SQLTypes = "any"
    ) -> bool:
        """
        Determine if a hierarchical reference string matches a SQLRef.

        Args:
            name: The name to match against (could be catalog, schema, or table).
            ref: The fully qualified reference string (e.g., "schema.table", "catalog.schema.table").
            kind: The kind of reference ("table", "view", "schema", "catalog").

        Returns:
            True if the reference matches the SQLRef's structure and values, False otherwise.
        """
        ref = ref.lower()
        name = name.lower()
        parts = ref.split(".")
        num_parts = len(parts)

        if num_parts == 0:
            return False

        if kind == "catalog":
            if self.catalog is not None:
                return name == self.catalog == parts[0]
            # Fallback to schema if catalog is None
            kind = "schema"

        if kind == "schema":
            if num_parts < 3:
                return name == self.schema == parts[0]
            return name == self.schema == parts[1]

        # Otherwise, kind is "table" or "view", and we should check the ordering
        # and return accordingly
        if num_parts == 1:
            # Only table name provided
            return name == self.table == parts[0] and kind in (
                "table",
                "view",
                "any",
            )

        if num_parts == 2:
            # Format: schema.table or catalog.table
            # sqlglot cannot differentiate between schema and catalog
            # so we check if the qualifier matches either
            qualifier, table = parts
            # Try matching as schema or catalog
            if (self.schema, self.catalog) == (None, None):
                return name == self.table == table and kind in (
                    "table",
                    "view",
                    "any",
                )
            if qualifier not in (self.schema, self.catalog):
                return False

            return name in (
                self.catalog,
                self.schema,
                self.table,
            ) and kind in (
                "table",
                "view",
                "catalog",
                "schema",
                "any",
            )

        if num_parts == 3:
            # Format: catalog.schema.table
            catalog, schema, table = parts
            if self.catalog:
                if catalog != self.catalog:
                    return False
                if schema != self.schema:
                    return name == self.catalog and kind in ("catalog", "any")
            elif self.schema:
                if schema != self.schema:
                    return False
                return name == self.schema and kind in ("schema", "any")
            return name in (
                self.catalog,
                self.schema,
                self.table,
            ) and kind in (
                "table",
                "view",
                "catalog",
                "schema",
                "any",
            )

        return False

    def contains_hierarchical_ref(self, ref: str, kind: str) -> bool:
        if kind in ("table", "view"):
            return ref == self.table
        if kind == "catalog":
            return ref == self.catalog or ref == self.schema
        return False


def find_sql_refs(sql_statement: str) -> set[SQLRef]:
    """
    Find table and schema references in a SQL statement.

    Args:
        sql_statement: The SQL statement to parse.

    Returns:
        A set of unique SQLRefs, one for each table reference in the statement.
        Eg. SELECT * FROM schema1.test_table INNER JOIN schema2.test_table2
        would return two SQLRefs, one for the first table and one for the second.

    Note:
        When providing only a single qualification,
        DuckDB will interpret as either a catalog or a schema, as long as there are no conflicts.

        Eg. SELECT * FROM my_db.my_table, my_db can be a catalog or schema. If a catalog exists,
        then it would resolve to my_db.main.my_table.

        At the moment, we don't know this, so my_db is treated as a schema.
    """

    # Use sqlglot to parse ast (https://github.com/tobymao/sqlglot/blob/main/posts/ast_primer.md)

    DependencyManager.sqlglot.require(why="SQL parsing")

    from sqlglot import exp, parse
    from sqlglot.errors import ParseError
    from sqlglot.optimizer.scope import build_scope

    def get_ref_from_table(table: exp.Table) -> Optional[SQLRef]:
        # The variables might be empty strings, if they are, we set them to None
        table_name = table.name or None
        schema_name = table.db or None
        catalog_name = table.catalog or None

        if table_name is None:
            LOGGER.warning("Table name cannot be found in the SQL statement")
            return None

        # Check if the table name looks like a URL or has a file extension.
        # These are often not actual table references, so we skip them.
        # Note that they can be valid table names, but we skip them to avoid circular deps
        if "://" in table_name or table_name.endswith(COMMON_FILE_EXTENSIONS):
            return None

        return SQLRef(
            table=table_name, schema=schema_name, catalog=catalog_name
        )

    try:
        with _loggers.suppress_warnings_logs("sqlglot"):
            expression_list = parse(sql_statement, dialect="duckdb")
    except ParseError as e:
        LOGGER.error(f"Unable to parse SQL. Error: {e}")
        return set()

    refs: set[SQLRef] = set()

    for expression in expression_list:
        if expression is None:
            continue

        if bool(expression.find(exp.Update, exp.Insert, exp.Delete)):
            for table in expression.find_all(exp.Table):
                if ref := get_ref_from_table(table):
                    refs.add(ref)

        # build_scope only works for select statements
        if root := build_scope(expression):
            for scope in root.traverse():  # type: ignore
                for _node, source in scope.selected_sources.values():
                    if isinstance(source, exp.Table):
                        if ref := get_ref_from_table(source):
                            refs.add(ref)

    return refs
