import logging
from abc import ABC, abstractmethod
from collections.abc import Generator, Iterable
from dataclasses import dataclass
from typing import Any

from tree_sitter import Node, Tree
#from tree_sitter_languages import get_language, get_parser
from tree_sitter_language_pack import get_language, get_parser
#import tree_sitter_python

from synthegrator.lang_specs.lang_util import TreeSitterUtil


class LangSpec(ABC):
    _cls_tree_sitter_parser = None

    def __eq__(self, other):
        if not isinstance(other, LangSpec):
            return NotImplemented
        result = type(self) is type(other)
        return result

    @abstractmethod
    def truncate_function_completion(self, completion: str) -> str: ...

    def get_space_indent_size(self) -> int:
        return 4

    def is_expression_compatable_with(self, other: "LangSpec") -> bool:
        return type(self) == type(other)

    def get_method_name(self, method_src) -> str:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def split_before_last_method_name(prompt: str) -> tuple[str, str]:
        """Splits right before the last method declaration in a string."""
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def get_default_docker_env() -> str:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def get_lang_md_name() -> str:
        raise NotImplementedError

    @staticmethod
    @abstractmethod
    def get_tree_sitter_lang_name() -> str:
        raise NotImplementedError

    # Treesitter-based manipulation
    @classmethod
    def get_tree_sitter_parser(cls) -> str:
        if cls._cls_tree_sitter_parser is None:
            cls._cls_tree_sitter_parser = get_parser(
                cls.get_tree_sitter_lang_name(),
            )
        return cls._cls_tree_sitter_parser

    @classmethod
    def get_tree_sitter_tree(cls, src) -> Tree:
        parser = get_parser(cls.get_tree_sitter_lang_name())
        result: Tree = parser.parse(bytes(src, "utf8"))
        return result

    @classmethod
    @abstractmethod
    def get_prefix_comments(cls, src: str | Tree) -> Iterable["PrefixComment"]:
        raise NotImplementedError

    def get_syntax_errors(self, content: str) -> list["SyntaxError"]:
        tree = self.get_tree_sitter_tree(content)
        errors = []

        def traverse(node):
            if node.has_error:
                # Get the start position of the node
                start_line, start_column = node.start_point
                errors.append(SyntaxError("Syntax error", start_line, start_column))
            for child in node.children:
                traverse(child)

        traverse(tree.root_node)
        return errors

    def check_no_syntax_errors(self, content: str) -> bool:
        return len(self.get_syntax_errors(content)) == 0

    def dict_serialize(self) -> dict[str, Any]:
        return {
            "lang_md_name": self.get_lang_md_name(),
        }

    @classmethod
    def get_comment_line_start(cls):
        """
        A symbol that represents a comment line. It may not
        be the only way to start a comment line in the language,
        but is one valid way.
        """
        return "#"


@dataclass(frozen=True)
class SyntaxError:
    message: str
    line_idx: int | None
    column_idx: int | None


def lang_spec_for_path(path: str) -> LangSpec:
    # TODO: This should be specific for the problem
    if path.endswith(".py"):
        from synthegrator.lang_specs.lang_spec_python import PythonLangSpec

        return PythonLangSpec()
    return None


