"""The main purpose of the :py:mod:`ssb_timeseries.fs` module is to allow file based IO that works on both a local file system and Google Cloud Storage."""

from __future__ import annotations

import functools
import glob
import json
import os
import shutil
from _collections_abc import Callable
from pathlib import Path

import narwhals
import pyarrow
import pyarrow.dataset
import pyarrow.parquet as pq
import tomli
import tomli_w
from dapla import FileClient
from narwhals.typing import IntoFrame

from ssb_timeseries.dataframes import to_arrow
from ssb_timeseries.types import F
from ssb_timeseries.types import PathStr

# mypy: disable-error-code="arg-type, type-arg, no-any-return, no-untyped-def, import-untyped, attr-defined, type-var, index, return-value"


def path_to_str(path: PathStr) -> PathStr:
    """Normalise as strings.

    This is a trick to make automated tests pass on Windows.
    """
    return str(Path(path)).replace("gs:/", "gs://")


def wrap_return_as_str(func: F) -> F:
    """Decorator to normalise outputs using path_to_str()."""

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        out = func(*args, **kwargs)
        return path_to_str(out)

    return wrapper


def remove_prefix(path: PathStr) -> str:
    """Helper function to compensate for some os.* functions shorten gs://<path> to gs:/<path>."""
    return str(path).replace("//", "/").replace("gs:/", "")


def is_gcs(path: PathStr) -> bool:
    """Check if path is on GCS."""
    return str(path).startswith("gs:/")


def is_local(path: PathStr) -> bool:
    """Check if path is local."""
    return not str(path).startswith("gs:/")


def fs_type(path: PathStr) -> str:
    """Check filesystem type (local or GCS) for a given path."""
    out = ""
    types = {"gcs": is_gcs(path), "local": is_local(path)}
    out = list(types.keys())[list(types.values()).index(True)]
    return out


def exists(path: PathStr) -> bool:
    """Check if a given (local or GCS) path exists."""
    # OSError: [Errno 36] File name too long
    # New (temporary) dataset names generated by functions can become very long.
    # The above error then occurs when metadata maintenance functions check if files exist.
    # Hence the need toi control for filenames
    p = Path(path)
    MAX_FILENAME_LENGTH = 254
    if len(p.name) > MAX_FILENAME_LENGTH:
        path = p.parent / p.name[:MAX_FILENAME_LENGTH]

    if not path:
        return False
    elif is_gcs(path):
        fs = FileClient.get_gcs_file_system()
        return fs.exists(path)
    else:
        return Path(path).exists()


@wrap_return_as_str
def existing_subpath(path: PathStr) -> PathStr:
    """Return the existing part of a path on local or GCS file system."""
    if Path(path).exists():
        return str(path)
    else:
        p = Path(path).parent
        while not p.exists():
            p = Path(p).parent
        return p


def touch(path: PathStr) -> PathStr:
    """Touch file regardless of wether the filesystem is local or GCS; return path."""
    if is_gcs(path):
        fs = FileClient.get_gcs_file_system()
        fs.touch(path)
    else:
        mk_parent_dir(path)
        Path(path).touch()
    return path


@wrap_return_as_str
def path(*args: PathStr) -> str:
    """Join args to form path. Make sure that gcs paths are begins with double slash: gs://..."""
    p = Path(args[0]).joinpath(*args[1:])
    return p


def mkdir(path: PathStr) -> None:
    """Make directory regardless of filesystem is local or GCS."""
    # not good enough .. it is hard to distinguish between dirs and files that do not exist yet
    if is_local(path):
        Path(path).mkdir(parents=True, exist_ok=True)
    else:
        ...


def mk_parent_dir(path: PathStr) -> None:
    """Ensure a parent directory exists. ... regardless of wether fielsystem is local or GCS."""
    # wanted a mkdir that could work seamlessly with both file and directory paths,
    # but it is hard to distinguish between dirs and files that do not exist yet
    # --> use this to create parent directory for files, mkdir() when the last part of path is a directory
    if is_local(path):
        Path(path).parent.mkdir(parents=True, exist_ok=True)
    else:
        ...


def file_count(path: PathStr, create: bool = False) -> int:
    """Count files in path. Should work regardless of wether source and target location is local fs or GCS to local."""
    return len(ls(path, create=create))


def ls(path: str, pattern: str = "*", create: bool = False) -> list[str]:
    """List files. Should work regardless of wether the filesystem is local or GCS."""
    search = os.path.join(path, pattern)
    if is_gcs(path):
        fs = FileClient.get_gcs_file_system()
        return fs.glob(search)
    else:
        if create:
            mkdir(path)
        return glob.glob(search)


