import json
import math
import re
from functools import reduce
from typing import Any, Dict, List

from mage_ai.shared.strings import camel_to_snake_case


def camel_case_keys_to_snake_case(d):
    if not isinstance(d, dict):
        return d
    snake_dict = {}
    for key, value in d.items():
        snake_key = camel_to_snake_case(key)
        if isinstance(value, dict):
            value = camel_case_keys_to_snake_case(value)
        elif isinstance(value, list):
            value = [camel_case_keys_to_snake_case(item) for item in value]
        snake_dict[snake_key] = value
    return snake_dict


def dig(obj_arg, arr_or_string):
    if type(arr_or_string) is str:
        arr_or_string = arr_or_string.split(".")
    arr = list(map(str.strip, arr_or_string))

    def _build(obj, key):
        tup = re.split(r"\[(\d+)\]$", key)
        if len(tup) >= 2:
            key, index = filter(lambda x: x, tup)
            if key and index:
                return obj[key][int(index)]
            elif index:
                return obj[int(index)]
        elif obj:
            return obj.get(key)
        else:
            return obj

    return reduce(_build, arr, obj_arg)


def safe_dig(obj_arg, arr_or_string):
    """
    Safely retrieves nested values from a dictionary or list using a dot-separated path.

    Args:
        obj_arg: The object (dictionary or list) to navigate.
        arr_or_string (str or list): A dot-separated path string or a list of
            keys/indexes.

    Returns:
        The value retrieved from the nested structure, or None if any intermediate
            key/index is missing or the object is None.
    """
    if isinstance(arr_or_string, str):
        arr_or_string = arr_or_string.split(".")
    arr = list(map(str.strip, arr_or_string))

    def _build(obj, key):
        # Return None if the object is None or not a dictionary
        if obj is None or not isinstance(obj, dict) and not isinstance(obj, list):
            return None

        tup = re.split(r"\[(\d+)\]$", key)
        if len(tup) >= 2:
            key, index = filter(lambda x: x, tup)
            index = int(index) if index else None
            if key and index is not None:
                if key not in obj:
                    return None  # Return None if the key is not present
                return (
                    obj[key][index]
                    if isinstance(obj[key], list) and len(obj[key]) > index
                    else None
                )
            elif index is not None:
                return (
                    obj[index] if isinstance(obj, list) and len(obj) > index else None
                )
        elif isinstance(obj, dict):
            return obj.get(key)
        else:
            return None

    return reduce(_build, arr, obj_arg)


def flatten(input_data):
    final_data = {}

    for k1, v1 in input_data.items():
        if type(v1) is dict:
            for k2, v2 in v1.items():
                if type(v2) is dict:
                    for k3, v3 in v2.items():
                        final_data[f"{k1}_{k2}_{k3}"] = v3
                else:
                    final_data[f"{k1}_{k2}"] = v2
        else:
            final_data[k1] = v1

    return final_data


def flatten_dict(d, parent_key="", sep="."):
    items = []
    for k, v in d.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_dict(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)


def unflatten_dict(d, sep="."):
    result_dict = {}
    for k, v in d.items():
        parts = k.split(sep)
        current_level = result_dict
        for part in parts[:-1]:
            if part not in current_level:
                current_level[part] = {}
            current_level = current_level[part]
        current_level[parts[-1]] = v
    return result_dict


def get_json_value(str_value, arr_or_string):
    if not str_value:
        return str_value
    try:
        obj_arg = json.loads(str_value)
    except Exception:
        return str_value
    return dig(obj_arg, arr_or_string)


def ignore_keys(d, keys):
    d_keys = d.keys()
    d2 = d.copy()
    for key in keys:
        if key in d_keys:
            d2.pop(key)
    return d2


def ignore_keys_with_blank_values(d: Dict, include_values: List[Any] = None) -> Dict:
    d2 = d.copy()
    for key, value in d.items():
        if not value and (not include_values or value not in include_values):
            d2.pop(key)
    return d2


def extract(d, keys, include_blank_values: bool = False):
    def _build(obj, key):
        val = None
        if key in d:
            val = d[key]
        if include_blank_values or val is not None:
            obj[key] = val
        return obj

    return reduce(_build, keys, {})


def extract_arrays(input_data):
    arr = []
    for _, v in input_data.items():
        if type(v) is list:
            arr.append(v)
    return arr


def group_by(func, arr):
    def _build(obj, item):
        val = func(item)
        if not obj.get(val):
            obj[val] = []
        obj[val].append(item)
        return obj

    return reduce(_build, arr, {})


def index_by(func, arr):
    obj = {}
    for item in arr:
        key = func(item)
        obj[key] = item
    return obj


def merge_dict(a: Dict, b: Dict) -> Dict:
    if a:
        c = a.copy()
    else:
        c = {}

    if not b:
        return c

    c.update(b)

    return c


def replace_dict_nan_value(d):
    def _replace_nan_value(v):
        if isinstance(v, float) and math.isnan(v):
            return None
        return v

    return {k: _replace_nan_value(v) for k, v in d.items()}


def get_safe_value(data: Dict, key: str, default_value):
    return data.get(key, default_value) if data else default_value


def set_value(obj: Dict, keys: List[str], value) -> Dict:
    if len(keys) >= 2:
        for idx in range(len(keys)):
            keys_init = keys[:idx]
            if len(keys_init) >= 1:
                set_value(obj, keys_init, dig(obj, keys_init) or {})

    results = dict(__obj_to_set_value=obj, __value=value)

    key = "".join(f"['{key}']" for key in keys)
    expression = f"__obj_to_set_value{key} = __value"
    exec(expression, results)

    return results["__obj_to_set_value"]


def combine_into(child: Dict, parent: Dict) -> None:
    # Child will merge into parent and override parent values.
    for k, v in child.items():
        if isinstance(v, dict):
            combine_into(v, parent.setdefault(k, {}))
        else:
            parent[k] = v
