import asyncio
import hashlib
import json
import os
import platform
import re
import subprocess
import sys
import typing
from base64 import urlsafe_b64encode
from collections import defaultdict
from logging import getLogger
from pathlib import Path
from typing import Dict, List, Optional, Union, cast
from urllib.parse import urlparse

from importlib_metadata import Distribution, PackagePath, PathDistribution
from packaging.version import InvalidVersion
from packaging.version import parse as parse_version
from rich.progress import Progress
from typing_extensions import Literal

from coiled.types import CondaPackage, CondaPlaceHolder, PackageInfo
from coiled.utils import COILED_LOCAL_PACKAGE_PREFIX, parse_file_uri, recurse_importable_python_files

logger = getLogger("coiled.package_sync")
subdir_datas = {}
PYTHON_VERSION = platform.python_version_tuple()


class ResilientDistribution(PathDistribution):
    """Subclass of Distribution that adds more resilient methods for retrieving files"""

    def _read_files_egginfo_installed(self):
        """
        Read installed-files.txt and return lines in a similar
        CSV-parsable format as RECORD: each file should be placed
        relative to the site-packages directory and must be
        quoted (since file names can contain literal commas).

        This file is written when the package is installed by pip,
        but it might not be written for other installation methods.
        Assume the file is accurate if it exists.
        """
        text = self.read_text("installed-files.txt")
        # Prepend the .egg-info/ subdir to the lines in this file.
        # But this subdir is only available from PathDistribution's
        # self._path.
        subdir = self._path
        if not text or not subdir:
            return

        site_pkgs_path = cast(Path, self.locate_file("")).resolve()
        for name in text.splitlines():
            # relpath will add .. to a path to make it relative to site-packages,
            # so use that instead of Path.relative_to (which will raise an error)
            path = Path(os.path.relpath((subdir / name).resolve(), site_pkgs_path))
            yield f'"{path.as_posix()}"'


def normalize_name(name: str):
    # gives the best shot at names being consistent between pypi/conda
    return name.lower().replace("-", "_")


def normalize_version(version: str):
    """Attempt to normalize version between conda and pip without tripping over conda not supporting PEP440"""
    # Bail on non-PEP440 versions like 2023c, which will get parsed as 2023rc0
    if not len(version.split(".")) > 2:
        return version
    else:
        try:
            # Normalize things like 23.04.00 to 23.4.0
            return str(parse_version(version))
        # Fallback to original version if its unparseable like 1.7.1dev.rapidsai23.04
        except InvalidVersion:
            return version


async def scan_conda(prefix: Path, progress: Optional[Progress] = None) -> typing.Dict[str, List[PackageInfo]]:
    conda_meta = prefix / "conda-meta"
    if conda_meta.exists() and conda_meta.is_dir():
        conda_packages = [
            CondaPackage(json.load(metafile.open("r")), prefix=prefix)
            for metafile in conda_meta.iterdir()
            if metafile.suffix == ".json"
        ]
        packages: List[PackageInfo] = []
        if progress:
            for task in progress.track(
                asyncio.as_completed([handle_conda_package(pkg) for pkg in conda_packages]),
                description=f"Scanning {len(conda_packages)} conda packages",
                total=len(conda_packages),
            ):
                r = await task
                if r:
                    packages.append(r)
        else:
            packages = [
                pkg for pkg in await asyncio.gather(*[handle_conda_package(pkg) for pkg in conda_packages]) if pkg
            ]
        # it's possible for multiple similar packages to be "installed"
        # eg importlib-metadata & importlib_metadata
        # we have to check later which one is actually being imported
        result: Dict[str, List[PackageInfo]] = defaultdict(list)
        for pkg in sorted(packages, key=lambda pkg: (pkg["name"], pkg["conda_name"]), reverse=True):
            result[normalize_name(pkg["conda_name"] or pkg["name"])].append(pkg)
        return result
    else:
        return {}


async def handle_conda_package(pkg: CondaPackage) -> Optional[PackageInfo]:
    # Are there conda packages that install multiple python packages?
    metadata_location = next((pkg.prefix / Path(fp).parent for fp in pkg.files if fp.endswith("METADATA")), None)
    if metadata_location:
        if not metadata_location.exists():
            # a file for this package no longer exists
            # likely pip installed a new version
            # removing the conda installed version
            return None
        else:
            dist = ResilientDistribution(pkg.prefix / metadata_location)  # type: ignore
            name = dist.metadata["Name"] or pkg.name
    else:
        name = pkg.name
    return {
        "channel": pkg.channel,
        "path": None,
        "channel_url": pkg.channel_url,
        "source": "conda",
        "conda_name": pkg.name,
        "subdir": pkg.subdir,
        "name": name,
        "version": pkg.version,
        "wheel_target": None,
    }


