Creating custom functions or classes for the kfold pipeline
Introduction
In EEG decoding pipelines, it is common to evaluate model performance using cross-validation techniques such as kfold, which is our case. To ensure proper modularity and prevent data leakage across folds, the pipeline is typically divided into two main transformation stages: pre-folding and post-folding.
Pre-foldingrefers to all operations that are applied to the data before the fold split. These transformations are shared across all folds and must be strictly independent of the training/testing process. Common examples include filtering, artifact rejection, or epoch extraction.Post-foldingtransformations, defined through thepos_foldingdictionary, are applied within each fold, only after the data has been split into training and testing sets. This guarantees that feature extraction, normalization, or classifier training is done independently for each fold, which is essential to avoid information leakage and obtain reliable cross-validation results.
The separation between these two stages is crucial. Applying operations like scaling or feature extraction globally before folding would cause the test data to influence the learned transformations—this violates cross-validation assumptions and leads to overly optimistic performance estimates.
This tutorial focuses on creating custom pre_folding and pos_folding components—either functions or classes—that can be seamlessly integrated into the kfold pipeline. For the bciflow package, we have a distinction from pre_folding and pos_folding. Only the pos_folding can have custom classes, while both can have custom functions.
Basic Usage Pattern
An example usage within the kfold pipeline looks like this:
tf = function #because it is only used on pre_folding
tf2 = function #because it can be used on both
sf = Class() or function #because it is only used on pos_folding
fe = Class() or function
fs = Class() or function
pre_folding = {'tf':(tf,{})}
pos_folding = {
'tf2':(tf2, {}),
'sf': (sf, {}),
'fe':(fe, {'flattening': True}),
'fs': (fs,{})
'clf': (lda(), {})
}
results = kfold(
target=dataset,
start_window=dataset['events']['cue'][0] + 0.5,
pre_folding=pre_folding,
pos_folding=pos_folding
)
Each key in the pre_folding and pos_folding dictionary must map to a tuple (object, kwargs), where:
objectis a function or class instancekwargsis a dictionary of keyword arguments passed to itstransformmethod or function call
Component Naming Conventions
The dictionary keys typically follow standardized abbreviations to identify the type of transformation being applied:
sf: Spatial Filter — e.g., CSP, xDAWN, or ICA. Can be applied in bothpre_foldingandpos_folding, depending on whether it requires supervision.tf: Temporal Filter — e.g., bandpass or notch filters. Usually appears inpre_folding, but can also be applied inpos_foldingif it requires adaptation to training data.fs: Feature Selection — selects relevant features (e.g., variance threshold, mutual information). Must appear only inpos_foldingto avoid data leakage.fe: Feature Extraction — transforms the data into a feature space (e.g., mean amplitude, power spectral density). Always performed inpos_folding.clf: Classifier — the final predictive model (e.g., LDA, SVM). Defined inpos_folding.
Pipeline Structure Considerations
Filters (sf, tf) may be safely applied in both pre_folding and pos_folding, depending on whether the transformation is unsupervised (e.g., FIR filters) or supervised (e.g., CSP). In contrast, operations such as fs (feature selection) and fe (feature extraction) must be strictly placed in the pos_folding stage to ensure that only training data is used for parameter estimation, thereby preserving the validity of the cross-validation protocol.
Basic Requirements
1. If You Use a Function
The function must have the following signature:
def my_function(eegdata: dict, **kwargs):
...
return eegdata_transformed
Requirements:
Inputs:
eegdata, and optional keyword argumentsOutput: modified version of
eegdata
Example:
This function removes the mean of the EEG signal along the time dimension, effectively centering the signal for each trial, band, and electrode.
import numpy as np
def removeEEGSignalMean(eegdata):
X = eegdata['X'].copy()
# Compute mean over time axis
mean = np.mean(X, axis=-1, keepdims=True) # shape: (trials, bands, electrodes, 1)
# Subtract mean from signal
X_ = X - mean
eegdata['X'] = X_ # shape: (trials, bands, electrodes, time)
return eegdata
Usage:
pre_folding = {}
pos_folding = {
'tf': (removeEEGSignalMean, {}),
...
'clf': (lda(), {})
}
Or
pre_folding = {'tf': (removeEEGSignalMean, {}),}
pos_folding = {
...
'clf': (lda(), {})
}
2. If You Use a Class
Your class must implement the following methods:
class MyTransformer:
def fit(self, eegdata:dict, **kwargs):
...
return self
def transform(self, eegdata:dict, **kwargs):
...
return eegdata_transformed
def fit_transform(self, eegdata:dict, **kwargs):
...
return self.fit(eegdata).transform(eegdata)
Expected Return Types
All custom steps must comply with the return format expected by the pipeline:
The
fit()method of a class should returnself.The
transform()method of a class and any standalone function must return adict-like object with the structure ofeegdata.
The eegdata dictionary typically includes a key 'X', which contains the EEG data in a 4D array of shape (trials, bands, electrodes, time) or its flattened variant if flattening=True is passed.
Note: You must always return the updated eegdata dictionary even if you perform operations in-place to ensure the pipeline remains functional and modular.
Example:
This class performs standardization (Z-score) across the EEG time domain, considering the shape (trials, bands, electrodes, time).
import numpy as np
class StandardScalerEEG:
def __init__(self):
pass
def fit(self, eegdata: dict):
X = eegdata['X']
bands, electrodes = X.shape[1], X.shape[2]
X_reshaped = X.transpose(1, 2, 0, 3).reshape(bands, electrodes, -1)
self.mean_ = np.mean(X_reshaped, axis=-1, keepdims=True) #shape (bands, electrodes, 1)
self.std_ = np.std(X_reshaped, axis=-1, keepdims=True) #shape (bands, electrodes, 1)
return self
def transform(self, eegdata: dict):
X = eegdata['X']
shape = X.shape #(trials, bands, electrodes, time)
X_trans = X.transpose(1, 2, 0, 3) #(bands, electrodes, trials, time)
X_scaled = (X_trans - self.mean_[..., None]) / self.std_[..., None] #(bands, electrodes, trials, time)
X_scaled = X_scaled.transpose(2, 0, 1, 3) # volta para (trials, bands, electrodes, time)
eegdata['X'] = X_scaled
return eegdata
def fit_transform(self, eegdata: dict):
return self.fit(eegdata).transform(eegdata)
Usage:
pre_folding = {}
pos_folding = {
'sf': (StandardScalerEEG(), {}),
...
'clf': (lda(), {})
}
These examples demonstrate how both object-oriented and functional styles can be effectively integrated into the pipeline.