def cp(from_path: PathStr, to_path: PathStr) -> None:
    """Copy file from one location to another.

    This function handles copying files between local and GCS paths,
    automatically selecting the correct backend.

    Args:
        from_path: The path to the source file.
        to_path: The path to the destination file.
    """
    from_type = fs_type(from_path)
    to_type = fs_type(to_path)
    if is_gcs(from_path) | is_gcs(to_path):
        fs = FileClient.get_gcs_file_system()
    if is_local(to_path):
        mk_parent_dir(to_path)

    match (from_type, to_type):
        case ("local", "local"):
            shutil.copy2(from_path, to_path)
        case ("local", "gcs"):
            fs.put(from_path, to_path)
        case ("gcs", "local"):
            fs.get(from_path, to_path)
        case ("gcs", "gcs"):
            fs.copy(from_path, to_path)


def mv(from_path: PathStr, to_path: PathStr) -> None:
    """Move file from one location to another.

    This function handles moving files between local and GCS paths,
    automatically selecting the correct backend.

    Args:
        from_path: The path to the source file.
        to_path: The path to the destination file.
    """
    from_type = fs_type(from_path)
    to_type = fs_type(to_path)

    if is_gcs(from_path) | is_gcs(to_path):
        fs = FileClient.get_gcs_file_system()
    if is_local(to_path):
        mk_parent_dir(to_path)

    match (from_type, to_type):
        case ("local", "local"):
            shutil.move(from_path, to_path)
        case ("local", "gcs"):
            fs.put(from_path, to_path)
        case ("gcs", "local"):
            fs.get(from_path, to_path)
        case ("gcs", "gcs"):
            fs.move(from_path, to_path)


def rm(path: PathStr) -> None:
    """Remove a file from either the local filesystem or GCS.

    This function is non-recursive. For a recursive variant, see rmtree().

    Args:
        path: The path to the file to be removed.
    """
    if is_gcs(path):
        fs = FileClient.get_gcs_file_system()
        fs.rm(path)
    else:
        os.remove(path)


def rmtree(
    path: str,
) -> None:
    """Recursively remove a directory and all its subdirectories and files regardless of local or GCS filesystem."""
    if is_gcs(path):
        ...
        # TO DO: implement this (but recursive)
        # fs = FileClient.get_gcs_file_system()
        # fs.rm(path)
    else:
        shutil.rmtree(path)


@wrap_return_as_str
def same_path(*args) -> PathStr:
    """Return common part of path, for two or more files. Files must be on same file system, but the file system can be either local or GCS."""
    # TO DO: add support for Windows style paths?
    # ... regex along the lines of: [A-Z\:|\\\\]
    paths = [a.replace("gs:/", "") for a in args]
    return os.path.commonpath(paths)


def find(
    search_path: PathStr,
    equals: str = "",
    contains: str = "",
    pattern: str = "",
    search_sub_dirs: bool = True,
    full_path: bool = False,
    replace_root: bool = False,
) -> list[str]:
    """Find files and subdirectories with names matching pattern. Should work for both local and GCS filesystems."""
    if contains:
        pattern = f"*{pattern}*"
    elif equals:
        pattern = equals
    elif not pattern:
        pattern = "*"

    if search_sub_dirs:
        search_str = path(search_path, "*", pattern)
    else:
        search_str = path(search_path, pattern)

    if is_gcs(path):
        fs = FileClient.get_gcs_file_system()
        found = fs.glob(search_str)
    else:
        found = glob.glob(search_str)

    if replace_root:
        # may be necessary if not returning full path? -> TODO: add tests
        found = [f.replace(path, "root").split(os.path.sep) for f in found]

    if full_path:
        return found
    else:
        return [f[-1] for f in found]


def read_text(path: PathStr, file_format: str = "") -> dict:
    """Read a text file from specified path on either local fs or GCS."""
    if not file_format:
        file_format = Path(path).suffix
    read_func = _text_reader(file_format)
    if is_gcs(path):
        fs = FileClient.get_gcs_file_system()
        with fs.open(path, "r") as file:
            return read_func(file)
    else:
        with open(path) as file:
            return read_func(file)


def write_text(path: PathStr, content: str | dict, file_format: str) -> None:
    """Write json file to path on either local fs or GCS."""
    if not file_format:
        file_format = Path(path).suffix
    write = _text_writer(file_format)

    if is_gcs(path):
        fs = FileClient.get_gcs_file_system()
        with fs.open(path, "w") as file:
            write(file, content)
    else:
        mk_parent_dir(path)
        with open(path, "w") as file:
            write(file, content)


def _text_reader(file_format: str) -> Callable:
    """Helper function for selecting a function for reading files in the specified format."""
    match file_format.lower():
        case "json":

            def reader(file: PathStr) -> dict:
                return json.load(file)

        case "toml":

            def reader(file: PathStr) -> dict:
                return tomli.load(file)

    return reader


