# Copyright 2022-2023 XProbe Inc.
# derived from copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import operator

import numpy as np
import pandas as pd

from .. import opcodes as OperandDef
from ..core import OutputType
from ..core.operand import MapReduceOperand, OperandStage
from ..serialization.serializables import (
    AnyField,
    BoolField,
    FieldTypes,
    Int32Field,
    KeyField,
    ListField,
)
from .core import SERIES_CHUNK_TYPE
from .operands import DataFrameOperandMixin, DataFrameShuffleProxy
from .utils import (
    build_split_idx_to_origin_idx,
    filter_dtypes,
    filter_index_value,
    hash_dtypes,
    hash_index,
    is_index_value_identical,
    parse_index,
    split_monotonic_index_min_max,
    validate_axis,
)


class DataFrameIndexAlign(MapReduceOperand, DataFrameOperandMixin):
    _op_type_ = OperandDef.DATAFRAME_INDEX_ALIGN

    index_min = AnyField("index_min")
    index_min_close = BoolField("index_min_close")
    index_max = AnyField("index_max")
    index_max_close = BoolField("index_max_close")
    index_shuffle_size = Int32Field("index_shuffle_size", default=None)
    column_min = AnyField("column_min")
    column_min_close = BoolField("column_min_close")
    column_max = AnyField("column_max")
    column_max_close = BoolField("column_max_close")
    column_shuffle_size = Int32Field("column_shuffle_size", default=None)
    column_shuffle_segments = ListField("column_shuffle_segments", FieldTypes.series)

    input = KeyField("input")

    def __init__(
        self, index_min_max=None, column_min_max=None, output_types=None, **kw
    ):
        if index_min_max is not None:
            kw.update(
                dict(
                    index_min=index_min_max[0],
                    index_min_close=index_min_max[1],
                    index_max=index_min_max[2],
                    index_max_close=index_min_max[3],
                )
            )
        if column_min_max is not None:
            kw.update(
                dict(
                    column_min=column_min_max[0],
                    column_min_close=column_min_max[1],
                    column_max=column_min_max[2],
                    column_max_close=column_min_max[3],
                )
            )
        super().__init__(_output_types=output_types, **kw)

    @property
    def index_min_max(self):
        if getattr(self, "index_min", None) is None:
            return None
        return (
            self.index_min,
            self.index_min_close,
            self.index_max,
            self.index_max_close,
        )

    @property
    def column_min_max(self):
        if getattr(self, "column_min", None) is None:
            return None
        return (
            self.column_min,
            self.column_min_close,
            self.column_max,
            self.column_max_close,
        )

    def _set_inputs(self, inputs):
        super()._set_inputs(inputs)
        self.input = self._inputs[0]

    def build_map_chunk_kw(self, inputs, **kw):
        if kw.get("index_value", None) is None and inputs[0].index_value is not None:
            input_index_value = inputs[0].index_value
            index_min_max = self.index_min_max
            if index_min_max is not None:
                kw["index_value"] = filter_index_value(input_index_value, index_min_max)
            else:
                kw["index_value"] = parse_index(
                    inputs[0].index_value.to_pandas(),
                    input_index_value,
                    type(self).__name__,
                )
        if self.output_types[0] == OutputType.dataframe:
            if (
                kw.get("columns_value", None) is None
                and getattr(inputs[0], "columns_value", None) is not None
            ):
                input_columns_value = inputs[0].columns_value
                input_dtypes = inputs[0].dtypes
                column_min_max = self.column_min_max
                if column_min_max is not None:
                    kw["columns_value"] = filter_index_value(
                        input_columns_value, column_min_max, store_data=True
                    )
                else:
                    kw["columns_value"] = parse_index(
                        inputs[0].columns_value.to_pandas(),
                        input_columns_value,
                        type(self).__name__,
                    )
                kw["dtypes"] = input_dtypes[kw["columns_value"].to_pandas()]
                column_shuffle_size = self.column_shuffle_size
                if column_shuffle_size is not None:
                    self.column_shuffle_segments = hash_dtypes(
                        input_dtypes, column_shuffle_size
                    )
        else:
            if (
                kw.get("dtype", None) is None
                and getattr(inputs[0], "dtype", None) is not None
            ):
                kw["dtype"] = inputs[0].dtype
            if (
                kw.get("name", None) is None
                and getattr(inputs[0], "name", None) is not None
            ):
                kw["name"] = inputs[0].name
        return kw

    def build_reduce_chunk_kw(self, inputs, index, **kw):
        kw["index"] = index
        if (
            kw.get("index_value", None) is None
            and inputs[0].inputs[0].index_value is not None
        ):
            index_align_map_chunks = inputs[0].inputs
            if index_align_map_chunks[0].op.index_min_max is not None:
                # shuffle on columns, all the DataFrameIndexAlignMap has the same index
                kw["index_value"] = filter_index_value(
                    index_align_map_chunks[0].index_value,
                    index_align_map_chunks[0].op.index_min_max,
                )
            else:
                # shuffle on index
                kw["index_value"] = parse_index(
                    index_align_map_chunks[0].index_value.to_pandas(),
                    [c.key for c in index_align_map_chunks],
                    type(self).__name__,
                )
        if self.output_types[0] == OutputType.dataframe:
            if (
                kw.get("columns_value", None) is None
                and getattr(inputs[0].inputs[0], "columns_value", None) is not None
            ):
                index_align_map_chunks = inputs[0].inputs
                if index_align_map_chunks[0].op.column_min_max is not None:
                    # shuffle on index
                    kw["columns_value"] = filter_index_value(
                        index_align_map_chunks[0].columns_value,
                        index_align_map_chunks[0].op.column_min_max,
                        store_data=True,
                    )
                    kw["dtypes"] = index_align_map_chunks[0].dtypes[
                        kw["columns_value"].to_pandas()
                    ]
                else:
                    # shuffle on columns
                    all_dtypes = [
                        c.op.column_shuffle_segments[index[1]]
                        for c in index_align_map_chunks
                        if c.index[0] == index_align_map_chunks[0].index[0]
                    ]
                    kw["dtypes"] = pd.concat(all_dtypes)
                    kw["columns_value"] = parse_index(
                        kw["dtypes"].index, store_data=True
                    )
        else:
            if (
                kw.get("dtype", None) is None
                and getattr(inputs[0].inputs[0], "dtype", None) is not None
            ):
                kw["dtype"] = inputs[0].inputs[0].dtype
            if (
                kw.get("name", None) is None
                and getattr(inputs[0].inputs[0], "name", None) is not None
            ):
                kw["name"] = inputs[0].inputs[0].name
        return kw

    @classmethod
    def execute_map(cls, ctx, op):
        # TODO(QIN): add GPU support here
        df = ctx[op.inputs[0].key]

        filters = [[], []]

        chunk = op.outputs[0]
        if op.index_shuffle_size == -1:
            # no shuffle and no min-max filter on index
            filters[0].append(slice(None, None, None))
        elif op.index_shuffle_size is None:
            # no shuffle on index
            comp_op = operator.ge if op.index_min_close else operator.gt
            index_cond = comp_op(df.index, op.index_min)
            comp_op = operator.le if op.index_max_close else operator.lt
            index_cond = index_cond & comp_op(df.index, op.index_max)
            filters[0].append(index_cond)
        else:
            # shuffle on index
            shuffle_size = op.index_shuffle_size
            filters[0].extend(hash_index(df.index, shuffle_size))

        if chunk.ndim == 1:
            if len(filters[0]) == 1:
                # no shuffle
                ctx[chunk.key] = df.loc[filters[0][0]]
            else:
                for index_idx, index_filter in enumerate(filters[0]):
                    ctx[chunk.key, (index_idx,)] = (
                        ctx.get_current_chunk().index,
                        df.loc[index_filter],
                    )
            return

        if op.column_shuffle_size == -1:
            # no shuffle and no min-max filter on columns
            filters[1].append(slice(None, None, None))
        if op.column_shuffle_size is None:
            # no shuffle on columns
            comp_op = operator.ge if op.column_min_close else operator.gt
            columns_cond = comp_op(df.columns, op.column_min)
            comp_op = operator.le if op.column_max_close else operator.lt
            columns_cond = columns_cond & comp_op(df.columns, op.column_max)
            filters[1].append(columns_cond)
        else:
            # shuffle on columns
            shuffle_size = op.column_shuffle_size
            filters[1].extend(hash_index(df.columns, shuffle_size))

        if all(len(it) == 1 for it in filters):
            # no shuffle
            ctx[chunk.key] = df.loc[filters[0][0], filters[1][0]]
        elif len(filters[0]) == 1:
            # shuffle on columns
            for column_idx, column_filter in enumerate(filters[1]):
                shuffle_index = (chunk.index[0], column_idx)
                ctx[chunk.key, shuffle_index] = (
                    ctx.get_current_chunk().index,
                    df.loc[filters[0][0], column_filter],
                )
        elif len(filters[1]) == 1:
            # shuffle on index
            for index_idx, index_filter in enumerate(filters[0]):
                shuffle_index = (index_idx, chunk.index[1])
                ctx[chunk.key, shuffle_index] = (
                    ctx.get_current_chunk().index,
                    df.loc[index_filter, filters[1][0]],
                )
        else:
            # full shuffle
            shuffle_index_size = op.index_shuffle_size
            shuffle_column_size = op.column_shuffle_size
            out_idxes = itertools.product(
                range(shuffle_index_size), range(shuffle_column_size)
            )
            out_index_columns = itertools.product(*filters)
            for out_idx, out_index_column in zip(out_idxes, out_index_columns):
                index_filter, column_filter = out_index_column
                ctx[chunk.key, out_idx] = (
                    ctx.get_current_chunk().index,
                    df.loc[index_filter, column_filter],
                )

    @classmethod
    def execute_reduce(cls, ctx, op: "DataFrameIndexAlign"):
        chunk = op.outputs[0]
        input_idx_to_df = dict(op.iter_mapper_data(ctx))
        row_idxes = sorted({idx[0] for idx in input_idx_to_df})
        if chunk.ndim == 2:
            col_idxes = sorted({idx[1] for idx in input_idx_to_df})

        ress = []
        for row_idx in row_idxes:
            if chunk.ndim == 2:
                row_dfs = []
                for col_idx in col_idxes:
                    row_dfs.append(input_idx_to_df[row_idx, col_idx])
                row_df = pd.concat(row_dfs, axis=1)
            else:
                row_df = input_idx_to_df[(row_idx,)]

            ress.append(row_df)

        ctx[chunk.key] = pd.concat(ress, axis=0)

    @classmethod
    def execute(cls, ctx, op):
        if op.stage == OperandStage.map:
            cls.execute_map(ctx, op)
        else:
            cls.execute_reduce(ctx, op)


