import threading
import io
import pathlib
import tarfile
import typing
from collections import defaultdict
from collections.abc import Iterable, Iterator
from typing import Union

import xxhash


def _defaultdict_recursive():
    return defaultdict(_defaultdict_recursive)


class ProjFsNode:
    def __init__(self, dirty: bool):
        self._dirty = dirty

    def __getitem__(self, item):
        raise NotImplementedError

    @property
    def dirty(self):
        return self._dirty

    def set_node(
        self,
        fn: str,
        value: "ProjFsNode",
    ) -> "ProjFsNode":
        raise NotImplementedError

    def get_hash(self):
        if not hasattr(self, "_hash"):
            hasher = xxhash.xxh64()
            hasher.update(self._name)
            hasher.update(self._content)
            self._hash = hasher.hexdigest()
        return self._hash

    def get_size(self, recursive: bool = False) -> int:
        """
        Get the size of this node in bytes.
        
        Args:
            recursive: If True and this is a directory, include all children recursively.
                      If False, only count this node's direct content.
        
        Returns:
            Size in bytes
        """
        raise NotImplementedError

    def __hash__(self):
        return hash(self.get_hash())


class ProjectDir(ProjFsNode):
    """
    Represents a directory. It is a persistent data structure (mutations to
    it return a new instance).
    """

    _contents: dict[str, ProjFsNode]

    def __init__(
        self,
        contents: dict[str, ProjFsNode],
        dirty: bool = False,
    ) -> None:
        super().__init__(dirty=dirty)
        self._contents = contents

    @classmethod
    def construct_with_files(
        cls,
        files: Iterable[tuple[str, bytes | str]],
    ) -> "ProjectDir":
        """Creates a new ProjectDir with the given files added. Works with nested paths."""
        # data structure to store the file hierarchy
        root = _defaultdict_recursive()

        for path, content in files:
            parts = path.split("/")

            # navigate through the file hierarchy and set the content
            d = root
            for part in parts[:-1]:
                d = d[part]
            d[parts[-1]] = content

        # convert data structure to ProjectDir
        def create_dir(d):
            contents = {}
            for name, val in d.items():
                if isinstance(val, dict):
                    contents[name] = create_dir(
                        val,
                    )  # recursive call for nested dirs
                else:
                    # base case: create a ProjectFile for each file content
                    contents[name] = ProjectFile(name, val)
            return ProjectDir(contents)

        return create_dir(root)

    @classmethod
    def construct_with_one_file(
        cls,
        path: str,
        content: bytes | str,
    ) -> "ProjectDir":
        """Constructs a ProjectDir with a single file."""
        return cls.construct_with_files([(path, content)])

    def __getitem__(self, fn: str) -> Union["ProjectDir", "ProjectFile"]:
        if "/" in fn:
            d, rest = split_fn(fn)
            return self._contents[d][rest]
        return self._contents[fn]

    def __contains__(self, fn: str) -> bool:
        if "/" in fn:
            d, rest = split_fn(fn)
            return d in self._contents and rest in self._contents[d]
        return fn in self._contents

    def __iter__(self):
        yield from self._contents.values()

    def set_node(
        self,
        fn: str,
        value: ProjFsNode,
        clean_value_ok: bool = False,
    ) -> "ProjectDir":
        """
        Returns a new ProjectDir with the given file set to the given value.

        :param clean_value_ok: By default attempting to set a node with a value
            marked as clean will switch it to a dirty. If it is a directory
            it will raise an error. Set false if setting a clean value is intentional.
        """
        if not value.dirty and not clean_value_ok:
            if isinstance(value, ProjectFile):
                value = ProjectFile(value.name, value.content, dirty=True)
            else:
                msg = (
                    "Attempting to set a node with a clean value. Expected it "
                    "to now be dirty unless clean_value_ok=True"
                )
                raise ValueError(msg)
        new_contents = self._contents.copy()
        first, rest = split_fn(fn)
        if rest:
            new_contents[first] = self._contents[first].set_node(rest, value)
        else:
            if isinstance(value, ProjectFile) and value.name != first:
                msg = "Node value name does not match first"
                raise ValueError(msg)
            new_contents[first] = value
        return ProjectDir(new_contents, dirty=value.dirty)

    def set_all(
        self,
        new_files: "ProjectDir",
        dirty: bool = True,
    ) -> "ProjectDir":
        """
        Sets all files and directories from another ProjectDir into this one
        recursively, overwriting existing content.
        """
        new_contents = self._contents.copy()

        for name, node in new_files._contents.items():
            if (
                name in new_contents
                and isinstance(node, ProjectDir)
                and isinstance(new_contents[name], ProjectDir)
            ):
                new_contents[name] = new_contents[name].set_all(node)
            else:
                if not dirty and node.dirty:
                    msg = (
                        "Attempting to set a node with a dirty value. "
                        "Expected it to now be clean unless dirty=True"
                    )
                    raise ValueError(msg)
                new_contents[name] = node

        return ProjectDir(new_contents, dirty=dirty)

    def set_file_contents(
        self,
        path: str,
        content: bytes | str,
        dirty: bool = True,
    ) -> "ProjectDir":
        """Returns a new ProjectDir with the given file set to the given value."""
        fn = path.split("/")[-1]
        new_file = ProjectFile(fn, content, dirty=dirty)
        return self.set_node(path, new_file)

    def get_only_file(self, only_consider_dirty: bool = False) -> "ProjectFile":
        """
        Gets the single file in this directory (or descendent files). If
        there is not exactly one file, raises an AssertionError.
        """
        files = list(
            self.walk(include_dirs=False, only_consider_dirty=only_consider_dirty),
        )
        if len(files) != 1:
            msg = f"Expected exactly one file, got {len(files)}"
            raise AssertionError(msg)
        return files[0][1]

    def walk(
        self,
        include_dirs: bool = False,
        only_consider_dirty: bool = False,
    ) -> Iterator[tuple[str, ProjFsNode]]:
        """Walks the directory tree, yielding (path, node) pairs."""
        for name, node in self._contents.items():
            if only_consider_dirty and not node.dirty:
                continue
            if isinstance(node, ProjectDir):
                if include_dirs:
                    yield name, node
                yield from (
                    (f"{name}/{path}", node)
                    for path, node in node.walk(
                        include_dirs=include_dirs,
                        only_consider_dirty=only_consider_dirty,
                    )
                )
            else:
                yield name, node

    def pretty_print(self, indent=0, recursive=True, file: typing.IO | None = None):
        """Pretty print the directory contents."""
        for name, node in self._contents.items():
            print(f"{' ' * indent}{name}{' (Dirty)' if node.dirty else ''}", file=file)
            if recursive and isinstance(node, ProjectDir):
                node.pretty_print(indent=indent + 2, recursive=recursive, file=file)

    def pretty_str(self):
        file_str = io.StringIO()
        self.pretty_print(file=file_str)
        return file_str.getvalue()

    def contains_only_leafs(self) -> bool:
        return all(isinstance(v, ProjectFile) for v in self._contents.values())

    def _add_to_tar(self, tar, path=""):
        """Recursive function to add ProjectFile or ProjectDir to tar."""
        for name, node in self._contents.items():
            if isinstance(node, ProjectFile):
                tarinfo = tarfile.TarInfo(name=path + node.name)
                tarinfo.size = len(node.content)
                tar.addfile(tarinfo, io.BytesIO(node.content))
            elif isinstance(node, ProjectDir):
                node._add_to_tar(tar, path + name + "/")
            else:
                msg = f"Unexpected type {type(node)}"
                raise TypeError(msg)

    def convert_to_tar(self, gzip: bool = False) -> bytes:
        """Converts the directory to a tarball."""
        tar_data = io.BytesIO()
        with tarfile.open(
            fileobj=tar_data,
            mode="w:gz" if gzip else "w",
        ) as tar:
            self._add_to_tar(tar)
        return tar_data.getvalue()

    @classmethod
    def from_path(cls, path: str | pathlib.Path) -> "ProjectDir":
        if isinstance(path, str):
            path = pathlib.Path(path)

        dir_content = {}
        for item in path.iterdir():
            if item.is_file():
                with item.open("rb") as f:
                    content = f.read()
                    dir_content[item.name] = ProjectFile(
                        name=item.name,
                        content=content,
                    )
            elif item.is_dir():
                dir_content[item.name] = cls.from_path(item)
        return cls(contents=dir_content)

    def expand_glob(self, glob: str) -> Iterator[str]:
        """Expands a glob pattern to a list of files."""
        if "*" not in glob and glob in self:
            yield glob
            return
        if "**" in glob:
            msg = "Recursive glob not implemented"
            raise NotImplementedError(msg)
        if "*" in glob:
            msg = "Non-recursive glob not implemented"
            raise NotImplementedError(msg)

    def __eq__(self, other):
        if not isinstance(other, ProjectDir):
            return False
        return self._contents == other._contents

    def get_hash(self):
        if (not hasattr(self, "_hash")) or self._hash is None:
            hasher = xxhash.xxh64()
            for name, node in sorted(self._contents.items()):
                hasher.update(name)
                hasher.update(node.get_hash())
            self._hash = hasher.hexdigest()
        return self._hash

    def get_size(self, recursive: bool = False) -> int:
        """
        Get the size of this directory in bytes.
        
        Args:
            recursive: If True, include all children recursively.
                      If False, return 0 (directories have no direct content).
        
        Returns:
            Size in bytes
        """
        if not recursive:
            return 0  # Directories themselves have no size
        
        total_size = 0
        for node in self._contents.values():
            total_size += node.get_size(recursive=True)
        return total_size

    def __hash__(self):
        return hash(self.get_hash())

    def __repr__(self):
        return f"ProjectDir({len(self._contents)} children; {self.get_hash()}{', dirty' if self.dirty else ''})"


