import re
import textwrap
from collections.abc import Iterable
from dataclasses import dataclass
from pathlib import Path

import libcst as cst
from tree_sitter import Language, Node, Tree
#from tree_sitter_languages import get_language, get_parser
from tree_sitter_language_pack import get_language, get_parser

from synthegrator.lang_specs.lang_spec import (
    LangSpec,
    LsFunctionDef,
    PrefixComment,
    PythonFunctionParser,
)
from synthegrator.lang_specs.lang_util import truncate_when_decreases_indent
from synthegrator.sandboxing import PY_DEFAULT_DOCKER_ENV, DockerExecutionContext
from synthegrator.util import lazy_prop


def assert_spec_is_python(spec: LangSpec):
    if not isinstance(spec, PythonLangSpec):
        msg = "Only python language supported"
        raise TypeError(msg)


class PythonLangSpec(LangSpec):
    @staticmethod
    def get_tree_sitter_lang_name() -> str:
        return "python"

    def truncate_function_completion(self, completion: str) -> str:
        return truncate_when_decreases_indent(completion)

    def get_method_name(self, method_src) -> str:
        import ast

        tree = ast.parse(method_src)
        if not isinstance(tree, ast.Module):
            msg = "Tree must be of ast.Module type"
            raise TypeError(msg)
        if len(tree.body) != 1:
            msg = "More than one tree body"
            raise ValueError(msg)
        func_def = tree.body[0]
        if not isinstance(func_def, ast.FunctionDef):
            msg = "Function definition must be of ast.FunctionDef type"
            raise TypeError(msg)
        return func_def.name

    def get_lang_md_name(self) -> str:
        return "python"

    @staticmethod
    def split_before_last_method_name(code: str) -> tuple[str, str]:
        # method dec when line starts with def and ends with and ends with :
        lines = code.split("\n")
        for i, line in enumerate(reversed(lines)):
            if line.startswith("def") and line.endswith(":"):
                before = "\n".join(lines[: -i - 1])
                if len(lines[: -i - 1]) > 0:
                    before += "\n"
                method = "\n".join(lines[-i - 1 :])
                return before, method
        msg = "No method declaration found"
        raise RuntimeError(msg)

    @staticmethod
    def get_signature_text(method: str) -> str:
        module = cst.parse_module(method)
        func_def = module.body[0]
        if not isinstance(func_def, cst.FunctionDef):
            msg = "Function definition not of cst.FunctionDef type"
            raise TypeError(msg)

        body_code = module.code_for_node(func_def.body)
        # return everything before the body
        return method[: method.index(body_code)].strip()

    @staticmethod
    def get_default_docker_env() -> DockerExecutionContext:
        return PY_DEFAULT_DOCKER_ENV

    @classmethod
    def find_functions(
        cls,
        src: str | Tree,
        include_class_functions: bool = True,
        include_nested_functions: bool = True,
    ) -> Iterable["LsFunctionDefPython"]:
        """
        Finds all functions in the source code.
        :param include_class_functions: If true, will include functions defined inside classes.
        :param include_nested_functions: If true, will include functions defined inside other functions.
        """
        if not include_nested_functions:
            raise NotImplementedError
        if not include_class_functions:
            raise NotImplementedError
        tree = cls.get_tree_sitter_tree(src) if isinstance(src, str) else src
        captures = _py_function_def_query().captures(tree.root_node)
        for node in captures.get('function', []):
            yield LsFunctionDefPython(node=node, tree=tree)

    @classmethod
    def get_prefix_comments(cls, src: str | Tree) -> Iterable[PrefixComment]:
        tree = cls.get_tree_sitter_tree(src) if isinstance(src, str) else src
        captures = _py_prefix_comments_query().captures(tree.root_node)
        if len(captures) == 0:
            return
        for captured_node in captures['docstring']:
            yield PrefixComment(
                doc_string_node=captured_node,
                function_node=captured_node.parent.parent.parent,
            )

    @staticmethod
    def get_tree_sitter_lang() -> Language:
        return PY_LANG_TS


@lazy_prop
def _py_prefix_comments_query():
    return PY_LANG_TS.query(
        """
        (function_definition
          (block .
            (expression_statement
              (string) @docstring)))
        """,
    )


@lazy_prop
def _py_function_def_query():
    return PY_LANG_TS.query(
        """
        (function_definition) @function
        """
    )


@lazy_prop
def _py_docstring_query_in_func():
    return PY_LANG_TS.query(
        """
        (expression_statement
          (string) @docstring
          (#not-has-children? @docstring)
        )
        """
    )


@lazy_prop
def _docstring_pattern():
    return re.compile(r'^\s*("""[\s\S]*?"""|\'\'\'[\s\S]*?\'\'\')')



class PythonModuleParser:
    def __init__(self, module_source: str) -> None:
        self.language = get_language("python")
        self.parser = get_parser("python")
        self.module_source = module_source
        self.function_without_doc = None
        self.function_definition_with_doc = None
        self.function_definition_without_doc = None
        self.body_span_begin = -1
        self.body_span_end = -1
        self.has_docstring = False
        self.parse_source()


PY_LANG_TS = get_language("python")