class _AxisMinMaxSplitInfo(object):
    def __init__(
        self, left_split, left_increase, right_split, right_increase, dummy=False
    ):
        self._left_split = left_split
        self._right_split = right_split
        self._dummy = dummy

        self._left_split_idx_to_origin_idx = build_split_idx_to_origin_idx(
            self._left_split, left_increase
        )
        self._right_split_idx_to_origin_idx = build_split_idx_to_origin_idx(
            self._right_split, right_increase
        )

    def isdummy(self):
        return self._dummy

    def get_origin_left_idx(self, idx):
        return self._left_split_idx_to_origin_idx[idx][0]

    def get_origin_left_split(self, idx):
        left_idx, left_inner_idx = self._left_split_idx_to_origin_idx[idx]
        return self._left_split[left_idx][left_inner_idx]

    def get_origin_right_idx(self, idx):
        return self._right_split_idx_to_origin_idx[idx][0]

    def get_origin_right_split(self, idx):
        right_idx, right_inner_idx = self._right_split_idx_to_origin_idx[idx]
        return self._right_split[right_idx][right_inner_idx]


class _MinMaxSplitInfo(object):
    def __init__(self, row_min_max_split_info=None, col_min_max_split_info=None):
        self.row_min_max_split_info = row_min_max_split_info
        self.col_min_max_split_info = col_min_max_split_info

    def all_axes_can_split(self):
        return (
            self.row_min_max_split_info is not None
            and self.col_min_max_split_info is not None
        )

    def one_axis_can_split(self):
        return (self.row_min_max_split_info is None) ^ (
            self.col_min_max_split_info is None
        )

    def no_axis_can_split(self):
        return (
            self.row_min_max_split_info is None and self.col_min_max_split_info is None
        )

    def __getitem__(self, i):
        return [self.row_min_max_split_info, self.col_min_max_split_info][i]

    def __setitem__(self, axis, axis_min_max_split_info):
        assert axis in {0, 1}
        if axis == 0:
            self.row_min_max_split_info = axis_min_max_split_info
        else:
            self.col_min_max_split_info = axis_min_max_split_info

    def get_row_left_idx(self, out_idx):
        return self.row_min_max_split_info.get_origin_left_idx(out_idx)

    def get_row_left_split(self, out_idx):
        return self.row_min_max_split_info.get_origin_left_split(out_idx)

    def get_col_left_idx(self, out_idx):
        return self.col_min_max_split_info.get_origin_left_idx(out_idx)

    def get_col_left_split(self, out_idx):
        return self.col_min_max_split_info.get_origin_left_split(out_idx)

    def get_row_right_idx(self, out_idx):
        return self.row_min_max_split_info.get_origin_right_idx(out_idx)

    def get_row_right_split(self, out_idx):
        return self.row_min_max_split_info.get_origin_right_split(out_idx)

    def get_col_right_idx(self, out_idx):
        return self.col_min_max_split_info.get_origin_right_idx(out_idx)

    def get_col_right_split(self, out_idx):
        return self.col_min_max_split_info.get_origin_right_split(out_idx)

    def get_axis_idx(self, axis, left_or_right, out_idx):
        if axis == 0:
            if left_or_right == 0:
                return self.get_row_left_idx(out_idx)
            else:
                assert left_or_right == 1
                return self.get_row_right_idx(out_idx)
        else:
            assert axis == 1
            if left_or_right == 0:
                return self.get_col_left_idx(out_idx)
            else:
                assert left_or_right == 1
                return self.get_col_right_idx(out_idx)

    def get_axis_split(self, axis, left_or_right, out_idx):
        if axis == 0:
            if left_or_right == 0:
                return self.get_row_left_split(out_idx)
            else:
                assert left_or_right == 1
                return self.get_row_right_split(out_idx)
        else:
            assert axis == 1
            if left_or_right == 0:
                return self.get_col_left_split(out_idx)
            else:
                assert left_or_right == 1
                return self.get_col_right_split(out_idx)


