"""
Tests for `load_data` in load_data.py
"""

import pickle

import pandas as pd
from load_data import load_data

import sys
import reprlib
from joblib.memory import MemorizedFunc


def get_original_code(func):
    if isinstance(func, MemorizedFunc):
        return func.func.__code__
    return func.__code__


def debug_info_print(func):
    aRepr = reprlib.Repr()
    aRepr.maxother=300
    def wrapper(*args, **kwargs):
        original_code = get_original_code(func)
        def local_trace(frame, event, arg):
            if event == "return" and frame.f_code == original_code:
                print("\n" + "="*20 + "Running data_load code, local variable values:" + "="*20)
                for k, v in frame.f_locals.items():
                    printed = aRepr.repr(v)
                    print(f"{k}:\n {printed}")
                print("="*20 + "Local variable values end" + "="*20)
            return local_trace
        
        sys.settrace(local_trace)
        try:
            return func(*args, **kwargs)
        finally:
            sys.settrace(None)
    return wrapper

X, y, X_test, test_ids = debug_info_print(load_data)()


def get_length(data):
    return data.shape[0] if hasattr(data, 'shape') else len(data)


def get_width(data):
    return data.shape[1:] if hasattr(data, 'shape') else 1


def get_column_list(data):
    return data.columns.tolist() if isinstance(data, pd.DataFrame) else None

assert X is not None, "Training data (X) is None."
assert y is not None, "Training labels (y) are None."
assert X_test is not None, "Test data (X_test) is None."
assert test_ids is not None, "Test IDs (test_ids) are None."

assert get_length(X_test) == get_length(
    test_ids
), f"Mismatch in length of test images and test IDs: X_test ({get_length(X_test)}) and test_ids ({get_length(test_ids)})"
assert get_length(X) == get_length(
    y
), f"Mismatch in length of training images and labels: X ({get_length(X)}) and y ({get_length(y)})"

assert get_length(X) != 0, f"Training data is empty."
assert get_length(y) != 0, f"Training labels are empty."
assert get_length(X_test) != 0, f"Test data is empty."

assert get_width(X) == get_width(
    X_test
), "Mismatch in width of training and test data. Width means the number of features."

if isinstance(X, pd.DataFrame) and isinstance(X_test, pd.DataFrame):
    assert get_column_list(X) == get_column_list(X_test), "Mismatch in column names of training and test data."

assert get_width(X) == get_width(
    X_test
), "Mismatch in width of training and test data. Width means the number of features."

print("Data loader test passed successfully. Length of test images matches length of test IDs.")