class ProjectFile(ProjFsNode):
    """Immutable file."""

    def __init__(self, name: str, content: bytes | str, dirty: bool = False) -> None:
        super().__init__(dirty=dirty)
        self._name = name
        if isinstance(content, str):
            self._content = content.encode("utf-8")
        elif isinstance(content, bytes):
            self._content = content
        else:
            msg = f"Unexpected type {type(content)}"
            raise TypeError(msg)

    @property
    def name(self):
        return self._name

    @property
    def content(self) -> bytes:
        return self._content

    @property
    def content_str(self) -> str:
        return self._content.decode("utf-8")

    def __getitem__(self, item):
        msg = "Cannot index into a file"
        raise RuntimeError(msg)

    def __eq__(self, other):
        if not isinstance(other, ProjectFile):
            return False
        return self._name == other._name and self._content == other._content

    def get_hash(self):
        if (not hasattr(self, "_hash")) or self._hash is None:
            hasher = xxhash.xxh64()
            hasher.update(self._name)
            hasher.update(self._content)
            self._hash = hasher.hexdigest()
        return self._hash

    def get_size(self, recursive: bool = False) -> int:
        """
        Get the size of this file in bytes.
        
        Args:
            recursive: Ignored for files (files are always leaf nodes).
        
        Returns:
            Size of file content in bytes
        """
        return len(self._content)