def _get_chunk_index_min_max(index_chunks):
    chunk_index_min_max = []
    for chunk in index_chunks:
        min_val = chunk.min_val
        min_val_close = chunk.min_val_close
        max_val = chunk.max_val
        max_val_close = chunk.max_val_close
        if min_val is None or max_val is None:
            chunk_index_min_max.append((None, True, None, True))
        else:
            chunk_index_min_max.append((min_val, min_val_close, max_val, max_val_close))
    return chunk_index_min_max


def _get_monotonic_chunk_index_min_max(index, index_chunks):
    chunk_index_min_max = _get_chunk_index_min_max(index_chunks)
    if index.is_monotonic_decreasing:
        return list(reversed(chunk_index_min_max)), False

    for j in range(len(chunk_index_min_max) - 1):
        # overlap only if the prev max is close and curr min is close
        # and they are identical
        prev_max, prev_max_close = chunk_index_min_max[j][2:]
        curr_min, curr_min_close = chunk_index_min_max[j + 1][:2]
        if prev_max_close and curr_min_close and prev_max == curr_min:
            return
    return chunk_index_min_max, True


def _need_align_map(
    input_chunk,
    index_min_max,
    column_min_max,
    dummy_index_splits=False,
    dummy_column_splits=False,
):
    if isinstance(input_chunk, SERIES_CHUNK_TYPE):
        if input_chunk.index_value is None:
            return True
        if input_chunk.index_value.min_max != index_min_max:
            return True
    else:
        if not dummy_index_splits:
            if (
                input_chunk.index_value is None
                or input_chunk.index_value.min_max != index_min_max
            ):
                return True
        if not dummy_column_splits:
            if (
                input_chunk.columns_value is None
                or input_chunk.columns_value.min_max != column_min_max
            ):
                return True
    return False


