"""
Tests for `ensemble_workflow` in ensemble.py

A qualified ensemble_workflow implementation should:
- Return predictions
- Have correct shapes for inputs and outputs
- Use validation data appropriately
- Generate a scores.csv file
"""

import numpy as np
import pandas as pd
from pathlib import Path
from sklearn.model_selection import train_test_split
import torch
import tensorflow as tf
from load_data import load_data
from feature import feat_eng
from ensemble import ensemble_workflow

def print_preds_info(model_name, data_type, preds):
    if preds is None:
        print(f"Model {model_name} {data_type} predictions: None")
    else:
        print(f"Model {model_name} {data_type} predictions shape: {preds.shape}")

        print("Showing a preview of the predictions (first few entries only):")
        if isinstance(preds, (pd.DataFrame, pd.Series)):
            print(preds.head())
        elif isinstance(preds, (np.ndarray, torch.Tensor, tf.Tensor)):
            print(preds[:2])
        elif isinstance(preds, list):
            print(pd.DataFrame(preds[:5]))
        else:
            print(f"Unknown prediction type: {type(preds)}")

def get_length(data):
    return data.shape[0] if hasattr(data, 'shape') else len(data)

X, y, test_X, test_ids = load_data()
X, y, test_X = feat_eng(X, y, test_X)
train_X, val_X, train_y, val_y = train_test_split(X, y, test_size=0.2, random_state=42)

# Print the types of train_y and val_y
print(f"train_y type: {type(train_y)}, val_y type: {type(val_y)}")

test_preds_dict = {}
val_preds_dict = {}
{% for mn in model_names %}
from {{mn}} import model_workflow as {{mn}}_workflow
val_preds_dict["{{mn}}"], test_preds_dict["{{mn}}"], _ = {{mn}}_workflow(
    X=train_X,
    y=train_y,
    val_X=val_X,
    val_y=val_y,
    test_X=test_X
)

print_preds_info("{{mn}}", "test", test_preds_dict["{{mn}}"])
{% endfor %}

for key in val_preds_dict.keys():
    if val_preds_dict[key] is None: 
        print(f"Model {key} validation predictions (val_preds_dict[key]) is None.")
    elif isinstance(val_preds_dict[key], list):
        print(f"Model {key} validation predictions (val_preds_dict[key]) (list type) length: {len(val_preds_dict[key])}")
    else:
        print(f"Model {key} validation predictions (val_preds_dict[key]) shape: {val_preds_dict[key].shape}")

    if test_preds_dict[key] is None: 
        print(f"Model {key} test predictions (test_preds_dict[key]) is None.")
    elif isinstance(test_preds_dict[key], list):
        print(f"Model {key} test predictions (test_preds_dict[key]) (list type) length: {len(test_preds_dict[key])}")
    else:
        print(f"Model {key} test predictions (test_preds_dict[key]) shape: {test_preds_dict[key].shape}")

print(f"val_y.shape: {val_y.shape}" if not isinstance(val_y, list) else f"val_y(list)'s length: {len(val_y)}")

import sys
import reprlib
def debug_info_print(func):
    aRepr = reprlib.Repr()
    aRepr.maxother=300
    def wrapper(*args, **kwargs):
        def local_trace(frame, event, arg):
            if event == "return" and frame.f_code == func.__code__:
                print("\n" + "="*20 + "Running ensemble code, local variable values:" + "="*20)
                for k, v in frame.f_locals.items():
                    printed = aRepr.repr(v)
                    print(f"{k}:\n {printed}")
                print("="*20 + "Local variable values end" + "="*20)
            return local_trace
        
        sys.settrace(local_trace)
        try:
            return func(*args, **kwargs)
        finally:
            sys.settrace(None)
    return wrapper


# Run ensemble
final_pred = debug_info_print(ensemble_workflow)(test_preds_dict, val_preds_dict, val_y)

print_preds_info("ensemble", "test", final_pred)

# Check type
pred_type = type(next(iter(test_preds_dict.values())))
assert isinstance(final_pred, pred_type), (
    f"Type mismatch: 'final_pred' is of type {type(final_pred)}, but expected {pred_type} "
)

# Check shape
if isinstance(final_pred, (list, np.ndarray, pd.DataFrame, torch.Tensor, tf.Tensor)):
    assert get_length(final_pred) == get_length(test_X), (
        f"Wrong output sample size: get_length(final_pred)={get_length(final_pred)} "
        f"vs. get_length(test_X)={get_length(test_X)}"
    )

# check scores.csv
assert Path("scores.csv").exists(), "scores.csv is not generated"
score_df = pd.read_csv("scores.csv", index_col=0)
model_set_in_scores = set(score_df.index)

assert model_set_in_scores == set({{model_names}}).union({"ensemble"}), (
    f"The scores dataframe does not contain the correct model names as index.\ncorrect model names are: {{model_names}} + ['ensemble']\nscore_df is:\n{score_df}"
)
assert score_df.index.is_unique, "The scores dataframe has duplicate model names."
assert score_df.columns.tolist() == ["{{metric_name}}"], f"The column names of the scores dataframe should be ['{{metric_name}}'], but is '{score_df.columns.tolist()}'"

# Check for NaN values in score_df
assert not score_df.isnull().values.any(), (
    f"The scores dataframe contains NaN values at the following locations:\n{score_df[score_df.isnull().any(axis=1)]}"
)


print("Ensemble test end.")