async def handle_dist(dist: Distribution, locations: List[Path]) -> Optional[Union[PackageInfo, CondaPlaceHolder]]:
    # Sometimes the dist name is blank (seemingly only on Windows?)
    if not dist.name:
        return
    installer = dist.read_text("INSTALLER") or ""
    installer = installer.rstrip()
    # dist._path can sometimes be a zipp.Path or something else
    dist_path = Path(str(dist._path))  # type: ignore
    if installer == "conda":
        return CondaPlaceHolder(name=dist.name, path=dist_path)
    elif dist_path.parent.suffix == ".egg":
        # egg files can be a directory OR a zip file
        # the zipp implementation of Path always uses
        # linux style seperators so we strip them too
        return {
            "name": dist.name,
            "path": dist_path.parent,
            "source": "pip",
            "channel": None,
            "subdir": None,
            "channel_url": None,
            "conda_name": None,
            "version": dist.version,
            "wheel_target": str(dist_path.parent).rstrip(os.sep + "/"),
        }
    else:
        direct_url_metadata = dist.read_text("direct_url.json")
        if direct_url_metadata:
            url_metadata = json.loads(direct_url_metadata)
            url = url_metadata.get("url")
            if not url:
                # no url in this file
                # invalid PEP-610 so don't do anything
                pass
            elif url_metadata.get("vcs_info"):
                # PEP-610 Source is VCS
                vcs_info = url_metadata.get("vcs_info")
                vcs: Literal["git", "hg", "bzr", "svn"] = vcs_info["vcs"]
                commit = vcs_info["commit_id"]
                url = url_metadata["url"]
                pip_url = f"{vcs}+{url}@{commit}"
                return {
                    "name": dist.name,
                    "path": dist_path,
                    "source": "pip",
                    "channel": None,
                    "channel_url": None,
                    "subdir": None,
                    "conda_name": None,
                    "version": dist.version,
                    "wheel_target": pip_url,
                }
            elif str((Path("pypoetry") / "artifacts")) in url_metadata["url"]:
                # if the install source is actually the pre 1.2 poetry cache location
                # they this is actually just normal a pypi
                # and we can ignore direct_url.json
                pass
            elif url_metadata.get("archive_info") is not None:
                # PEP-610 - Source is an archive/wheel, somewhere!
                p = urlparse(url)
                if p.scheme == "file":
                    url = str(parse_file_uri(url))
                return {
                    "name": dist.name,
                    "path": dist_path,
                    "source": "pip",
                    "channel": None,
                    "channel_url": None,
                    "subdir": None,
                    "conda_name": None,
                    "version": dist.version,
                    "wheel_target": url,
                }
            elif url_metadata.get("dir_info") is not None:
                # PEP-610 - Source is a local directory
                path = parse_file_uri(url)
                return {
                    "name": dist.name,
                    "path": path,
                    "source": "pip",
                    "channel": None,
                    "channel_url": None,
                    "subdir": None,
                    "conda_name": None,
                    "version": dist.version,
                    "wheel_target": str(path),
                }
        egg_links = []
        for location in locations:
            egg_link_pth = location / Path(dist.name).with_suffix(".egg-link")
            if egg_link_pth.is_file():
                egg_links.append(location / Path(dist.name).with_suffix(".egg-link"))
        if egg_links:
            return {
                "name": dist.name,
                "path": dist_path.parent,
                "source": "pip",
                "channel": None,
                "channel_url": None,
                "subdir": None,
                "conda_name": None,
                "version": dist.version,
                "wheel_target": str(dist_path.parent),
            }
        return {
            "name": dist.name,
            "path": dist_path,
            "source": "pip",
            "channel": None,
            "channel_url": None,
            "subdir": None,
            "conda_name": None,
            "version": dist.version,
            "wheel_target": None,
        }


def _is_hash_match(dist: Distribution, pkg_paths: Dict[str, PackagePath], path: str):
    dist_path = Path(str(dist._path)).parent  # type: ignore
    pkg_path = pkg_paths.get(path)
    if pkg_path is not None and pkg_path.hash is not None:  # type: ignore
        pkg_hash = pkg_path.hash  # type: ignore
        hash_func = getattr(hashlib, pkg_hash.mode)
        absolute_path = dist_path / pkg_path
        if absolute_path.exists() and absolute_path.is_file():
            with absolute_path.open("rb") as f:
                actual_hash = urlsafe_b64encode(hash_func(f.read()).digest()).strip(b"=").decode()
                if actual_hash == pkg_hash.value:
                    return True
    return False