def _is_index_identical(left, right):
    if len(left) != len(right):
        return False
    for left_item, right_item in zip(left, right):
        if left_item.key != right_item.key:
            return False
    return True


def _axis_need_shuffle(left_axis, right_axis, left_axis_chunks, right_axis_chunks):
    if _is_index_identical(left_axis_chunks, right_axis_chunks):
        return False
    if (
        not left_axis.is_monotonic_increasing_or_decreasing
        and len(left_axis_chunks) > 1
    ):
        return True
    if (
        not right_axis.is_monotonic_increasing_or_decreasing
        and len(right_axis_chunks) > 1
    ):
        return True
    return False


def _calc_axis_splits(left_axis, right_axis, left_axis_chunks, right_axis_chunks):
    if _axis_need_shuffle(left_axis, right_axis, left_axis_chunks, right_axis_chunks):
        # do shuffle
        out_chunk_size = max(len(left_axis_chunks), len(right_axis_chunks))
        return None, [np.nan for _ in range(out_chunk_size)]
    else:
        # no need to do shuffle on this axis
        if _is_index_identical(left_axis_chunks, right_axis_chunks):
            left_chunk_index_min_max = _get_chunk_index_min_max(left_axis_chunks)
            right_splits = left_splits = [[c] for c in left_chunk_index_min_max]
            right_increase = left_increase = None
        elif len(left_axis_chunks) == 1 and len(right_axis_chunks) == 1:
            left_splits = [_get_chunk_index_min_max(left_axis_chunks)]
            left_increase = left_axis_chunks[0].is_monotonic_decreasing
            right_splits = [_get_chunk_index_min_max(right_axis_chunks)]
            right_increase = right_axis_chunks[0].is_monotonic_decreasing
        else:
            (
                left_chunk_index_min_max,
                left_increase,
            ) = _get_monotonic_chunk_index_min_max(left_axis, left_axis_chunks)
            (
                right_chunk_index_min_max,
                right_increase,
            ) = _get_monotonic_chunk_index_min_max(right_axis, right_axis_chunks)
            left_splits, right_splits = split_monotonic_index_min_max(
                left_chunk_index_min_max,
                left_increase,
                right_chunk_index_min_max,
                right_increase,
            )
        splits = _AxisMinMaxSplitInfo(
            left_splits, left_increase, right_splits, right_increase
        )
        return splits, None


