from collections.abc import Iterable
from dataclasses import dataclass

from tree_sitter import Language, Node, Tree
#from tree_sitter_languages import get_language
from tree_sitter_language_pack import get_language

from synthegrator.lang_specs.lang_spec import LangSpec, LsFunctionDef, PrefixComment
from synthegrator.sandboxing import JAVA_DEFAULT_DOCKER_ENV, DockerExecutionContext

JAVA_LANG_TS = get_language("java")


class JavaLangSpec(LangSpec):
    @staticmethod
    def get_tree_sitter_lang_name() -> str:
        return "java"

    def truncate_function_completion(self, completion: str) -> str:
        with_dummy = "public class Dummy {\n    public void dummy() {\n" + completion
        funcs = list(self.find_functions(with_dummy))
        if len(funcs) == 0:
            return completion
        func = funcs[0]
        body_src = func.get_body_src()
        if body_src.startswith("{") and not completion.startswith("{"):
            body_src = body_src[1:]
            if body_src.startswith("\n") and not completion.startswith("{\n"):
                body_src = body_src[1:]
        return body_src.rstrip()

    def get_method_name(self, method_src) -> str:
        raise NotImplementedError

    def get_lang_md_name(self) -> str:
        return "java"

    @staticmethod
    def split_before_last_method_name(code: str) -> tuple[str, str]:
        raise NotImplementedError

    @staticmethod
    def get_signature_text(method: str) -> str:
        raise NotImplementedError

    @staticmethod
    def get_default_docker_env() -> DockerExecutionContext:
        return JAVA_DEFAULT_DOCKER_ENV

    @classmethod
    def find_functions(
        cls,
        src: str | Tree,
        include_class_functions: bool = True,
        include_nested_functions: bool = True,
    ) -> Iterable["LsFunctionDefJava"]:
        """
        Finds all functions in the source code.
        :param include_class_functions: If true, will include functions defined inside classes.
        :param include_nested_functions: If true, will include functions defined inside other functions.
        """
        if not include_nested_functions:
            raise NotImplementedError
        if not include_class_functions:
            raise NotImplementedError
        tree = cls.get_tree_sitter_tree(src) if isinstance(src, str) else src
        captures = _java_func_def_query().captures(tree.root_node)
        for node in captures.get('function', []):
            yield LsFunctionDefJava(node=node, tree=tree)

    @classmethod
    def get_prefix_comments(cls, src: str | Tree) -> Iterable[PrefixComment]:
        raise NotImplementedError

    @staticmethod
    def get_tree_sitter_lang() -> Language:
        return JAVA_LANG_TS

    @classmethod
    def get_comment_line_start(cls):
        return "//"


@dataclass(frozen=True)
class LsFunctionDefJava(LsFunctionDef):
    """
    A wrapper around a function definition with useful functions common
    between languages.
    """

    def get_body_node(self) -> Node:
        """Returns the node that corresponds to the body of the function"""
        return self.node.children[-1]

    # @frozen_method_cache
    def get_body_src(
        self,
        include_body_indent: bool = True,
        include_trailing_new_line: bool = True,
    ) -> str:
        """
        Source of the method.

        These parameters don't really make sense for Java. Need to sort through that.
        Note that the starting { is included in the body.
        """
        self.get_body_node().text.decode()
        if not include_body_indent:
            raise NotImplementedError
        if not include_trailing_new_line:
            raise NotImplementedError
        return self.get_body_node().text.decode()

    def get_preceding_src_code(self):
        """Returns the source code before the function declaration"""
        raise NotImplementedError

    def get_src_code_after(self):
        """Returns the source code after the function"""
        raise NotImplementedError

    # @frozen_method_cache
    def get_function_name(self) -> str:
        """Returns the name identifier of the function"""
        raise NotImplementedError

    # @frozen_method_cache
    def get_declaration_str(self, include_body_indent: bool = False) -> str:
        r"""
        Returns the full declaration of the function (so the `def`, signature, args, and `:`

        :param include_body_indent: If false the return might have leading ident.
            For example "def foo():\n    pass" would return "def foo():\n    "
            with the indent. If false would return "def foo():\n"
        """
        raise NotImplementedError

    # @frozen_method_cache
    def get_full_function_src(self, zero_base_indent: bool = False) -> str:
        """Returns the entire source code of the function, optionally without leading indentation"""
        raise NotImplementedError

    # @frozen_method_cache
    def get_prefix_comment_str(self, as_plain_text: bool = False) -> str | None:
        """
        Gets the prefix comment (a docstring in Python)

        :param as_plain_text: Normally will include the language specific annotations
            (such as the triple quotes in Python). If true, will return the dedented plain text.
        """
        raise NotImplementedError


def _java_func_def_query():
    return JAVA_LANG_TS.query(
        """
        (method_declaration) @function
        """
    )
