"""
Tests for `feat_eng` in feature.py
"""


from copy import deepcopy
import sys
import numpy as np
import pandas as pd
from feature import feat_eng
from load_data import load_data
import reprlib
aRepr = reprlib.Repr()
aRepr.maxother=300

X, y, X_test, test_ids = load_data()
print("X:", aRepr.repr(X))
print("y:", aRepr.repr(y))
print("X_test:", aRepr.repr(X_test))
print("test_ids", aRepr.repr(test_ids))

print(f"X.shape: {X.shape}" if hasattr(X, 'shape') else f"X length: {len(X)}")
print(f"y.shape: {y.shape}" if hasattr(y, 'shape') else f"y length: {len(y)}")
print(f"X_test.shape: {X_test.shape}" if hasattr(X_test, 'shape') else f"X_test length: {len(X_test)}")
print(f"test_ids length: {len(test_ids)}")

X_loaded = deepcopy(X)
y_loaded = deepcopy(y)
X_test_loaded = deepcopy(X_test)

import sys
import reprlib
from joblib.memory import MemorizedFunc


def get_original_code(func):
    if isinstance(func, MemorizedFunc):
        return func.func.__code__
    return func.__code__


def debug_info_print(func):
    def wrapper(*args, **kwargs):
        original_code = get_original_code(func)
        def local_trace(frame, event, arg):
            if event == "return" and frame.f_code == original_code:
                print("\n" + "="*20 + "Running feat_eng code, local variable values:" + "="*20)
                for k, v in frame.f_locals.items():
                    printed = aRepr.repr(v)
                    print(f"{k}:\n {printed}")
                print("="*20 + "Local variable values end" + "="*20)
            return local_trace
        
        sys.settrace(local_trace)
        try:
            return func(*args, **kwargs)
        finally:
            sys.settrace(None)
    return wrapper
X, y, X_test = debug_info_print(feat_eng)(X, y, X_test)


def get_length(data):
    return data.shape[0] if hasattr(data, 'shape') else len(data)


def get_width(data):
    return 1 if isinstance(data, list) else data.shape[1:]


def get_column_list(data):
    return data.columns.tolist() if isinstance(data, pd.DataFrame) else None


assert X is not None, "The feature engineering function returned None for X."
assert y is not None, "The feature engineering function returned None for y."
assert X_test is not None, "The feature engineering function returned None for X_test."

assert get_length(X_test) == get_length(
    test_ids
), f"Mismatch in length of test images and test IDs: X_test ({get_length(X_test)}) and test_ids ({get_length(test_ids)})"
assert get_length(X) == get_length(
    y
), f"Mismatch in length of training images and labels: X ({get_length(X)}) and y ({get_length(y)})"

assert get_length(X) != 0, f"Training data is empty."
assert get_length(y) != 0, f"Training labels are empty."
assert get_length(X_test) != 0, f"Test data is empty."

assert get_width(X) == get_width(
    X_test
), "Mismatch in width of training and test data. Width means the number of features."

if isinstance(X, pd.DataFrame) and isinstance(X_test, pd.DataFrame):
    assert get_column_list(X) == get_column_list(X_test), "Mismatch in column names of training and test data."

if isinstance(X, pd.DataFrame):
    def normalize_dtype(dtype):
        return "numeric" if np.issubdtype(dtype, np.number) else str(dtype)

    X_dtypes_unique_sorted = sorted(set(normalize_dtype(dt) for dt in X.dtypes.unique()))
    X_loaded_dtypes_unique_sorted = sorted(set(normalize_dtype(dt) for dt in X_loaded.dtypes.unique()))

    X_dtypes_unique_sorted_new = [
        dt for dt in X_dtypes_unique_sorted if dt not in X_loaded_dtypes_unique_sorted and dt != "object"
    ]
    assert (
        np.dtypes.ObjectDType in X_loaded_dtypes_unique_sorted or len(X_dtypes_unique_sorted_new) == 0
    ), f"feature engineering has produced new data types which is not allowed, data loader data types are {X_loaded_dtypes_unique_sorted} and feature engineering data types are {X_dtypes_unique_sorted}"


print(
    "Feature Engineering test passed successfully. All checks including length, width, and data types have been validated."
)