def _build_dummy_axis_split(chunk_shape):
    axis_index_min_max, axis_increase = (
        [(i, True, i + 1, True) for i in range(chunk_shape)],
        True,
    )
    if len(axis_index_min_max) == 1:
        left_splits, right_splits = [axis_index_min_max], [axis_index_min_max]
    else:
        left_splits, right_splits = split_monotonic_index_min_max(
            axis_index_min_max, axis_increase, axis_index_min_max, axis_increase
        )
    return _AxisMinMaxSplitInfo(
        left_splits, axis_increase, right_splits, axis_increase, dummy=True
    )


def _gen_series_chunks(splits, out_shape, left_or_right, series):
    out_chunks = []
    if splits[0] is not None:
        # need no shuffle
        for out_idx in range(out_shape[0]):
            idx = splits.get_axis_idx(0, left_or_right, out_idx)
            index_min_max = splits.get_axis_split(0, left_or_right, out_idx)
            chunk = series.cix[(idx,)]
            if _need_align_map(chunk, index_min_max, None):
                align_op = DataFrameIndexAlign(
                    stage=OperandStage.map,
                    index_min_max=index_min_max,
                    column_min_max=None,
                    dtype=chunk.dtype,
                    sparse=series.issparse(),
                    output_types=[OutputType.series],
                )
                params = align_op.build_map_chunk_kw(
                    [chunk], shape=(np.nan,), index=(out_idx,)
                )
                out_chunk = align_op.new_chunk([chunk], **params)
            else:
                out_chunk = chunk
            out_chunks.append(out_chunk)
    else:
        # gen map chunks
        map_chunks = []
        for chunk in series.chunks:
            map_op = DataFrameIndexAlign(
                stage=OperandStage.map,
                sparse=chunk.issparse(),
                index_shuffle_size=out_shape[0],
                output_types=[OutputType.series],
            )
            params = map_op.build_map_chunk_kw(
                [chunk], shape=(np.nan,), index=chunk.index
            )
            map_chunks.append(map_op.new_chunk([chunk], **params))

        proxy_chunk = DataFrameShuffleProxy(output_types=[OutputType.series]).new_chunk(
            map_chunks, shape=()
        )

        # gen reduce chunks
        for out_idx in range(out_shape[0]):
            reduce_op = DataFrameIndexAlign(
                stage=OperandStage.reduce,
                n_reducers=out_shape[0],
                i=out_idx,
                sparse=proxy_chunk.issparse(),
                output_types=[OutputType.series],
            )
            params = reduce_op.build_reduce_chunk_kw(
                [proxy_chunk], index=(out_idx,), shape=(np.nan,)
            )
            out_chunks.append(reduce_op.new_chunk([proxy_chunk], **params))

    return out_chunks


