from __future__ import annotations

import collections
import collections.abc
import dataclasses
import functools
import logging
import os
import re
import shlex
import subprocess
from collections.abc import Iterable, Iterator, Mapping, Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast

from mkdocstrings import BaseHandler, CollectionError, HandlerOptions

from . import inventory
from .items import DocConstant, DocItem, DocLocation, DocMapping, DocMethod, DocModule, DocType

try:
    from mkdocs.exceptions import PluginError
except ImportError:
    PluginError = SystemExit  # type: ignore[assignment, misc]

log = logging.getLogger(f"mkdocs.plugins.{__name__}")

if TYPE_CHECKING:
    _D = TypeVar("_D", bound=DocItem)


class CrystalCollector(BaseHandler):
    def __init__(
        self, crystal_docs_flags: Sequence[str] = (), source_locations: Mapping[str, str] = {}
    ):
        """Create a "collector", reading docs from `crystal doc` in the current directory.

        Normally this should not be instantiated.

        When using mkdocstrings-crystal within MkDocs, a plugin can access the instance as `config.plugins['mkdocstrings'].get_handler('crystal')`.

        See [Extras](extras.md).
        """
        command = [
            "crystal",
            "docs",
            "--format=json",
            "--project-name=",
            "--project-version=",
        ]
        if source_locations:
            command.append("--source-refname=master")
        command += (s.format_map(_crystal_info) for s in crystal_docs_flags)
        log.debug("Running `%s`", " ".join(shlex.quote(arg) for arg in command))

        self._proc = subprocess.Popen(command, stdout=subprocess.PIPE)

        # For unambiguous prefix match: add trailing slash, sort by longest path first.
        self._source_locations = sorted(
            (
                _SourceDestination(os.path.relpath(k) + os.sep, source_locations[k])
                for k in source_locations
            ),
            key=lambda d: -d.src_path.count("/"),
        )

    @cached_property
    def root(self) -> DocRoot:
        """The top-level namespace, represented as a fake module."""
        try:
            with self._proc:
                stdout = self._proc.stdout
                assert stdout is not None
                module = inventory.read(stdout)
            module.__class__ = DocRoot
            assert isinstance(module, DocRoot)
            module.source_locations = self._source_locations
            return module
        finally:
            if self._proc.returncode:
                args = cast("Sequence[str]", self._proc.args)
                cmd = " ".join(shlex.quote(arg) for arg in args)
                raise PluginError(f"Command `{cmd}` exited with status {self._proc.returncode}")

    def collect(self, identifier: str, options: HandlerOptions) -> DocView:
        """[Find][mkdocstrings_handlers.crystal.items.DocItem.lookup] an item by its identifier.

        Raises:
            CollectionError: When an item by that identifier couldn't be found.
        """
        item: DocItem = self.root
        if identifier != "::":
            item = item.lookup(identifier)
        return DocView(item, options)


@dataclasses.dataclass
class _SourceDestination:
    src_path: str
    dest_url: str

    def substitute(self, location: DocLocation) -> str:
        data = {"file": location.filename[len(self.src_path) :], "line": location.line}
        try:
            return self.dest_url.format_map(
                collections.ChainMap(data, _DictAccess(self), _crystal_info)  # type: ignore[arg-type]
            )
        except KeyError as e:
            raise PluginError(
                f"The source_locations template {self.dest_url!r} did not resolve correctly: {e}"
            )

    @property
    def shard_version(self) -> str:
        return self._shard_version(os.path.dirname(self.src_path))

    @classmethod
    @functools.cache
    def _shard_version(cls, path: str) -> str:
        file_path = _find_above(path, "shard.yml")
        with open(file_path, "rb") as f:
            m = re.search(rb"^version: *([\S+]+)", f.read(), flags=re.MULTILINE)
        if not m:
            raise PluginError(f"`version:` not found in {file_path!r}")
        return m[1].decode()


def _find_above(path: str, filename: str) -> str:
    orig_path = path
    while path:
        file_path = os.path.join(path, filename)
        if os.path.isfile(file_path):
            return file_path
        path = os.path.dirname(path)
    raise PluginError(f"{filename!r} not found anywhere above {os.path.abspath(orig_path)!r}")


class _CrystalInfo:
    @cached_property
    def crystal_version(self) -> str:
        return subprocess.check_output(
            ["crystal", "env", "CRYSTAL_VERSION"], encoding="ascii"
        ).rstrip()

    @cached_property
    def crystal_src(self) -> str:
        out = subprocess.check_output(["crystal", "env", "CRYSTAL_PATH"], text=True).rstrip()
        for path in out.split(os.pathsep):
            if os.path.isfile(os.path.join(path, "prelude.cr")):
                return os.path.relpath(path)
        raise PluginError(f"Crystal sources not found anywhere in CRYSTAL_PATH={out!r}")


class _DictAccess:
    def __init__(self, obj):
        self.obj = obj

    def __getitem__(self, key):
        try:
            return getattr(self.obj, key)
        except AttributeError as e:
            raise KeyError(f"Missing key: {e}")


_crystal_info = _DictAccess(_CrystalInfo())


class DocRoot(DocModule):
    source_locations: list[_SourceDestination]

    def update_url(self, location: DocLocation) -> DocLocation:
        for dest in self.source_locations:
            if (location.filename or "").startswith(dest.src_path):
                location.url = dest.substitute(location)
                break
        return location


class DocView:
    def __init__(self, item: DocItem, config: Mapping[str, Any]):
        self.item = item
        self.config = config

    def __getattr__(self, name: str):
        try:
            val = getattr(self.item, name)
            if isinstance(val, DocMapping) and val:
                if name == "types" and not self.config["nested_types"]:
                    return DocMapping(())
                return type(self)._filter(
                    self.config["file_filters"], val, type(self)._get_locations
                )
            return val
        except AttributeError as e:
            raise RuntimeError(e) from e

    def walk_types(self) -> Iterator[DocType]:
        types: DocMapping[DocType] = self.types
        for typ in types:
            yield typ
            yield from typ.walk_types()

    @classmethod
    def _get_locations(cls, obj: DocItem) -> Sequence[str]:
        if isinstance(obj, DocConstant):
            parent = obj.parent
            if not parent:
                return ()
            obj = parent
        if isinstance(obj, DocType):
            return [loc.filename for loc in obj.locations]
        elif isinstance(obj, DocMethod):
            if not obj.location:
                return ()
            return (obj.location.filename,)
        else:
            raise TypeError(obj)

    @classmethod
    def _filter(
        cls,
        filters: Sequence[str] | bool,  # noqa: FBT001
        mapp: DocMapping[_D],
        getter: Callable[[_D], Sequence[str]],
    ) -> DocMapping[_D]:
        if filters is False:
            return DocMapping(())
        if filters is True:
            return mapp
        try:
            re.compile(filters[0])
        except (TypeError, IndexError):
            raise CollectionError(
                f"Expected a non-empty list of strings as filters, not {filters!r}"
            )

        return DocMapping([item for item in mapp if _apply_filter(filters, getter(item))])


def _apply_filter(
    filters: Iterable[str],
    tags: Sequence[str],
) -> bool:
    match = False
    for filt in filters:
        filter_kind = True
        if filt.startswith("!"):
            filter_kind = False
            filt = filt[1:]
        if any(re.search(filt, s) for s in tags):
            match = filter_kind
    return match