async def scan_pip(
    locations: List[Path], progress: Optional[Progress] = None
) -> typing.Dict[str, Union[PackageInfo, CondaPlaceHolder]]:
    # distributions returns ALL distributions
    # even ones that are not active
    # this is a trick so we only get the distribution
    # that is last in stack
    locations = [location for location in locations if location.exists() and location.is_dir()]
    paths: List[str] = [str(location) for location in locations]
    for location in locations:
        for fp in location.iterdir():
            if fp.suffix in [".pth", ".egg-link"]:
                for line in fp.read_text().split("\n"):
                    if line.startswith("#"):
                        continue
                    elif line.startswith(("import", "import\t")):
                        continue
                    elif line.rstrip() == ".":
                        continue
                    else:
                        p = location / Path(line.rstrip())
                        full_path = str(p.resolve())
                        if p.exists() and full_path not in paths:
                            paths.append(full_path)
    # can't use ResilientDistribution here properly without monkey patching it
    dists: List[Distribution] = [dist for dist in Distribution.discover(path=list(paths))]
    packages = []
    if progress:
        for task in progress.track(
            asyncio.as_completed([handle_dist(dist, locations) for dist in dists]),
            total=len(dists),
            description=f"Scanning {len(dists)} python packages",
        ):
            packages.append(await task)
    else:
        packages = await asyncio.gather(*(handle_dist(dist, locations) for dist in dists))

    # Resolve duplicate packages
    pkgs_by_name = {}
    for pkg in packages:
        if pkg:
            pkg_name = normalize_name(pkg["name"])
            # For duplicate .dist-info directories, we need to check which
            # version is actually importable
            existing_pkg = pkgs_by_name.get(pkg_name)
            if existing_pkg is None:
                pkgs_by_name[pkg_name] = pkg
            else:
                # Compare hashes to actual files
                new_dist = ResilientDistribution(pkg["path"])
                old_dist = ResilientDistribution(existing_pkg["path"])
                new_dist_path = new_dist._path
                old_dist_path = old_dist._path
                new_is_egg_info = new_dist_path.name.endswith(".egg-info")  # type: ignore
                old_is_egg_info = old_dist_path.name.endswith(".egg-info")  # type: ignore
                new_is_dist_info = new_dist_path.name.endswith(".dist-info")  # type: ignore
                old_is_dist_info = old_dist_path.name.endswith(".dist-info")  # type: ignore
                if (new_is_egg_info and not old_is_egg_info) or (new_is_dist_info and not old_is_dist_info):
                    continue
                if (not new_is_egg_info and old_is_egg_info) or (not new_is_dist_info and old_is_dist_info):
                    pkgs_by_name[pkg_name] = pkg
                    continue

                if new_is_egg_info and old_is_egg_info:
                    # This should never happen
                    logger.debug(
                        "Found two egg-info directories with the same name: %s and %s", new_dist_path, old_dist_path
                    )

                new_pkg_paths = {
                    str(f): f
                    for f in (new_dist.files or [])
                    if (
                        not f.name.endswith(".pyc")
                        and f.parent.name != new_dist_path.name  # type: ignore
                        and f.hash is not None  # type: ignore
                    )
                }
                old_pkg_paths = {
                    str(f): f
                    for f in (old_dist.files or [])
                    if (
                        not f.name.endswith(".pyc")
                        and f.parent.name != old_dist_path.name  # type: ignore
                        and f.hash is not None  # type: ignore
                    )
                }
                old_paths = set(old_pkg_paths.keys())
                new_paths = set(new_pkg_paths.keys())
                same_hashes = {
                    path
                    for path in old_paths.intersection(new_paths)
                    if old_pkg_paths[path].hash.mode == new_pkg_paths[path].hash.mode  # type: ignore
                    and old_pkg_paths[path].hash.value == new_pkg_paths[path].hash.value  # type: ignore
                }
                paths_to_check = new_paths.union(old_paths) - same_hashes

                for path in paths_to_check:
                    # Since we are only checking files that have different hashes,
                    # we can just assume the new version is correct on the first
                    # match.
                    if _is_hash_match(new_dist, new_pkg_paths, path):
                        pkgs_by_name[pkg_name] = pkg
                        break
                    if not _is_hash_match(old_dist, old_pkg_paths, path):
                        logger.debug("Encountered path that does not match either version: %s", path)

    return pkgs_by_name