def _gen_dataframe_chunks(splits, out_shape, left_or_right, df):
    out_chunks = []
    if splits.all_axes_can_split():
        # no shuffle for all axes
        kw = {
            "index_shuffle_size": -1 if splits[0].isdummy() else None,
            "column_shuffle_size": -1 if splits[1].isdummy() else None,
        }
        for out_idx in itertools.product(*(range(s) for s in out_shape)):
            row_idx = splits.get_axis_idx(0, left_or_right, out_idx[0])
            col_idx = splits.get_axis_idx(1, left_or_right, out_idx[1])
            index_min_max = splits.get_axis_split(0, left_or_right, out_idx[0])
            column_min_max = splits.get_axis_split(1, left_or_right, out_idx[1])
            chunk = df.cix[row_idx, col_idx]
            if _need_align_map(
                chunk,
                index_min_max,
                column_min_max,
                splits[0].isdummy(),
                splits[1].isdummy(),
            ):
                if splits[1].isdummy():
                    dtypes = chunk.dtypes
                else:
                    dtypes = filter_dtypes(chunk.dtypes, column_min_max)
                chunk_kw = {
                    "index_value": chunk.index_value if splits[0].isdummy() else None,
                    "columns_value": chunk.columns_value
                    if splits[1].isdummy()
                    else None,
                    "dtypes": chunk.dtypes if splits[1].isdummy() else None,
                }
                align_op = DataFrameIndexAlign(
                    stage=OperandStage.map,
                    index_min_max=index_min_max,
                    column_min_max=column_min_max,
                    dtypes=dtypes,
                    sparse=chunk.issparse(),
                    output_types=[OutputType.dataframe],
                    **kw
                )
                params = align_op.build_map_chunk_kw(
                    [chunk], shape=(np.nan, np.nan), index=out_idx, **chunk_kw
                )
                out_chunk = align_op.new_chunk([chunk], **params)
            else:
                out_chunk = chunk
            out_chunks.append(out_chunk)
    elif splits.one_axis_can_split():
        # one axis needs shuffle
        shuffle_axis = 0 if splits[0] is None else 1
        align_axis = 1 - shuffle_axis

        for align_axis_idx in range(out_shape[align_axis]):
            if align_axis == 0:
                kw = {
                    "index_min_max": splits.get_axis_split(
                        align_axis, left_or_right, align_axis_idx
                    ),
                    "index_shuffle_size": -1 if splits[0].isdummy() else None,
                    "column_shuffle_size": out_shape[shuffle_axis],
                }
                input_idx = splits.get_axis_idx(
                    align_axis, left_or_right, align_axis_idx
                )
            else:
                kw = {
                    "column_min_max": splits.get_axis_split(
                        align_axis, left_or_right, align_axis_idx
                    ),
                    "index_shuffle_size": out_shape[shuffle_axis],
                    "column_shuffle_size": -1 if splits[1].isdummy() else None,
                }
                input_idx = splits.get_axis_idx(
                    align_axis, left_or_right, align_axis_idx
                )
            input_chunks = [c for c in df.chunks if c.index[align_axis] == input_idx]
            map_chunks = []
            for j, input_chunk in enumerate(input_chunks):
                chunk_kw = dict()
                if align_axis == 0:
                    chunk_kw["index_value"] = (
                        input_chunk.index_value if splits[0].isdummy() else None
                    )
                else:
                    chunk_kw["columns_value"] = (
                        input_chunk.columns_value if splits[1].isdummy() else None
                    )
                chunk_kw["dtypes"] = input_chunk.dtypes
                map_op = DataFrameIndexAlign(
                    stage=OperandStage.map,
                    sparse=input_chunk.issparse(),
                    output_types=[OutputType.dataframe],
                    **kw
                )
                idx = [None, None]
                idx[align_axis] = align_axis_idx
                idx[shuffle_axis] = j
                params = map_op.build_map_chunk_kw(
                    [input_chunk], shape=(np.nan, np.nan), index=tuple(idx), **chunk_kw
                )
                map_chunks.append(map_op.new_chunk([input_chunk], **params))
            proxy_chunk = DataFrameShuffleProxy(
                sparse=df.issparse(), output_types=[OutputType.dataframe]
            ).new_chunk(map_chunks, shape=())
            for j in range(out_shape[shuffle_axis]):
                chunk_kw = dict()
                if align_axis == 0:
                    chunk_kw["index_value"] = (
                        proxy_chunk.inputs[0].inputs[0].index_value
                        if splits[0].isdummy()
                        else None
                    )
                else:
                    chunk_kw["columns_value"] = (
                        proxy_chunk.inputs[0].inputs[0].columns_value
                        if splits[1].isdummy()
                        else None
                    )
                chunk_kw["dtypes"] = proxy_chunk.inputs[0].inputs[0].dtypes
                reduce_idx = (
                    (align_axis_idx, j) if align_axis == 0 else (j, align_axis_idx)
                )
                reduce_op = DataFrameIndexAlign(
                    stage=OperandStage.reduce,
                    n_reducers=out_shape[shuffle_axis],
                    i=j,
                    sparse=proxy_chunk.issparse(),
                    output_types=[OutputType.dataframe],
                )
                params = reduce_op.build_reduce_chunk_kw(
                    [proxy_chunk], shape=(np.nan, np.nan), index=reduce_idx, **chunk_kw
                )
                out_chunks.append(reduce_op.new_chunk([proxy_chunk], **params))
        out_chunks.sort(key=lambda c: c.index)
    else:
        # all axes need shuffle
        assert splits.no_axis_can_split()

        # gen map chunks
        map_chunks = []
        for chunk in df.chunks:
            map_op = DataFrameIndexAlign(
                stage=OperandStage.map,
                sparse=chunk.issparse(),
                index_shuffle_size=out_shape[0],
                column_shuffle_size=out_shape[1],
                output_types=[OutputType.dataframe],
            )
            params = map_op.build_map_chunk_kw(
                [chunk], shape=(np.nan, np.nan), index=chunk.index
            )
            map_chunks.append(map_op.new_chunk([chunk], **params))

        proxy_chunk = DataFrameShuffleProxy(
            output_types=[OutputType.dataframe]
        ).new_chunk(map_chunks, shape=())

        # gen reduce chunks
        out_indices = list(itertools.product(*(range(s) for s in out_shape)))
        for out_idx in out_indices:
            reduce_op = DataFrameIndexAlign(
                stage=OperandStage.reduce,
                n_reducers=len(out_indices),
                i=out_idx,
                sparse=proxy_chunk.issparse(),
                output_types=[OutputType.dataframe],
            )
            params = reduce_op.build_reduce_chunk_kw(
                [proxy_chunk], index=out_idx, shape=(np.nan, np.nan)
            )
            out_chunks.append(reduce_op.new_chunk([proxy_chunk], **params))

    return out_chunks


