Source code for bciflow.modules.core.kfold

from sklearn.model_selection import StratifiedKFold
import numpy as np
import pandas as pd
import inspect
from ..core.util import util
from typing import Dict, Any, List, Optional

[docs] def kfold(target: Dict[str, Any], start_window: float or list, start_test_window: Optional[float or list] = None, pre_folding: Optional[Dict[str, tuple]] = None, pos_folding: Dict[str, tuple] = {}, window_size: float = 1.0, source: list = None) -> pd.DataFrame: ''' This method is used to perform a stratified k-fold cross-validation. The method is designed to work with eegdata dictionary. Parameters ---------- target : dict Input EEG data in the form of a dictionary. The dictionary should contain the following keys: - 'X': The EEG data as a numpy array. - 'y': The labels corresponding to the EEG data. - 'sfreq': The sampling frequency of the EEG data. - 'y_dict': A dictionary mapping the labels to integers. - 'events': A dictionary describing the event markers. - 'ch_names': A list of channel names. - 'tmin': The start time of the EEG data. start_window : int The start time of the window to be used in the crop method of eegdata for the training set. start_test_window : int The start time of the window to be used in the crop method of eegdata for the test set. pre_folding : dict A dictionary containing the preprocessing functions to be applied to the data before the cross-validation. The keys are the names of the preprocessing functions, and the values are tuples containing the function and its parameters. pos_folding : dict A dictionary containing the postprocessing functions to be applied to the data before the cross-validation. The keys are the names of the postprocessing functions, and the values are the functions. The 'clf' key is reserved for the classifier, and its value should be a tuple containing the classifier and its parameters. window_size : float The size of the window to be used in the crop method of eegdata. source : list List of Eeg data from anothers subjects to be used as a source for the Transfer Learning modules Returns ------- results : pandas.DataFrame A pandas dataframe containing the results of the cross-validation. The columns are 'fold', 'tmin', 'true_label', and the labels of the events in the target object. Raises ------ ValueError If any of the input parameters are invalid Example ------- Applying k-fold cross-validation on EEG data: >>> from bciflow.modules.core.kfold import kfold >>> import numpy as np >>> target = { 'X': np.random.rand(100, 64, 256), 'y': np.random.randint(0, 2, size=100), 'sfreq': 256, # Sampling frequency 'y_dict': {0: 'class_0', 1: 'class_1'}, 'events': {'event_1': [0, 50], 'event_2': [51, 100]}, 'ch_names': [f'ch_{i}' for i in range(64)], 'tmin': -0.5 } >>> start_window = 0.0 >>> start_test_window = 0.5 >>> results = kfold(target, start_window, start_test_window) >>> print(results.head()) # Display the first few rows of the results ''' if type(start_window) is float: start_window = [start_window] if start_test_window is None: start_test_window = start_window elif type(start_test_window) is float: start_test_window = [start_test_window] # if not isinstance(start_window, list): # raise ValueError("start_window must be a float or a list of floats") # if not isinstance(start_test_window, list): # raise ValueError("start_test_window must be a list of floats") if pre_folding is None: pre_folding = {} target_dict = {} for tmin_ in start_test_window: target_dict[tmin_] = util.crop(data=target, tmin=tmin_, window_size=window_size, inplace=False) for tmin_ in start_test_window: for name, pre_func in pre_folding.items(): if inspect.isfunction(pre_func[0]): target_dict[tmin_] = util.apply_to_trials(data=target_dict[tmin_], func=pre_func[0], func_param=pre_func[1], inplace=False) else: target_dict[tmin_] = util.apply_to_trials(data=target_dict[tmin_], func=pre_func[0].transform, func_param=pre_func[1], inplace=False) skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) fold_id = 0 results = [] for train_index, test_index in skf.split(target["y"], target["y"]): fold_id += 1 target_train = [] for tmin_ in start_window: target_train.append(util.get_trial(data=target_dict[tmin_], ids=train_index)) target_train = util.concatenate(target_train) target_test = {} for tmin_ in start_test_window: target_test[tmin_] = util.get_trial(data=target_dict[tmin_], ids=test_index) for name, pos_func in pos_folding.items(): if name != 'clf': if inspect.isfunction(pos_func[0]): target_train = pos_func[0](target_train, **pos_func[1]) else: target_train = pos_func[0].fit_transform(target_train, **pos_func[1]) for tmin_ in start_test_window: if inspect.isfunction(pos_func[0]): target_test[tmin_] = pos_func[0](target_test[tmin_], **pos_func[1]) else: target_test[tmin_] = pos_func[0].transform(target_test[tmin_]) clf, clf_param = pos_folding['clf'] if not inspect.isfunction(clf): clf = clf.fit(target_train['X'], target_train['y'], **clf_param) for tmin_ in start_test_window: try: y_pred = clf.predict_proba(target_test[tmin_]['X']) except: y_pred = np.zeros((len(target_test[tmin_]['y']), len(target['y_dict']))) y_pred = np.round(y_pred, 4) for trial_ in range(len(y_pred)): results.append([fold_id, tmin_, util.find_key_with_value(target['y_dict'], target_test[tmin_]['y'][trial_]), *y_pred[trial_]]) results = np.array(results) results = pd.DataFrame(results, columns=['fold', 'tmin', 'true_label', *target['y_dict'].keys()]) return results