from __future__ import annotations

import re
from typing import Any, Self, override
from collections.abc import Callable, Hashable

import numpy as np
import numpy.typing as npt
import pandas as pd
import zarr
import zarr.abc.store
import zarr.storage
from pandas.core.generic import NDFrame
from pandas.core.internals.array_manager import BaseArrayManager
from zarr.core.common import AccessModeLiteral
from zarr.storage import StoreLike

import ezarr
from ezarr.dataframe.manager import EZArrayManager
from ezarr.types import SupportsEZReadWrite


class EZDataFrame(pd.DataFrame, SupportsEZReadWrite):  # EZObject[pd.DataFrame],
    """
    Proxy for pandas DataFrames for storing data in zarr stores.
    """

    _internal_names: list[str] = ["_data_store"] + pd.DataFrame._internal_names
    _internal_names_set: set[str] = {"_data_store"} | pd.DataFrame._internal_names_set

    _data_store: zarr.Group

    @property
    @override
    def _constructor(self) -> Callable[..., EZDataFrame]:
        def inner(df: Any) -> EZDataFrame:
            if not isinstance(df, EZDataFrame):
                raise ValueError("EZDataFrame constructor not properly called")

            return df

        return inner

    @override
    def _constructor_from_mgr(self, mgr: BaseArrayManager, axes: list[pd.Index]) -> pd.DataFrame:
        if not isinstance(mgr, EZArrayManager):
            df: pd.DataFrame = pd.DataFrame._from_mgr(mgr, axes=axes)  # pyright: ignore[reportAttributeAccessIssue]

        else:
            df = EZDataFrame._from_mgr(mgr, axes=axes)  # pyright: ignore[reportAttributeAccessIssue]

        if isinstance(self, pd.DataFrame):
            # This would also work `if self._constructor is DataFrame`, but
            #  this check is slightly faster, benefiting the most-common case.
            return df

        elif type(self).__name__ == "GeoDataFrame":
            # Shim until geopandas can override their _constructor_from_mgr
            #  bc they have different behavior for Managers than for DataFrames
            return self._constructor(mgr)

        # We assume that the subclass __init__ knows how to handle a
        #  pd.DataFrame object.
        return self._constructor(df)

    def _constructor_sliced_from_mgr(self, mgr: BaseArrayManager, axes: list[pd.Index[Any]]) -> pd.Series[Any]:
        ser: pd.Series = pd.Series._from_mgr(mgr, axes)  # pyright: ignore[reportAttributeAccessIssue]
        ser._name = None  # caller is responsible for setting real name

        if isinstance(self, pd.DataFrame):
            return ser

        return self._constructor_sliced(ser)

    def __init__(
        self,
        data: pd.DataFrame
        | dict[Hashable, npt.NDArray[Any]]
        | ezarr.EZDict[Any]
        | npt.NDArray[Any]
        | zarr.Array
        | tuple[EZArrayManager, zarr.Group]
        | None = None,
        index: zarr.Array | pd.Index | None = None,
        columns: zarr.Array | pd.Index | None = None,
    ):
        match data:
            case ezarr.EZDict():
                assert index is None, (
                    "Cannot provide the `index` parameter when also providing the `data` parameter as EZDict."
                )
                assert columns is None, (
                    "Cannot provide the `columns` parameter when also providing the `data` parameter as EZDict."
                )

                _index: pd.Index = pd.Index(data["index"].copy())
                _columns: pd.Index = pd.Index(data["arrays"].keys())
                arrays: list[zarr.Array] = [arr for arr in data["arrays"].values()]

                store = data.group
                mgr = EZArrayManager(store, arrays, [_index, _columns])

            case pd.DataFrame():
                _index = data.index if index is None else pd.Index(index)
                _columns = data.columns if columns is None else pd.Index(columns)

                store = zarr.create_group({})
                arrays = [
                    store.create_array(str(name), data=np.array(arr))
                    for name, arr in data.to_dict(orient="list").items()
                ]

                mgr = EZArrayManager(store, arrays, [_index, _columns])

            case zarr.Array(_async_array):
                _index = pd.RangeIndex(start=0, stop=data.shape[0]) if index is None else pd.Index(index)
                _columns = pd.RangeIndex(start=0, stop=data.shape[1]) if columns is None else pd.Index(columns)
                arrays = [col for col in data.T]

                store = zarr.open_group(_async_array.store)
                mgr = EZArrayManager(store, arrays, [_index, _columns])

            case np.ndarray():
                _index = pd.RangeIndex(start=0, stop=data.shape[0]) if index is None else pd.Index(index)
                _columns = pd.RangeIndex(start=0, stop=data.shape[1]) if columns is None else pd.Index(columns)
                arrays = [col for col in data.T]

                store = zarr.create_group({})
                mgr = EZArrayManager(store, arrays, [_index, _columns])

            case dict():
                assert columns is None, (
                    "Cannot provide the `columns` parameter when also providing the `data` parameter as dict."
                )

                _columns = pd.Index(data.keys())
                arrays = [zarr.create_array(name, data=np.asarray(arr)) for name, arr in data.items()]
                _index = pd.RangeIndex(start=0, stop=arrays[0].shape[0]) if index is None else pd.Index(index)

                store = zarr.create_group({})
                mgr = EZArrayManager(store, arrays, [_index, _columns])

            case None:
                _index = pd.RangeIndex(0) if index is None else pd.Index(index)
                _columns = pd.RangeIndex(0) if columns is None else pd.Index(columns)
                arrays = []

                store = zarr.create_group({})
                mgr = EZArrayManager(store, arrays, [_index, _columns])

            case (EZArrayManager() as mgr, zarr.Group() as store):
                pass

            case _:  # pyright: ignore[reportUnnecessaryComparison]
                raise TypeError(f"Invalid type '{type(data)}' for 'data' argument.")  # pyright: ignore[reportUnreachable]

        object.__setattr__(self, "_data_store", store)
        NDFrame.__init__(self, mgr)  # pyright: ignore[reportCallIssue]

    @override
    def __finalize__(self, other: EZDataFrame, method: str | None = None, **kwargs: Any) -> pd.DataFrame:  # pyright: ignore[reportIncompatibleMethodOverride]
        super().__finalize__(other, method, **kwargs)

        if method == "copy":
            return other

        return self

    @staticmethod
    def _get_mode(store: zarr.abc.store.Store) -> str:
        match store:
            case zarr.storage.FsspecStore():
                return "FSSpec"
            case zarr.storage.GpuMemoryStore():
                return "GPU"
            case zarr.storage.LocalStore():
                return "Local"
            case zarr.storage.LoggingStore(store=inner_store):  # pyright: ignore[reportUnknownVariableType]
                return f"Logging:{EZDataFrame._get_mode(inner_store)}"  # pyright: ignore[reportUnknownArgumentType]
            case zarr.storage.MemoryStore():
                return "RAM"
            case zarr.storage.ObjectStore():
                return "Object"
            case zarr.storage.WrapperStore(store=inner_store):  # pyright: ignore[reportUnknownVariableType]
                return f"Wrapper:{EZDataFrame._get_mode(inner_store)}"  # pyright: ignore[reportUnknownArgumentType]
            case zarr.storage.ZipStore():
                return "Zip"
            case _:
                return type(store).__name__

    @override
    def __repr__(self) -> str:
        repr_ = repr(self.iloc[:5].copy())
        if self.empty:
            return repr_

        re.sub(r"\n\n\[.*\]$", "", repr_)
        return (
            repr_
            + f"\n[{self._get_mode(self._data_store.store)}]\n[{len(self.index)} rows x {len(self.columns)} columns]"
        )

    @override
    def __setitem__(self, key, value) -> None:
        return super().__setitem__(key, value)

    @override
    def __ez_write__(self, values: ezarr.EZDict[Any]) -> None:
        if values.group.store == self._data_store.store:
            return

        _index = self.index.values
        if _index.dtype == object:
            _index = _index.astype(str)

        values["index"] = _index
        values["arrays"] = {str(k): ezarr.EZList.defer(v) for k, v in self.to_dict(orient="list").items()}
        values["arrays"].attrs["columns_order"] = self.columns.to_list()

    @classmethod
    @override
    def __ez_read__(cls, values: ezarr.EZDict[Any]) -> Self:
        arrays = values["arrays"]
        columns_order = values["arrays"].attrs["columns_order"]

        mgr = EZArrayManager(
            values.group,
            [arr for (_, arr) in sorted(arrays.group.arrays(), key=lambda name_arr: columns_order.index(name_arr[0]))],
            [pd.Index(values["index"][:]), pd.Index(columns_order)],
        )

        return cls((mgr, values.group))

    @property
    def data(self) -> ezarr.EZDict[Any] | None:
        return self._data_store

    @classmethod
    def open(
        cls, store: StoreLike | None = None, *, name: str, mode: AccessModeLiteral = "a", path: str | None = None
    ) -> Self:
        """
        Open this object from a store.

        Args:
            store: Store, path to a directory or name of a zip file.
            name: name for the object, to use inside the store.
            mode: Persistence mode: 'r' means read only (must exist); 'r+' means
                read/write (must exist); 'a' means read/write (create if doesn't
                exist); 'w' means create (overwrite if exists); 'w-' means create
                (fail if exists).
            path: path within the store to open.
        """
        path = f"{path.rstrip('/')}/{name}" if path else name
        return cls.__ez_read__(ezarr.EZDict(zarr.open_group(store, mode=mode, path=path)))

    def save(
        self,
        store: StoreLike,
        *,
        name: str,
        mode: AccessModeLiteral = "a",
        path: str | None = None,
        overwrite: bool = False,
    ) -> None:
        """
        Save this object to a local file system.

        Args:
            store: Store, path to a directory or name of a zip file.
            name: name for the object, to use inside the store.
            mode: Persistence mode: 'r' means read only (must exist); 'r+' means
                read/write (must exist); 'a' means read/write (create if doesn't
                exist); 'w' means create (overwrite if exists); 'w-' means create
                (fail if exists).
            path: path within the store where the object will be saved.
            overwrite: overwrite object if a group with name `name` already exists ? (default: False)
        """
        path = f"{path.rstrip('/')}/{name}" if path else name

        if not overwrite:
            mode = "w-"

        self.__ez_write__(ezarr.EZDict(zarr.open_group(store, mode=mode, path=path)))