def align_dataframe_dataframe(left, right, axis=None):
    left_index_chunks = [c.index_value for c in left.cix[:, 0]]
    right_index_chunks = [c.index_value for c in right.cix[:, 0]]
    left_columns_chunks = [c.columns_value for c in left.cix[0, :]]
    right_columns_chunks = [c.columns_value for c in right.cix[0, :]]

    axis = validate_axis(axis) if axis is not None else None
    if axis is None or axis == 0:
        index_splits, index_chunk_shape = _calc_axis_splits(
            left.index_value, right.index_value, left_index_chunks, right_index_chunks
        )
    else:
        index_splits, index_chunk_shape = None, None

    if axis is None or axis == 1:
        columns_splits, column_chunk_shape = _calc_axis_splits(
            left.columns_value,
            right.columns_value,
            left_columns_chunks,
            right_columns_chunks,
        )
    else:
        columns_splits, column_chunk_shape = None, None

    splits = _MinMaxSplitInfo(index_splits, columns_splits)
    out_left_chunk_shape = (
        len(index_chunk_shape or list(itertools.chain(*index_splits._left_split)))
        if index_splits is not None
        else left.chunk_shape[0],
        len(column_chunk_shape or list(itertools.chain(*columns_splits._left_split)))
        if columns_splits is not None
        else left.chunk_shape[1],
    )
    if axis is None:
        out_right_chunk_shape = out_left_chunk_shape
    else:
        out_right_chunk_shape = (
            len(index_chunk_shape or list(itertools.chain(*index_splits._right_split)))
            if index_splits is not None
            else right.chunk_shape[0],
            len(
                column_chunk_shape
                or list(itertools.chain(*columns_splits._right_split))
            )
            if columns_splits is not None
            else right.chunk_shape[1],
        )
    left_chunks = _gen_dataframe_chunks(splits, out_left_chunk_shape, 0, left)
    right_chunks = _gen_dataframe_chunks(splits, out_right_chunk_shape, 1, right)

    index_nsplits = columns_nsplits = None
    if axis is None or axis == 0:
        if _is_index_identical(left_index_chunks, right_index_chunks):
            index_nsplits = left.nsplits[0]
        else:
            index_nsplits = [np.nan for _ in range(out_left_chunk_shape[0])]
    if axis is None or axis == 1:
        if _is_index_identical(left_columns_chunks, right_columns_chunks):
            columns_nsplits = left.nsplits[1]
        else:
            columns_nsplits = [np.nan for _ in range(out_left_chunk_shape[1])]

    nsplits = [index_nsplits, columns_nsplits]

    out_chunk_shapes = (out_left_chunk_shape, out_right_chunk_shape)
    return nsplits, out_chunk_shapes, left_chunks, right_chunks


