"""
Tests for `model_workflow` in model01.py
"""
import sys
import time

from feature import feat_eng
from load_data import load_data
from model01 import model_workflow
from sklearn.model_selection import train_test_split


def log_execution_results(start_time, val_pred, test_pred, hypers, execution_label):
    """Log the results of a single model execution."""
    feedback_str = f"{execution_label} end.\n"
    feedback_str += f"Validation predictions shape: {val_pred.shape if val_pred is not None else 'None'}\n"
    feedback_str += f"Test predictions shape: {test_pred.shape if test_pred is not None else 'None'}\n"
    feedback_str += f"Hyperparameters: {hypers if hypers is not None else 'None'}\n"
    feedback_str += f"Execution time: {time.time() - start_time:.2f} seconds.\n"
    print(feedback_str)


import reprlib
aRepr = reprlib.Repr()
aRepr.maxother=300

# Load and preprocess data
X, y, test_X, test_ids = load_data()
X, y, test_X = feat_eng(X, y, test_X)

print(f"X.shape: {X.shape}" if hasattr(X, 'shape') else f"X length: {len(X)}")
print(f"y.shape: {y.shape}" if hasattr(y, 'shape') else f"y length: {len(y)}")
print(f"test_X.shape: {test_X.shape}" if hasattr(test_X, 'shape') else f"test_X length: {len(test_X)}")
print(f"test_ids length: {len(test_ids)}")

train_X, val_X, train_y, val_y = train_test_split(X, y, test_size=0.8, random_state=42)


import sys
import reprlib
from joblib.memory import MemorizedFunc


def get_original_code(func):
    if isinstance(func, MemorizedFunc):
        return func.func.__code__
    return func.__code__

print("train_X:", aRepr.repr(train_X))
print("train_y:", aRepr.repr(train_y))
print("val_X:", aRepr.repr(val_X))
print("val_y:", aRepr.repr(val_y))

print(f"train_X.shape: {train_X.shape}" if hasattr(train_X, 'shape') else f"train_X length: {len(train_X)}")
print(f"train_y.shape: {train_y.shape}" if hasattr(train_y, 'shape') else f"train_y length: {len(train_y)}")
print(f"val_X.shape: {val_X.shape}" if hasattr(val_X, 'shape') else f"val_X length: {len(val_X)}")
print(f"val_y.shape: {val_y.shape}" if hasattr(val_y, 'shape') else f"val_y length: {len(val_y)}")



def debug_info_print(func):
    def wrapper(*args, **kwargs):
        original_code = get_original_code(func)
        def local_trace(frame, event, arg):
            if event == "return" and frame.f_code == original_code:
                print("\n" + "="*20 + "Running model training code, local variable values:" + "="*20)
                for k, v in frame.f_locals.items():
                    printed = aRepr.repr(v)
                    print(f"{k}:\n {printed}")
                print("="*20 + "Local variable values end" + "="*20)
            return local_trace
        
        sys.settrace(local_trace)
        try:
            return func(*args, **kwargs)
        finally:
            sys.settrace(None)
    return wrapper

# First execution
print("The first execution begins.\n")
start_time = time.time()
val_pred, test_pred, hypers = debug_info_print(model_workflow)(
    X=train_X,
    y=train_y,
    val_X=val_X,
    val_y=val_y,
    test_X=None,
)
log_execution_results(start_time, val_pred, test_pred, hypers, "The first execution")

# Second execution
print("The second execution begins.\n")
start_time = time.time()
val_pred, test_pred, final_hypers = debug_info_print(model_workflow)(
    X=train_X,
    y=train_y,
    val_X=None,
    val_y=None,
    test_X=test_X,
    hyper_params=hypers,
)
log_execution_results(start_time, val_pred, test_pred, final_hypers, "The second execution")

print("Model code test end.")