async def scan_prefix(
    prefix: Optional[Path] = None,
    base_prefix: Optional[Path] = None,
    progress: Optional[Progress] = None,
    locations: Optional[List[Path]] = None,
) -> typing.List[PackageInfo]:
    # TODO: private conda channels
    # TODO: detect pre-releases and only set --pre flag for those packages (for conda)

    if not prefix:
        prefix = Path(sys.prefix).resolve()
    if not base_prefix:
        base_prefix = Path(sys.base_prefix).resolve()
    if not locations:
        locations = [Path(p).resolve() for p in sys.path]
        cwd = Path.cwd().resolve()
        if cwd not in locations:
            locations.insert(0, cwd)
    conda_env_future = asyncio.create_task(scan_conda(prefix=prefix, progress=progress))
    # only pass locations to support testing, otherwise we should be using sys.path
    pip_env_future = asyncio.create_task(scan_pip(locations=locations, progress=progress))
    conda_env = await conda_env_future
    pip_env = await pip_env_future
    filtered_conda = {}
    # the pip list is the "truth" of what is imported for python deps
    for name, packages in conda_env.items():
        # if a package exists in the pip list but is not a conda place holder
        # then the conda package wont be imported and should be discarded
        found = False
        # we need to check both the conda name and the pypi name for each conda package
        # that claims to have this pypi name
        for possible_name in [name] + [p["name"] for p in packages]:
            if pip_env.get(possible_name):
                found = True
                pip_package = pip_env[possible_name]
                if isinstance(pip_package, CondaPlaceHolder):
                    # find the conda package that actually matches with what is importable
                    importable_package = next(
                        p for p in packages if normalize_name(p["name"]) == normalize_name(pip_package["name"])
                    )
                    filtered_conda[name] = importable_package
                    break
                elif next(
                    (
                        p
                        for p in packages
                        if normalize_version(p["version"]) == normalize_version(pip_package["version"])
                    ),
                    None,
                ):
                    # if the versions match, we can fall back to using the conda version
                    pip_env.pop(possible_name, None)
                    filtered_conda[name] = next(
                        (
                            p
                            for p in packages
                            if normalize_version(p["version"]) == normalize_version(pip_package["version"])
                        )
                    )
                    break
        if not found:
            # a non python package and safe to include
            filtered_conda[name] = packages[0]
    # remove conda placeholders
    pip_env = {pkg_name: pkg for pkg_name, pkg in pip_env.items() if not isinstance(pkg, CondaPlaceHolder)}
    results = sorted(
        list(pip_env.values()) + list(filtered_conda.values()),
        key=lambda pkg: pkg["name"],
    )
    # get set of urls for all packages that were installed via pip
    pkg_urls = set()
    for pkg in results:
        if pkg["wheel_target"]:
            url = pkg["wheel_target"]
            if "http://" or "https://" in url:
                url = re.sub(r"^(git|hg|svn|bzr)\+", "", url)
                url = re.sub(r"@\w+$", "", url)
                url = url.rstrip("/")
                pkg_urls.add(url)

    # Handle modules that are not installed via pip or conda
    pkg_paths = {pkg["path"].resolve() for pkg in results if pkg["path"]}
    extra_paths = {p for p in locations if prefix not in p.parents and base_prefix not in p.parents} - pkg_paths
    for extra_path in extra_paths:
        if not extra_path.is_dir():
            continue
        if any(recurse_importable_python_files(extra_path)):
            # Skip directories that are the same as a package that was installed via pip
            git_dir = extra_path / ".git"
            if git_dir.exists():
                try:
                    encoding = sys.stdout.encoding or "utf-8"
                    branch = subprocess.check_output(
                        ["git", "branch", "--show-current"], cwd=extra_path, encoding=encoding
                    ).strip()
                    remote = subprocess.check_output(
                        ["git", "branch", "--list", branch, "--format=%(upstream:remotename)"],
                        cwd=extra_path,
                        encoding=encoding,
                    ).strip()
                    origin_url = subprocess.check_output(
                        ["git", "remote", "get-url", remote], cwd=extra_path, encoding=encoding
                    ).strip()
                except subprocess.CalledProcessError:
                    origin_url = ""
                if origin_url.startswith("git@"):
                    origin_url = "https://" + origin_url[4:].replace(":", "/")
                if origin_url.endswith(".git"):
                    origin_url = origin_url[:-4]
                origin_url = origin_url.rstrip("/")
                if origin_url in pkg_urls:
                    continue

            cleaned_name = re.sub(r"[^\w\d.]+", "_", extra_path.name, re.UNICODE)
            if cleaned_name.startswith("_"):
                cleaned_name = cleaned_name[1:]
            results.append(
                {
                    "name": COILED_LOCAL_PACKAGE_PREFIX + cleaned_name,
                    "path": extra_path,
                    "source": "pip",
                    "version": "0.0.0",
                    "channel_url": None,
                    "channel": None,
                    "subdir": None,
                    "conda_name": None,
                    "wheel_target": str(extra_path),
                }
            )
    return sorted(results, key=lambda pkg: pkg["name"])