def split_fn(fn) -> tuple[str, str]:
    if "/" in fn:
        d, rest = fn.split("/", 1)
        return d, rest
    return fn, ""


def create_project_dir_from_docker_img(
    image_name: str,
    path: str = "/",
    tag: str = "latest",
) -> ProjectDir:
    """
    Extract a specific path from a Docker image and return as ProjectDir.
    
    Args:
        image_name: Docker image name (e.g., "nginx", "python", "nginx:latest")
        path: Path within the image to extract (default: root "/")
        tag: Image tag (default: "latest") - ignored if image_name already contains a tag
    
    Returns:
        ProjectDir containing the extracted filesystem
        
    Raises:
        subprocess.CalledProcessError: If docker commands fail
        FileNotFoundError: If the specified path doesn't exist in the image
        
    Example:
        # Extract entire filesystem
        fs = create_project_dir_from_docker_img("nginx:latest")
        
        # Extract specific directory
        app = create_project_dir_from_docker_img("my-app", "/app/src", "v1.0")
    """
    import subprocess
    import tempfile
    import shutil
    
    # Handle image name with or without tag
    if ":" in image_name:
        # Image name already includes tag, use as-is
        full_image_name = image_name
    else:
        # Add tag to image name
        full_image_name = f"{image_name}:{tag}"
    
    # Create temporary container (doesn't start it)
    # Use a dummy command in case the image doesn't have a default CMD
    try:
        result = subprocess.run([
            "docker", "create", full_image_name, "sleep", "0.01"
        ], capture_output=True, text=True, check=True)
    except subprocess.CalledProcessError as e:
        raise subprocess.CalledProcessError(
            e.returncode, 
            e.cmd, 
            f"Failed to create container from image {full_image_name}. "
            f"Make sure the image exists and Docker is running.\n{e.stderr}"
        ) from e
    
    container_id = result.stdout.strip()
    
    try:
        # Create temporary directory for extraction
        with tempfile.TemporaryDirectory() as temp_dir:
            temp_path = pathlib.Path(temp_dir)
            
            # Copy specific path from container
            try:
                subprocess.run([
                    "docker", "cp", 
                    f"{container_id}:{path}", 
                    str(temp_path)
                ], check=True, capture_output=True)
            except subprocess.CalledProcessError as e:
                raise FileNotFoundError(
                    f"Path '{path}' not found in image {full_image_name}"
                ) from e
            
            # The copied content will be at temp_path/basename(path)
            # Handle special case where path is root "/"
            if path == "/":
                # docker cp container:/ copies the contents, not the root dir itself
                return ProjectDir.from_path(temp_path)
            else:
                # docker cp container:/some/path creates temp_dir/path_basename
                copied_name = pathlib.Path(path).name
                extracted_path = temp_path / copied_name
                
                if not extracted_path.exists():
                    raise FileNotFoundError(
                        f"Expected extracted path {extracted_path} not found"
                    )
                
                return ProjectDir.from_path(extracted_path)
    
    finally:
        # Clean up container
        try:
            subprocess.run([
                "docker", "rm", container_id
            ], check=True, capture_output=True)
        except subprocess.CalledProcessError:
            # Log warning but don't fail - cleanup is best effort
            pass