def _main():
    """Test"""
    langspec_py = PythonLangSpec()
    source_code = Path(__file__).read_text()
    tree = langspec_py.get_tree_sitter_tree(source_code)
    root_node = tree.root_node

    if root_node.type != "module":
        msg = "Module not provided."
        raise Exception(msg)

    for node in root_node.children:
        if node.type in {"decorated_definition", "function_definition"}:
            func_parse = PythonFunctionParser(node.text.decode())
            print((func_parse.function_definition_with_doc or "") + "\n    ...")
        if node.type == "class_definition":
            print(
                source_code[node.children[0].start_byte : node.children[-2].end_byte],
            )
            for cnode in node.children[-1].children:
                if cnode.type in {
                    "decorated_definition",
                    "function_definition",
                }:
                    func_parse = PythonFunctionParser(cnode.text.decode())
                    print(func_parse.function_definition_with_doc + "\n    ...")
                if cnode.type == "comment":
                    print(cnode.text.decode())


@dataclass(frozen=True)
class LsFunctionDefPython(LsFunctionDef):
    """
    A wrapper around a function definition with useful functions common
    between languages.

    Note: this python specific. When expand language specs will need to
    make subclasses following this API.

    For reference, here is the treesitter grammar for function definitions:
    https://github.com/tree-sitter/tree-sitter-python/blob/c01fb4e/src/grammar.json#L1582-L1671
    """

    def get_body_node(self) -> Node:
        """Returns the node that corresponds to the body of the function"""
        return self.node.children[-1]

    # @frozen_method_cache
    def get_body_src(
        self,
        include_body_indent: bool = True,
        include_trailing_new_line: bool = True,
        include_prefix_comment: bool = True,
    ) -> str:
        """
        Source of the method (potentially including docstring). The body node
        won't include the body indent (the leading whitespace). Instead that is part
        of the function declaration. The `include_body_indent` can be used to grab it
        
        :param include_prefix_comment: If False, excludes the prefix comment (docstring in Python) from the body
        """
        declaration_src = self.get_declaration_str(
            include_body_indent=not include_body_indent,
        )
        full_src = self.get_full_function_src()
        src = full_src[len(declaration_src) :]
        
        # If we need to exclude the prefix comment (docstring), remove it
        if not include_prefix_comment:
            prefix_comment = self.get_prefix_comment_str()
            if prefix_comment is not None:
                # Use the shared regex pattern to find and remove the docstring
                match = _docstring_pattern().match(src)
                if match:
                    docstring_with_indent = match.group(0)
                    # Remove the docstring and any trailing newlines
                    src = src[len(docstring_with_indent):].lstrip('\n')
                    # Preserve the original indentation for the remaining code
                    if src and not include_body_indent:
                        # If we don't want body indent, keep as is
                        pass
                    # If there's remaining content, we're good
                    # If no remaining content, src will be empty which is correct
        
        if include_trailing_new_line:
            src += "\n"
        return src

    def get_preceding_src_code(self):
        return self.tree.root_node.text[: self.node.start_byte].decode()

    def get_src_code_after(self):
        """Returns the source code after the function"""
        return self.tree.root_node.text[self.node.end_byte :].decode()

    # @frozen_method_cache
    def get_function_name(self) -> str:
        """Returns the name identifier of the function"""
        has_async = self.node.children[0].type == "async"
        if has_async:
            # async, def, name
            return self.node.children[2].text.decode()
        # def, name
        return self.node.children[1].text.decode()

    # @frozen_method_cache
    def get_declaration_str(self, include_body_indent: bool = False) -> str:
        r"""
        Returns the full declaration of the function (so the `def`, signature, args, and `:`

        :param include_body_indent: If false the return might have leading ident.
            For example "def foo():\n    pass" would return "def foo():\n    "
            with the indent. If false would return "def foo():\n"
        """
        body_node = self.get_body_node()
        start_byte = self.node.start_byte
        end_byte = body_node.start_byte
        text = self.tree.root_node.text[start_byte:end_byte].decode()
        if include_body_indent:
            return text
        last_colon_idx = text.rfind(":")
        has_newline = "\n" in text[last_colon_idx:]
        if has_newline:
            return text[: last_colon_idx + 2]
        return text[: last_colon_idx + 1]

    # @frozen_method_cache
    def get_full_function_src(self, zero_base_indent: bool = False) -> str:
        """Returns the entire source code of the function, optionally without leading indentation"""
        src = self.node.text.decode()
        if not zero_base_indent:
            return src
        return textwrap.dedent(src)

    # @frozen_method_cache
    def get_prefix_comment_str(self, as_plain_text: bool = False) -> str | None:
        """
        Gets the prefix comment (a docstring in Python)

        :param as_plain_text: Normally will include the language specific annotations
            (such as the triple quotes in Python). If true, will return the dedented plain text.

        Body spec (a bit tricky, so we will just do it with string regex):
        https://github.com/tree-sitter/tree-sitter-python/blob/c01fb4e3/src/grammar.json#L2178-L2213
        """
        body_src = self.get_body_src()
        if len(body_src) == 0:
            return None

        # Use the shared regex pattern for docstring detection
        match = _docstring_pattern().match(body_src)
        if match:
            docstring = match.group(1)
            if as_plain_text:
                # remove quotes and dedent
                unquoted = docstring[3:-3]
                return textwrap.dedent(unquoted)
            return docstring

        # Return None if no docstring was found
        return None


if __name__ == "__main__":
    _main()
