"""Pipeline-related functionality."""

import itertools
import os

import git
import typer

import calkit
from calkit.environments import get_env_lock_fpath
from calkit.models.pipeline import (
    InputsFromStageOutputs,
    PathOutput,
    Pipeline,
)


def _expand_matrix(input_dict: dict[str, list]) -> list[dict]:
    """Restructure a dictionary with list values into a list of dictionaries,
    where each dictionary represents a permutation of the input dictionary's
    values.
    """
    keys = list(input_dict.keys())
    values = list(input_dict.values())
    # Create all combinations of values using itertools.product
    combinations = itertools.product(*values)
    # Create a list of dictionaries
    list_of_dicts = []
    for combination in combinations:
        list_of_dicts.append(dict(zip(keys, combination)))
    return list_of_dicts


def to_dvc(
    ck_info: dict | None = None,
    wdir: str | None = None,
    write: bool = False,
    verbose: bool = False,
) -> dict:
    """Transpile a Calkit pipeline to a DVC pipeline."""
    if ck_info is None:
        ck_info = calkit.load_calkit_info(wdir=wdir)
    if "pipeline" not in ck_info:
        raise ValueError("No pipeline found in calkit.yaml")
    try:
        pipeline = Pipeline.model_validate(ck_info["pipeline"])
    except Exception as e:
        raise ValueError(f"Pipeline is not defined properly: {e}")
    dvc_stages = {}
    # First, create stages for checking/exporting all environments used in the
    # pipeline
    used_envs = set([stage.environment for stage in pipeline.stages.values()])
    env_lock_fpaths = {}
    for env_name, env in ck_info.get("environments", {}).items():
        if env_name not in used_envs:
            continue
        env_fpath = env.get("path")
        lock_fpath = get_env_lock_fpath(
            env=env, env_name=env_name, as_posix=True
        )
        cmd = f"calkit check environment --name {env_name}"
        if lock_fpath is None:
            continue
        deps = []
        outs = []
        if env_fpath is not None:
            deps.append(env_fpath)
        # Docker envs sometimes have deps, so add those too
        if env.get("deps", []):
            deps += env["deps"]
        outs.append({lock_fpath: dict(cache=False, persist=True)})
        stage = dict(cmd=cmd, deps=deps, outs=outs, always_changed=True)
        stage["desc"] = (
            "Automatically generated by Calkit. "
            "Changes made here will be overwritten."
        )
        dvc_stages[f"_check-env-{env_name}"] = stage
        env_lock_fpaths[env_name] = lock_fpath
    # Now convert Calkit stages into DVC stages
    for stage_name, stage in pipeline.stages.items():
        dvc_stage = stage.to_dvc()
        # Add environment lock file to deps
        env_lock_fpath = env_lock_fpaths.get(stage.environment)
        if (
            env_lock_fpath is not None
            and env_lock_fpath not in dvc_stage["deps"]
        ):
            dvc_stage["deps"].append(env_lock_fpath)
        # Check if this stage iterates, which means we should create a matrix
        # stage
        if stage.iterate_over is not None:
            # Process a list of iterations into a DVC matrix stage
            dvc_matrix = {}
            format_dict = {}
            for iteration in stage.iterate_over:
                arg_name = iteration.arg_name
                dvc_matrix[arg_name] = iteration.expand_values(
                    params=ck_info.get("parameters", {})
                )
                # Now replace arg name in cmd, deps, and outs with
                # ${item.{arg_name}}
                format_dict[arg_name] = f"${{item.{arg_name}}}"
            try:
                cmd = dvc_stage["cmd"]
                cmd = cmd.format(**format_dict)
                dvc_stage["cmd"] = cmd
            except Exception as e:
                raise ValueError(
                    (
                        f"Failed to format cmd '{cmd}': "
                        f"{e.__class__.__name__}: {e}"
                    )
                )
            formatted_deps = []
            formatted_outs = []
            for dep in dvc_stage.get("deps", []):
                try:
                    formatted_deps.append(dep.format(**format_dict))
                except Exception as e:
                    raise ValueError(
                        (
                            f"Failed to format dep '{dep}': "
                            f"{e.__class__.__name__}: {e}"
                        )
                    )
            for out in dvc_stage.get("outs", []):
                if isinstance(out, dict):
                    formatted_outs.append(
                        {
                            str(list(out.keys())[0]).format(
                                **format_dict
                            ): dict(list(out.values())[0])
                        }
                    )
                else:
                    formatted_outs.append(out.format(**format_dict))
            dvc_stage["deps"] = formatted_deps
            dvc_stage["outs"] = formatted_outs
            dvc_stage["matrix"] = dvc_matrix
        # Add a description to the DVC stage
        desc = (
            f"Automatically generated from the '{stage_name}' stage "
            "in calkit.yaml. Changes made here will be overwritten."
        )
        dvc_stage["desc"] = desc
        # If this is a Jupyter Notebook stage, we need to add a clean stage
        if stage.kind == "jupyter-notebook":
            clean_stage_name = f"_clean-nb-{stage_name}"
            dvc_stages[clean_stage_name] = stage.dvc_clean_stage
            dvc_stages[clean_stage_name]["desc"] = desc
        dvc_stages[stage_name] = dvc_stage
        # Check for any outputs that should be ignored
        if write:
            repo = git.Repo(wdir)
            # Ensure we catch any Jupyter Notebook outputs
            outputs = stage.outputs.copy()
            if stage.kind == "jupyter-notebook":
                outputs += stage.notebook_outputs
            for out in outputs:
                if (
                    isinstance(out, PathOutput)
                    and out.storage is None
                    and not repo.ignored(out.path)
                ):
                    gitignore_path = ".gitignore"
                    if wdir is not None:
                        gitignore_path = os.path.join(wdir, gitignore_path)
                    with open(gitignore_path, "a") as f:
                        f.write("\n" + out.path + "\n")
    # Now process any inputs from stage outputs
    for stage_name, stage in pipeline.stages.items():
        for i in stage.inputs:
            if isinstance(i, InputsFromStageOutputs):
                dvc_outs = dvc_stages[i.from_stage_outputs]["outs"]
                for out in dvc_outs:
                    if out not in dvc_stages[stage_name]["deps"]:
                        # Handle cases where outs are from a matrix,
                        # in which case this output could become a list of
                        # outputs
                        if isinstance(out, dict):
                            out = list(out.keys())[0]
                        if "${item." in out:
                            extra_outs = []
                            dvc_matrix = dvc_stages[i.from_stage_outputs][
                                "matrix"
                            ]
                            replacements = _expand_matrix(dvc_matrix)
                            for r in replacements:
                                out_i = out
                                for var_name, var_val in r.items():
                                    out_i = out_i.replace(
                                        f"${{item.{var_name}}}",
                                        str(var_val),
                                    )
                                extra_outs.append(out_i)
                            for out_i in extra_outs:
                                if out_i not in dvc_stages[stage_name]["deps"]:
                                    dvc_stages[stage_name]["deps"].append(
                                        out_i
                                    )
                        else:
                            dvc_stages[stage_name]["deps"].append(out)
    if write:
        if os.path.isfile("dvc.yaml"):
            with open("dvc.yaml") as f:
                dvc_yaml = calkit.ryaml.load(f)
        else:
            dvc_yaml = {}
        if dvc_yaml is None:
            dvc_yaml = {}
        existing_stages = dvc_yaml.get("stages", {})
        for stage_name, stage in existing_stages.items():
            # Skip private stages (ones whose names start with an underscore)
            # and stages that are automatically generated
            if (
                not stage_name.startswith("_")
                and stage_name not in dvc_stages
                and not stage.get("desc", "").startswith(
                    "Automatically generated"
                )
            ):
                dvc_stages[stage_name] = stage
        dvc_yaml["stages"] = dvc_stages
        with open("dvc.yaml", "w") as f:
            if verbose:
                typer.echo("Writing to dvc.yaml")
            calkit.ryaml.dump(dvc_yaml, f)
    return dvc_stages