def align_dataframe_series(left, right, axis="columns"):
    axis = validate_axis(axis)
    if axis == 1:
        left_columns_chunks = [c.columns_value for c in left.cix[0, :]]
        right_index_chunks = [c.index_value for c in right.chunks]
        index_splits, chunk_shape = _calc_axis_splits(
            left.columns_value,
            right.index_value,
            left_columns_chunks,
            right_index_chunks,
        )
        dummy_splits, dummy_nsplits = (
            _build_dummy_axis_split(left.chunk_shape[0]),
            left.nsplits[0],
        )
        out_chunk_shape = (
            len(dummy_nsplits),
            len(chunk_shape or list(itertools.chain(*index_splits._left_split))),
        )
        left_chunks = _gen_dataframe_chunks(
            _MinMaxSplitInfo(dummy_splits, index_splits), out_chunk_shape, 0, left
        )
        right_chunks = _gen_series_chunks(
            _MinMaxSplitInfo(index_splits, None), (out_chunk_shape[1],), 1, right
        )
        if _is_index_identical(left_columns_chunks, right_index_chunks):
            index_nsplits = left.nsplits[1]
        else:
            index_nsplits = [np.nan for _ in range(out_chunk_shape[1])]
        nsplits = [dummy_nsplits, index_nsplits]
    else:
        left_index_chunks = [c.index_value for c in left.cix[:, 0]]
        right_index_chunks = [c.index_value for c in right.chunks]
        index_splits, index_chunk_shape = _calc_axis_splits(
            left.index_value, right.index_value, left_index_chunks, right_index_chunks
        )

        dummy_splits, dummy_nsplits = (
            _build_dummy_axis_split(left.chunk_shape[1]),
            left.nsplits[1],
        )
        out_chunk_shape = (
            len(index_chunk_shape or list(itertools.chain(*index_splits._left_split))),
            len(dummy_nsplits),
        )
        left_chunks = _gen_dataframe_chunks(
            _MinMaxSplitInfo(index_splits, dummy_splits), out_chunk_shape, 0, left
        )
        right_chunks = _gen_series_chunks(
            _MinMaxSplitInfo(index_splits, None), (out_chunk_shape[0],), 1, right
        )
        if _is_index_identical(left_index_chunks, right_index_chunks):
            index_nsplits = left.nsplits[0]
        else:
            index_nsplits = [np.nan for _ in range(out_chunk_shape[0])]
        nsplits = [index_nsplits, dummy_nsplits]

    return nsplits, out_chunk_shape, left_chunks, right_chunks


def align_series_series(left, right):
    if is_index_value_identical(left, right):
        # index identical, skip align
        return left.nsplits, left.chunk_shape, left.chunks, right.chunks

    left_index_chunks = [c.index_value for c in left.chunks]
    right_index_chunks = [c.index_value for c in right.chunks]

    index_splits, index_chunk_shape = _calc_axis_splits(
        left.index_value, right.index_value, left_index_chunks, right_index_chunks
    )

    out_chunk_shape = (
        len(index_chunk_shape or list(itertools.chain(*index_splits._left_split))),
    )
    splits = _MinMaxSplitInfo(index_splits, None)

    left_chunks = _gen_series_chunks(splits, out_chunk_shape, 0, left)
    right_chunks = _gen_series_chunks(splits, out_chunk_shape, 1, right)
    index_nsplits = [np.nan for _ in range(out_chunk_shape[0])]
    nsplits = [index_nsplits]
    return nsplits, out_chunk_shape, left_chunks, right_chunks