class PythonFunctionParser:
    # TODO: Merge with `Function` class.
    def __init__(self, function_source: str) -> None:
        self.language = get_language("python")
        self.parser = get_parser("python")
        self.function_source = function_source
        self.function_without_doc = None
        self.function_definition_with_doc = None
        self.function_definition_without_doc = None
        self.body_span_begin = -1
        self.body_span_end = -1
        self.has_docstring = False
        self.docstring_node: Node | None = None
        self.parse_source()

    def parse_source(self) -> Node:
        tree = self.parser.parse(bytes(self.function_source, "utf8"))
        root_node = tree.root_node

        classes = [
            node for node in tree.root_node.children if node.type == "class_definition"
        ]
        if classes:
            msg = "Class was defined at root! This class handles functions only."
            raise Exception(
                msg,
            )

        for idx, function_definition in enumerate(
            self.get_function_definitions(root_node),
        ):
            if idx > 0:
                msg = "Second function encountered!"
                raise Exception(msg)
            function_body = self.get_last_block(function_definition)
            docstring_node = self.get_docstring_node(function_body)
            if docstring_node is not None:
                self.has_docstring = True
                self.docstring_node = docstring_node

            self.function_without_doc = TreeSitterUtil.copy_tree_excluding_node(
                root_node,
                docstring_node,
                self.function_source,
            )

            self.function_definition_without_doc = (
                TreeSitterUtil.copy_tree_excluding_node(
                    root_node,
                    function_body,
                    self.function_source,
                )
            )

            self.body_expressions = self.get_non_docstring_nodes(function_body)

            if len(self.body_expressions) == 0:
                self.body_span_begin = docstring_node.end_point[0]
                self.body_span_end = docstring_node.end_point[0]
            else:
                self.body_span_begin = self.body_expressions[0].start_point[0]
                self.body_span_end = self.body_expressions[-1].end_point[0]

            self.function_definition_with_doc = (
                TreeSitterUtil.copy_tree_excluding_nodes(
                    root_node,
                    self.body_expressions,
                    self.function_source,
                ).rstrip()
            )

    def print_node_source(self, node: Node, source_code: str):
        if node.type:  # Check if the node is named (has a type)
            # Print the type of the node and its corresponding source code
            logging.info(source_code[node.start_byte : node.end_byte])

            # Iterate through the children and call the function recursively
            for child in node.children:
                self.print_node_source(child, source_code)

    def get_docstring_node(self, function_node) -> Node:
        first_node = function_node.children[0]
        if (
            first_node.type == "expression_statement"
            and first_node.children[0].type == "string"
        ):
            return first_node
        return None

    @staticmethod
    def get_function_definitions(node: Node) -> Generator[Node, Any, None]:
        if node.type == "function_definition":
            yield node
        for child in node.children:
            if child.type == "function_definition":
                yield child
            elif child.type == "decorated_definition":
                for c in child.children:
                    if c.type == "function_definition":
                        yield c

    def get_non_docstring_nodes(self, function_block: Node) -> list[Node]:
        docstring = self.get_docstring_node(function_block)
        return [
            node
            for node in function_block.children
            if not TreeSitterUtil.nodes_are_equal(node, docstring)
        ]

    def get_last_block(self, node: Node) -> Node:
        for i_node in reversed(node.children):
            if i_node.type == "block":
                return i_node
        msg = "No block found. Empty function?"
        raise Exception(msg)

    def copy_tree_excluding_node(
        self,
        node,
        skip: Node,
        source_code: str,
    ) -> str:
        return self.copy_tree_excluding_nodes(node, [skip], source_code)

    def copy_tree_excluding_nodes(
        self,
        node: Node,
        skip: list[Node],
        source_code: str,
    ) -> str:
        new_source_code = ""

        # Iterate through the children and recursively call the function.
        child_index = 0
        skipped = False
        for child in node.children:
            if any(TreeSitterUtil.nodes_are_equal(child, _skip) for _skip in skip):
                skipped = True
            else:
                # Add the source code between this child and the previous child (or the start of the parent).
                new_source_code += source_code[
                    node.start_byte + child_index : child.start_byte
                ]

                # Recursively handle the child.
                new_source_code += self.copy_tree_excluding_nodes(
                    child,
                    skip,
                    source_code,
                )

            # Update the index to the end of this child.
            child_index = child.end_byte - node.start_byte

        # Add the source code after the last child.
        if skipped:
            new_source_code = new_source_code.lstrip()

        new_source_code += source_code[node.start_byte + child_index : node.end_byte]

        return new_source_code


@dataclass(frozen=True)
class LsFunctionDef:
    """
    A wrapper around a function definition with useful functions common
    between languages.
    """

    node: Node  # A node of type `function_definition`
    tree: Tree  # The tree the node came from

    def get_body_node(self) -> Node:
        """Returns the node that corresponds to the body of the function"""
        raise NotImplementedError

    # @frozen_method_cache
    def get_body_src(
        self,
        include_body_indent: bool = True,
        include_trailing_new_line: bool = True,
        include_prefix_comment: bool = True,
    ) -> str:
        """
        Source of the method (potentially including docstring). The body node
        won't include the body indent (the leading whitespace). Instead that is part
        of the function declaration. The `include_body_indent` can be used to grab it
        
        :param include_prefix_comment: If False, excludes the prefix comment (docstring in Python) from the body
        """
        raise NotImplementedError

    def get_preceding_src_code(self):
        """Returns the source code before the function declaration"""
        raise NotImplementedError

    def get_src_code_after(self):
        """Returns the source code after the function"""
        raise NotImplementedError

    # @frozen_method_cache
    def get_function_name(self) -> str:
        """Returns the name identifier of the function"""
        raise NotImplementedError

    # @frozen_method_cache
    def get_declaration_str(self, include_body_indent: bool = False) -> str:
        r"""
        Returns the full declaration of the function (so the `def`, signature, args, and `:`

        :param include_body_indent: If false the return might have leading ident.
            For example "def foo():\n    pass" would return "def foo():\n    "
            with the indent. If false would return "def foo():\n"
        """
        raise NotImplementedError

    # @frozen_method_cache
    def get_full_function_src(self, zero_base_indent: bool = False) -> str:
        """Returns the entire source code of the function, optionally without leading indentation"""
        raise NotImplementedError

    # @frozen_method_cache
    def get_prefix_comment_str(self, as_plain_text: bool = False) -> str | None:
        """
        Gets the prefix comment (a docstring in Python)

        :param as_plain_text: Normally will include the language specific annotations
            (such as the triple quotes in Python). If true, will return the dedented plain text.
        """
        raise NotImplementedError


@dataclass(slots=True, frozen=True)
class PrefixComment:
    doc_string_node: Node
    function_node: Node

    @property
    def doc_string_str(self) -> str:
        return self.doc_string_node.text.decode()