def _text_writer(file_format: str) -> Callable[[PathStr, dict], None]:
    """Return a function for writing a dict to text file in specified format.

    Args:
        file_format (str): Format of the text file, either "json" or "toml".

    Raises:
        ValueError: If the file format is not supported.
    """
    match file_format.lower():
        case "json":

            def writer(file: PathStr, content: dict | str) -> None:
                if isinstance(content, str):
                    content = json.loads(content)
                return json.dump(content, file, indent=4, ensure_ascii=False)

        case "toml":

            def writer(file: PathStr, content: dict | str) -> None:
                if isinstance(content, str):
                    content = tomli.loads(content)
                return tomli_w.dump(file, content)

        case _:
            raise ValueError(f"Format {file_format} is not supported.")

    return writer


def read_json(path: PathStr) -> dict:
    """Read json file from path on either local fs or GCS."""
    if is_gcs(path):
        fs = FileClient.get_gcs_file_system()
        with fs.open(path, "r") as file:
            return json.load(file)
    else:
        with open(path) as file:
            return json.load(file)


def write_json(path: PathStr, content: str | dict) -> None:
    """Write json file to path on either local fs or GCS."""
    if is_gcs(path):
        fs = FileClient.get_gcs_file_system()
        with fs.open(path, "w") as file:
            if isinstance(content, str):
                content = json.loads(content)
            json.dump(content, file, indent=4, ensure_ascii=False)
    else:
        mk_parent_dir(path)
        with open(path, "w") as file:
            json.dump(content, file, indent=4, ensure_ascii=False)


# def metadata_from_parquet(filename: PathStr) -> dict:
#     """Read metadata from parquet file."""
#     meta = pq.read_metadata(filename)
#     decoded_schema = base64.b64decode(meta.metadata[b"ARROW:schema"])
#     return pyarrow.ipc.read_schema(pyarrow.BufferReader(decoded_schema))


# def metadata_to_parquet(metadata: dict, filename: PathStr) -> None:
#     """Write metadata to parquet file."""
#     schema = pyarrow.schema(metadata)
#     table = pyarrow.Table.from_pandas(metadata, schema=schema)
#     pq.write_table(table, filename)


def read_parquet(
    path: PathStr,
    lazy: bool = False,  # TODO: Later steps need lazy-proofing before we can switch defaults.
    implementation: str = "pyarrow",
    **kwargs,
) -> narwhals.typing.Frame:
    """Read a Parquet file into a dataframe.

    This function can read from both local and GCS paths.

    Args:
        path: The path to the Parquet file.
        lazy: If True, returns a lazy dataframe. Defaults to False.
        implementation: The backend to use for reading the file.
            Defaults to "pyarrow".
        **kwargs: Additional keyword arguments passed to the backend.

    Returns:
        A Narwhals dataframe.
    """
    if lazy:
        return narwhals.scan_parquet(path, backend=implementation, **kwargs)
    else:
        return narwhals.read_parquet(path, backend=implementation, **kwargs)


def write_parquet(
    data: pyarrow.Table | IntoFrame,
    path: PathStr,
    schema: pyarrow.Schema | None = None,
    **kwargs,
) -> None:
    """Write a dataframe to a Parquet file.

    This function can write to both local and GCS paths, automatically
    selecting the correct filesystem backend. It also handles schema validation.

    Args:
        data: The dataframe to write (can be a PyArrow Table or any
            Narwhals-compatible dataframe).
        path: The destination path for the Parquet file.
        schema: An optional PyArrow schema to validate against before writing.
        **kwargs: Additional keyword arguments passed to the backend.
    """
    table = to_arrow(data, schema)  # to validate schema ...
    if is_gcs(path):
        fs = FileClient.get_gcs_file_system()
    else:
        fs = pyarrow.fs.LocalFileSystem()
        mk_parent_dir(path)

    if isinstance(table, pyarrow.Table):
        pq.write_table(
            table,
            where=path,
            filesystem=fs,
            # schema=schema,
            **kwargs,
        )
    else:
        # TODO: figure out how to do schema validation, then this would do:
        narwhals.from_native(data).write_parquet(path)
    # to make schema validation work / keep IO pure pyarrow it may bew better to go back to this(?):
    # pyarrow.dataset.write_dataset(
    #     data,
    #     path,
    #     filesystem=fs,
    #     format="parquet",
    #     schema=schema,
    #     # partitioning=["as_of_utc"],
    #     # partitioning_flavor="hive",
    #     **kwargs,
    # )


# def update_parquet_metadata(
#     path: PathStr,
#     tags: dict | None = None,
#     schema: pyarrow.Schema = None,
# ) -> None:
#     """TODO: Add faster pyarrrow implementations enforcing type based schemas."""
#     table = pq.read_table(path)
#     existing_metadata = table.schema.metadata

#     if schema:
#         table = table.cast(schema)

#     byte_encoded_tags = json.dumps(tags).encode("utf8")
#     merged_metadata = {
#         **existing_metadata,
#         **{"metadata": byte_encoded_tags},
#     }

#     # is this covered by cast(schema) and replace_schema_metadata?
#     # convert_data = table.cast(table.schema)
#     pq.write_table(
#         table,
#         table.replace_schema_metadata(merged_metadata),
#         path,
#     )
