# AUTOGENERATED! DO NOT EDIT! File to edit: ../source_nbs/lib_nbs/utils_challenge.ipynb.

# %% auto 0
__all__ = ['majority_filter', 'unique_labelled', 'enforce_min_segment_length', 'label_filter', 'label_continuous_to_list',
           'label_list_to_continuous', 'array_to_df', 'df_to_array', 'file_nonOverlap_reOrg', 'get_VIP',
           'changepoint_assignment', 'changepoint_alpha_beta', 'jaccard_index', 'single_changepoint_error',
           'ensemble_changepoint_error', 'create_binary_segment', 'jaccard_between_segments', 'segment_assignment',
           'metric_anomalous_exponent', 'metric_diffusion_coefficient', 'metric_diffusive_state',
           'check_no_changepoints', 'segment_property_errors', 'extract_ensemble', 'multimode_dist',
           'distribution_distance', 'error_Ensemble_dataset', 'check_prediction_length', 'separate_prediction_values',
           'load_file_to_df', 'error_SingleTraj_dataset', 'when_error_single', 'run_single_task', 'run_ensemble_task',
           'listdir_nohidden', 'codalab_scoring', 'codalab_scoring_local', 'transform_ref_to_res']

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 2
import numpy as np
from scipy.optimize import linear_sum_assignment
from scipy import stats
import pandas
from tqdm.auto import tqdm
import warnings
from pathlib import Path
from .models_phenom import models_phenom

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 6
def majority_filter(data : np.ndarray, # array to apply majority filter
                    window_size : int # Size of the window in which the filter is applied.
                   )-> np.ndarray: # Filtered array
    '''
    Given a vector, applies a majority filter at given window size.
    '''
    
    
    # Ensure the window size is odd
    if window_size % 2 == 0:
        raise ValueError("Window size must be odd")
    half_window = window_size // 2
    n = len(data)
    filtered_data = data.copy()
    
    # Apply initial majority filtering
    for i in range(n):
        start = max(i - half_window, 0)
        end = min(i + half_window + 1, n)
        window = data[start:end]
        mode_result = stats.mode(window, keepdims=True)
        mode_value = mode_result.mode[0]
        filtered_data[i] = mode_value
    
    return filtered_data

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 8
def unique_labelled(arr : list # List or array from which to create 
                   )-> list : # List with new values as labels from unique
    ''' 
    Transforms the values of an input array to their corresponding label given by the uniques in the array.
    '''
    # Dictionary to store the first occurrence of each element
    unique_dict = {}
    uniques = np.array([])
    inverse_indices = np.array([])

    # Iterate through the array and populate the dictionary and uniques list
    for index, value in enumerate(arr):
        if value not in unique_dict:
            unique_dict[value] = len(uniques)
            uniques = np.append(uniques, value)
        inverse_indices = np.append(inverse_indices, unique_dict[value])
    
    return uniques, inverse_indices

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 10
def enforce_min_segment_length(data : np.ndarray, 
                               # The input array containing processed data. The array should be one-dimensional 
                               # and can contain numerical or categorical data.
                               min_seg : int = 3,
                               # The minimum length that any segment in the array should have. 
                               # Segments shorter than this length are merged with neighboring segments.
                              )-> np.ndarray : # A NumPy array with no segments shorter than the specified minimum length.
    
    """
    Ensures that all contiguous segments of the same value in a one-dimensional NumPy
    array are at least as long as the specified minimum segment length. If a segment
    is found to be shorter than this threshold, it is merged with an adjacent segment
    to comply with the minimum segment requirement.    
    """
    
    n = len(data)
    current_value = data[0]
    start_index = 0
    segments = []

    # Identify all segments
    for i in range(1, n):
        if data[i] != current_value:
            segments.append((current_value, start_index, i - 1))
            current_value = data[i]
            start_index = i
    # Add the last segment
    segments.append((current_value, start_index, n - 1))

    # Merge small segments
    i = 0
    while i < len(segments):
        segment_value, seg_start, seg_end = segments[i]
        seg_length = seg_end - seg_start + 1

        if seg_length < min_seg:
            # Try to merge with the previous or next segment
            if i > 0:
                prev_seg_value, prev_seg_start, prev_seg_end = segments[i-1]
                # Prefer merging with previous segment if possible
                segments[i-1] = (prev_seg_value, prev_seg_start, seg_end)
                segments.pop(i)
                continue
            if i < len(segments) - 1:
                next_seg_value, next_seg_start, next_seg_end = segments[i+1]
                segments[i] = (next_seg_value, seg_start, next_seg_end)
                segments.pop(i+1)
                continue
        
        i += 1

    # Reconstruct the array from merged segments
    filtered_data = np.empty(n, dtype=data.dtype)
    for value, start, end in segments:
        filtered_data[start:end+1] = value

    return filtered_data

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 12
def label_filter(label : np.ndarray, # Vector to filter by majority vote
                 window_size : int = 5,  # Size of the window in which the majority filter is applied.
                 min_seg : int = 3 # Minimum segment allowed in the output array
                )-> np.ndarray: # Filtered label vector
    '''
    Given a vector of changing labels, applies a majority filter to smoothen it. Then, enforces that the minimum segment of a particular label is
    bigger or equal to the given minimum segment length min_seg.
    '''

    # If there are no changepoints:
    if np.sum(label[1:] != label[:-1]) == 0:
        return label
    
    # Define dummy vector of same value distribution as label but which values are given by their unique tag/label
    # This is not (in principle needed) for current version but ensures homogeneity for different values of label.
    values, dummy = unique_labelled(label)
    

    dummy_filt = majority_filter(dummy, window_size)
    dummy_final = enforce_min_segment_length(dummy_filt, min_seg)

    # Reset dummy to original label values
    dummy_reset = np.zeros_like(dummy_final).astype(float)    
    for idx, v in enumerate(values):
        dummy_reset[dummy_final == idx] = v

    return dummy_reset

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 24
def label_continuous_to_list(labs):
    ''' 
    Given an array of T x 2 labels containing the anomalous exponent and diffusion 
    coefficient at each timestep, returns 3 arrays, each containing the changepoints, 
    exponents and coefficient, respectively. 
    If labs is size T x 3, then we consider that diffusive states are given and also
    return those.
    
    Parameters
    ----------
    labs : array
        T x 2  or T x 3 labels containing the anomalous exponent, diffusion 
        and diffusive state.
        
    Returns
    -------
    tuple
        - First element is the list of change points
        - The rest are corresponding segment properties (order: alpha, Ds and states)        
        '''
    
    # Check if states were given
    are_states = False
    if labs.shape[1] == 3:
        are_states = True
        
    # Check in which variable there is changes
    CP = np.argwhere((labs[:-1, :] != labs[1:, :]).sum(1) != 0).flatten()+1 
    T = labs.shape[0] 

    alphas = np.zeros(len(CP)+1)
    Ds = np.zeros(len(CP)+1)
    if are_states: states = np.zeros(len(CP)+1)
        
    for idx, cp in enumerate(np.append(CP, T)):
        alphas[idx] = labs[cp-1, 0]
        Ds[idx] = labs[cp-1, 1]
        if are_states: states[idx] = labs[cp-1, 2]
    
    CP = np.append(CP, T)
    
    if are_states:
        return CP, alphas, Ds, states
    else:
        return CP, alphas, Ds    

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 28
def label_list_to_continuous(CP, label):
    '''
    Given a list of change points and the labels of the diffusion properties of the
    resulting segments, generates and array of continuous labels. The last change point
    indicates the array length.
    
    Parameters
    ----------
    CP : array, list
        list of change points. Last change point indicates label length.
    label : array, list
        list of segment properties
        
    Returns
    -------
    array
        Continuous label created from the given change points and segment properties
    '''    
    
    if isinstance(label, list):
        label = np.array(label)
    segs = create_binary_segment(CP[:-1], CP[-1])
    return (segs.transpose()*label).sum(1)

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 32
from .utils_trajectories import segs_inside_fov


def array_to_df(trajs, 
               labels,
               min_length = 10,
               fov_origin = [0,0], fov_length= 100.0, cutoff_length = 10):
    '''
    Given arrays for the position and labels of trajectories, creates a dataframe with that
    data. The function also applies the demanded FOV. If you don't want a field of view, chose a 
    FOV length bigger (smaller) that your maximum (minimum) trajectory position.
   
    Parameters
    ----------
    trajs : array 
        Trajectories to store in the df (dimension: T x N x 3)
    labels : array
        Labels to store in the df (dimension: T x N x 3)        
    fov_origin : tuple
        Bottom left point of the square defining the FOV.
    fov_length : float
        Size of the box defining the FOV.
    cutoff_length : int
        Minimum length of a trajectory inside the FOV to be considered in the output dataset.
    
    
    Returns
    -------
    tuple
        - df_in (dataframe): dataframe with trajectories
        - df_out (datafram): dataframe with labels 
    '''
    
    xs = []
    ys = []
    idxs = []   
    
    df_out = pandas.DataFrame(columns = ['traj_idx', 'Ds', 'alphas', 'states', 'changepoints']) 
    
    idx_t = 0
    for traj, l_alpha, l_D, l_s in zip(tqdm(trajs), labels[:, :, 0], labels[:, :, 1], labels[:, :, 2]):

        # Check FOV and 
        idx_inside_segments = segs_inside_fov(traj, fov_origin, fov_length, cutoff_length)

        if idx_inside_segments is not None:

            for idx_in in idx_inside_segments:            
                seg_x = traj[idx_in[0]:idx_in[1], 0]
                seg_y = traj[idx_in[0]:idx_in[1], 1]
                seg_alpha = l_alpha[idx_in[0]:idx_in[1]]
                seg_D = l_D[idx_in[0]:idx_in[1]]
                seg_state = l_s[idx_in[0]:idx_in[1]]

                # Filtering
                seg_alpha = label_filter(seg_alpha)
                seg_D = label_filter(seg_D)
                seg_state = label_filter(seg_state)
                
                
                # Stacking data of input dataframe
                xs += seg_x.tolist()
                ys += seg_y.tolist()
                idxs += (np.ones(len(seg_x))*idx_t).tolist()
                
                # Transforming to list of changepoints and physical properties
                merge = np.hstack((seg_alpha.reshape(seg_alpha.shape[0], 1),
                                   seg_D.reshape(seg_D.shape[0], 1),
                                   seg_state.reshape(seg_state.shape[0], 1)))
                
                CP, alphas, Ds, states = label_continuous_to_list(merge)
                
                # Saving each segment info in output dataframe
                df_out.loc[df_out.shape[0]] = [idx_t, Ds, alphas, states, CP]
                
                # Updating segment index
                idx_t += 1

    
    # Saving trajectories in Dataframe
    tr_to_df = np.vstack((idxs,
                          xs,
                          ys)).transpose()
    df_in = pandas.DataFrame(tr_to_df, columns = ['traj_idx', 'x', 'y'])  
    
    return df_in, df_out

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 37
def df_to_array(df, pad = -1):
    '''
    Transform a dataframe as the ones given in the ANDI 2 challenge (i.e. 4 columns:
    traj_idx, frame, x, y) into a numpy array. To deal with irregular temporal supports,
    we pad the array whenever the trajectory is not present.
    The output array has the typical shape of ANDI datasets: TxNx2
    
    Parameters
    ----------
    df : dataframe
        Dataframe with four columns 'traj_idx': the trajectory index, 'frame' the time frame and 
        'x' and 'y' the positions of the particle.
    pad : int
        Number to use as padding.
    
    Returns
    -------
    array
        Array containing the trajectories from the dataframe, with usual ANDI shape (TxNx2).
    
    
    '''

    max_T = int(df.frame.max()+1)
    num_part = int(df.iloc[-1].traj_idx)
    array_trajs = np.ones((max_T, num_part+1, 2))*pad

    for idx in np.unique(df.traj_idx).astype(int):

        df_part = df.loc[df.traj_idx == idx]

        array_trajs[df_part.frame.values.astype(int), idx, 0] = df_part.x.values
        array_trajs[df_part.frame.values.astype(int), idx, 1] = df_part.y.values
        
    return array_trajs

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 39
from pathlib import Path
import shutil

def file_nonOverlap_reOrg(# Original folder with data produced by datasets_challenge.challenge_phenom_dataset
                          raw_folder: Path,
                          # Folder where to put reorganized files
                          target_folder: Path, 
                          # Number of experiments
                          experiments: int, 
                          # Number of FOVS
                          num_fovs: int,                          
                          # Track to consider
                          tracks = [1,2],
                          # If True, moves all data (also labels,.. etc). Do True only if saving reference / groundtruth data.
                          # Moreover, if True also save the trajectories for the video track
                          save_labels = False, 
                          # Which task to consider
                          task = ['single', 'ensemble'],
                          # If True prints, the percentage of states for each experiment
                          print_percentage = True):
    ''' 
    This considers that you have n_fovs*n_experiments 'fake' experiments 
    and organize them based on the challenge instructions
    
    '''
    
    if save_labels:
        names_files = ['traj_labs_', 'trajs_', 'videos_', 'ens_labs_', 'vip_idx_']
        extensions = ['.txt', '.csv', '.tiff', '.txt', '.txt']
    else:
        names_files = ['trajs_', 'videos_']
        extensions = ['.csv', '.tiff']

    exp = 0
    ensemble_info = []
    # Get model and num_states
    info_exp = np.loadtxt(raw_folder/f'ens_labs_exp_0_fov_0.txt', max_rows=1, dtype = str)
    model_exp, num_states = info_exp[1][:-1], info_exp[-1].astype(int)
    percentage_exp = np.zeros((num_fovs, num_states))

    for k in range(num_fovs*len(experiments)):        
        
        # ----- Check when we are done with one experiment and go to next -----
        if (k % num_fovs == 0 and k != 0):
            
            # First save the ensemble information of the current experiment
            if num_states > 1:
                percentage_exp = np.sum(percentage_exp, axis = 0)
                percentage_exp /= percentage_exp.sum()                                  
                ensemble_fov[-1,:] =  percentage_exp                  
            if num_states == 1:
                ensemble_fov[-1] = 1
            if print_percentage:
                    print(f'Experiment {exp}: {np.round(ensemble_fov[-1], 2)}')
            
            if save_labels:
                for track in tracks:
                    with open(target_folder/f'track_{track}/exp_{exp}/ensemble_labels.txt', 'w') as f:
                        f.truncate(0)
                        f.write(f'model: {model_exp}; num_state: {num_states} \n')
                        np.savetxt(f, ensemble_fov, delimiter = ';')

            # Then restart for next experiment
            exp += 1
            ensemble_info = []
            info_exp = np.loadtxt(raw_folder/f'ens_labs_exp_{k}_fov_0.txt', max_rows=1, dtype = str)
            model_exp, num_states = info_exp[1][:-1], info_exp[-1].astype(int)
            percentage_exp = np.zeros((num_fovs, num_states))   
            
        
        # ----- Move the folders -----
        for track in tracks:
            (target_folder/(f'track_{track}/'+f'exp_{exp}')).mkdir(parents=True, exist_ok=True)

            # Move single trajectory information
            for name, ext in zip(names_files, extensions):            
                if track == 1 and name == 'trajs_' and save_labels == False: continue
                if track == 2 and (name == 'videos_' or name == 'vip_idx_'): continue

                shutil.copyfile(src = raw_folder/(name + f'exp_{k}_fov_0'+ext), 
                                dst = target_folder/(f'track_{track}/exp_{exp}/' + name + f'fov_{k%num_fovs}' + ext))
        

        ### ----- Collect ensemble information -----
        ensemble_fov = np.loadtxt(raw_folder/f'ens_labs_exp_{k}_fov_0.txt', 
                                  skiprows = 1, delimiter = ';')
        if num_states > 1:
            percentage_exp[k%num_fovs] = ensemble_fov[-1, :].copy()
            
    # Save the ensemble information of the LAST experiment
    if num_states > 1:
        percentage_exp = np.sum(percentage_exp, axis = 0)
        percentage_exp /= percentage_exp.sum()                                  
        ensemble_fov[-1,:] =  percentage_exp                  
    if num_states == 1:
        ensemble_fov[-1] = 1
    if print_percentage:
            print(f'Experiment {exp}: {np.round(ensemble_fov[-1], 2)}')
    
    if save_labels:
        for track in tracks:
            with open(target_folder/f'track_{track}/exp_{exp}/ensemble_labels.txt', 'w') as f:
                f.truncate(0)
                f.write(f'model: {model_exp}; num_state: {num_states} \n')
                np.savetxt(f, ensemble_fov, delimiter = ';')

                    

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 41
from scipy.spatial import distance


def get_VIP(array_trajs, num_vip = 5, min_distance_part = 2, pad = -1, 
            boundary = False, boundary_origin = (0,0), min_distance_bound = 0,
            sort_length = True):
    '''
    Given an array of trajectories, finds the particles VIP particles that participants will
    need to characterize in the video trakcl.
    
    The function first finds the particles that exist at frame 0 (i.e. that their first value 
    is different from pad). Then, iterates over this particles to find num_vip that are at 
    distance > than min_distance_part in the first frame.
    
    Parameters
    ----------
    array_trajs : array
        Position of the trajectories that will be considered for the VIP search.
    num_vip : int
        Number of VIP particles to flag.
    min_distance_part : float
        Minimum distance between two VIP particles.
    pad : int
        Number used to indicate in the temporal support that the particle is outside of the FOV.
    boundary : bool, float
        If float, defines the length of the box acting as boundary
    boundary_origin : tuple
        X and Y coords of the boundary
    min_distance_bound : float
        Minimum distance a particles has to be from the boundary in ordered to be considered a VIP particle
    sort_length : bool
        If True, candidates for VIP particles are choosen in descending trajectory length. This ensures
        that the longest ones are chosen.
        
    Returns
    -------
    list
        List of indices of the chosen VIP particles
    
    '''
    if not boundary:
        candidates_vip = np.argwhere(array_trajs[0,:,0] != pad).flatten()
    else:
        # Define masks
        boundary_x0 = array_trajs[0,:,0] > (boundary_origin[0] + min_distance_bound)
        boundary_xL = array_trajs[0,:,0] < (boundary_origin[0] + boundary - min_distance_bound)
        boundary_y0 = array_trajs[0,:,1] > (boundary_origin[1] + min_distance_bound)
        boundary_yL = array_trajs[0,:,1] < (boundary_origin[1] + boundary - min_distance_bound)
        padding = array_trajs[0,:,0] != pad
        
        candidates_vip = np.argwhere(boundary_x0 & boundary_xL & boundary_y0 & boundary_yL & padding).flatten()        
        
    if len(candidates_vip) < num_vip:
        raise ValueError('Number of VIP demanded is bigger than available particles.')

    elected = []
    count_while = 0    
    
    if sort_length:
        array_candidates = array_trajs[:, candidates_vip, :]
        lengths = np.ones(array_candidates.shape[1])*array_candidates.shape[0]
        where_pad = np.argwhere(array_candidates[:,:,0] == pad)
        lengths[where_pad[:,1]] = where_pad[:,0]
        # We sort the particle by their lenghts (note the minus for descending order)
        candidates_vip = candidates_vip[np.argsort(-lengths)]
        
    while len(elected) < num_vip:
        
        if sort_length and count_while == 0: 
            # if we already did a while loop, we start with a random candidate even
            # when sorting
            elected = [candidates_vip[0]]
        else:
            elected = [np.random.choice(candidates_vip)]

        for c_idx in candidates_vip:
            if c_idx == elected[0]:
                continue
            if len(array_trajs[0, elected,:].shape) < 2:
                all_rest = np.expand_dims(array_trajs[0, elected,:], 0)
            else:
                all_rest = array_trajs[0, elected,:]

            dist = distance.cdist(np.expand_dims(array_trajs[0,c_idx,:], 0), all_rest, metric='euclidean').transpose()

            if dist.min() > min_distance_part:
                elected.append(c_idx)

            if len(elected) == num_vip:
                break


        count_while += 1
        if count_while > 100: 
            raise ValueError('Could not find suitable VIP particles. This is due to either having to few particles or them being too close')
            
    return elected


# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 46
def _get_error_bounds():
    '''
    Sets the current maximum errors we can do in the different diffusive properties.
    '''
    
    # For single trajectory
    threshold_error_alpha = 2
    threshold_error_D = 1e5
    threshold_error_s = 0
    threshold_cp = 10
    
    # For ensemble, it relates to the Wasserstein distance. Check test below distribution_distance function 
    threshold_ensemble_alpha = np.abs(models_phenom().bound_alpha[0]-models_phenom().bound_alpha[1])
    threshold_ensemble_D = np.abs(models_phenom().bound_D[0]-models_phenom().bound_D[1])
    
    return threshold_error_alpha, threshold_error_D, threshold_error_s, threshold_cp, threshold_ensemble_alpha, threshold_ensemble_D

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 48
def changepoint_assignment(GT, preds):
    ''' 
    Given a list of groundtruth and predicted changepoints, solves the assignment problem via
    the Munkres algorithm (aka Hungarian algorithm) and returns two arrays containing the index of the
    paired groundtruth and predicted changepoints, respectively.
    
    The distance between change point is the Euclidean distance.
    
    Parameters
    ----------
    GT : list
        List of groundtruth change points.
    preds : list
        List of predicted change points.
    
    Returns
    -------
    tuple
        - tuple of two arrays, each corresponding to the assigned GT and pred changepoints
        - Cost matrix
    
    '''
    
    cost_matrix = np.zeros((len(GT), len(preds)))

    for idxg, gt in enumerate(GT):
        for idxp, pred in enumerate(preds):
            cost_matrix[idxg, idxp] = np.abs(gt-pred)
            
    return linear_sum_assignment(cost_matrix), cost_matrix

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 50
def changepoint_alpha_beta(GT, preds, threshold = 10):
    '''
    Calculate the alpha and beta measure of paired changepoints.
    Inspired from Supplemantary Note 3 in https://www.nature.com/articles/nmeth.2808
    
    Parameters
    ----------
    GT : list
        List of groundtruth change points.
    preds : list
        List of predicted change points.
    threshold : float
        Distance from which predictions are considered to have failed. They are then assigned this number.
    
    Returns
    -------
    tuple
        alpha, beta
        
    '''

    assignment, _ = changepoint_assignment(GT, preds)
    assignment = np.array(assignment)

    threshold = 10
    distance = np.abs(GT[assignment[0]] - preds[assignment[1]])
    distance[distance > threshold] = threshold
    distance = np.sum(distance)

    d_x_phi = threshold*len(GT)
    d_ybar_phi = max([0, (len(preds)-len(GT))*threshold])

    alpha = 1-distance/d_x_phi
    beta = (d_x_phi-distance)/(d_x_phi+d_ybar_phi)

    return alpha, beta

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 52
def jaccard_index(TP: int, # true positive
                  FP: int, # false positive
                  FN: int # false negative
                 )-> float: # Jaccard Index
    '''
    Given the true positive, false positive and false negative rates, calculates the Jaccard Index
    '''
    return TP/(TP+FP+FN)

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 53
def single_changepoint_error(GT, preds, threshold = 5):
    '''
    Given the groundtruth and predicted changepoints for a single trajectory, first solves the assignment problem between changepoints,
    then calculates the RMSE of the true positive pairs and the Jaccard index.
    
    Parameters
    ----------
    GT : list
        List of groundtruth change points.
    preds : list
        List of predicted change points.
    threshold : float
        Distance from which predictions are considered to have failed. They are then assigned this number.
    
    Returns
    -------
    tuple
        - TP_rmse: root mean square error of the true positive change points.
        - Jaccard Index of the ensemble predictions        
        
    '''
    
    assignment, _ = changepoint_assignment(GT, preds)
    assignment = np.array(assignment)
    
    TP, FP, FN = 0, 0, 0
    TP_rmse = []
    for p in assignment.transpose():
        
        if np.abs(GT[p[0]] - preds[p[1]]) < threshold:
            TP += 1
            TP_rmse.append((GT[p[0]] - preds[p[1]])**2)
        else:
            FP += 1
            FN += 1
    # Calculating RMSE
    TP_rmse = np.sqrt(np.mean(TP_rmse))

    # Checking false positive and missed events
    if len(preds) > len(GT):
        FP += len(preds) - len(GT)
    elif len(preds) < len(GT):
        FN += len(GT) - len(preds)
    
    return TP_rmse, jaccard_index(TP, FP, FN)

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 54
def ensemble_changepoint_error(GT_ensemble, pred_ensemble, threshold = 5):    
    ''' 
    Given an ensemble of groundtruth and predicted change points, iterates
    over each trajectory's changepoints. For each, it solves the assignment problem 
    between changepoints. Then, calculates the RMSE of the true positive pairs and
    the Jaccard index over the ensemble of changepoints (i.e. not the mean of them 
    w.r.t. to the trajectories)
    
    Parameters
    ----------
    GT_ensemble : list, array
        Ensemble of groutruth change points.
    pred_ensemble : list
        Ensemble of predicted change points.
    threshold : float
        Distance from which predictions are considered to have failed. They are then assigned this number.
    
    Returns
    -------
    tuple
        - TP_rmse: root mean square error of the true positive change points.
        - Jaccard Index of the ensemble predictions
    
    '''
    
    TP, FP, FN = 0, 0, 0
    TP_empty_GT = 0
    TP_rmse = []
    num_cp_GT = 0
    for gt_traj, pred_traj in zip(GT_ensemble, pred_ensemble):
        num_cp_GT += len(gt_traj)
        
        assignment, _ = changepoint_assignment(gt_traj, pred_traj)
        assignment = np.array(assignment)
        
        for p in assignment.transpose():
            
            if np.abs(gt_traj[p[0]] - pred_traj[p[1]]) < threshold:
                TP += 1
                TP_rmse.append((gt_traj[p[0]] - pred_traj[p[1]])**2)
            else:
                FP += 1
                FN += 1    
                
        # Checking false positive and missed events
        if len(pred_traj) > len(gt_traj):
            FP += len(pred_traj) - len(gt_traj)
        elif len(pred_traj) < len(gt_traj):
            FN += len(gt_traj) - len(pred_traj)
        # Case where no CP was correctly predicted
        if assignment.shape[1] == 0 and len(pred_traj) == len(gt_traj):
            TP_empty_GT += 1
                
    if TP+FP+FN == 0:
        if num_cp_GT == 0: # this means there where no CP both in GT and Pred
            return 0, 1
        wrn_str = f'No segments found in your predictions dataset.'
        warnings.warn(wrn_str)
        return threshold, 0

    
    # Calculating RMSE
    if len(TP_rmse) > 0:
        TP_rmse = np.sqrt(np.mean(TP_rmse))
    else:
        # We consider here that, if you don't predict any CP, there can't be
        # a TP, hence TP_rmse must be zero.
        TP_rmse = 0
    
        
    return TP_rmse, jaccard_index(TP+TP_empty_GT, FP, FN)

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 58
def create_binary_segment(CP: list, # list of changepoints
                          T: int # length of the trajectory
                         )-> list: # list of arrays with value 1 in the temporal support of the current segment.
    '''
    Given a set of changepoints and the lenght of the trajectory, create segments which are equal to one
    if the segment takes place at that position and zero otherwise.
    '''
    segments = np.zeros((len(CP)+1, T))
    CP = np.append(0, CP)
    for idx, (cp1, cp2) in enumerate(zip(CP[:-1], CP[1:])):
        segments[idx, cp1+1:cp2+1] = 1
    segments[-1, CP[-1]+1:] = 1
    segments[0, 0] = 1
    return segments

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 60
def jaccard_between_segments(gt, pred):
    '''
    Given two segments, calculates the Jaccard index between them by considering TP as correct labeling,
    FN as missed events and FP leftover predictions.
    
    Parameters
    ----------
    gt : array
        groundtruth segment, equal to one in the temporal support of the given segment, zero otherwise.
    pred : array
        predicted segment, equal to one in the temporal support of the given segment, zero otherwise.
    
    Returns
    -------
    float
        Jaccard index between the given segments.
    '''
    
    if len(gt) > len(pred):
        pred = np.append(pred, np.zeros(len(gt) - len(pred)))
    elif len(pred) > len(gt):                        
        gt = np.append(gt, np.zeros(len(pred) - len(gt)))
    
    
    tp = np.sum(np.logical_and(pred == 1, gt == 1))
    fp = np.sum(np.logical_and(pred == 1, gt == 0))
    fn = np.sum(np.logical_and(pred == 0, gt == 1))
    
    # special case for absence of changepoint
    if tp+fp+fn == 0: return 0    
    else: return jaccard_index(tp, fp, fn)

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 61
def segment_assignment(GT, preds, T:int = None):
    ''' 
    Given a list of groundtruth and predicted changepoints, generates a set of segments. Then constructs 
    a cost matrix by calculting the Jaccard Index between segments. From this cost matrix, we solve the 
    assignment  problem via the Munkres algorithm (aka Hungarian algorithm) and returns two arrays 
    containing the index of the groundtruth and predicted segments, respectively.
    
    If T = None, then we consider that GT and preds may have different lenghts. In that case, the end
    of the segments is the the last CP of each set of CPs.
    
    Parameters
    ----------
    GT : list
        List of groundtruth change points.
    preds : list
        List of predicted change points.
    T : int, None
        Length of the trajectory. If None, considers different GT and preds length.
    
    Returns
    -------
    tuple
        - tuple of two arrays, each corresponding to the assigned GT and pred changepoints
        - Cost matrix calculated via JI of segments   
    
    '''
   
    if T is not None:
        T_gt = T_pred = T
        # Check if the GT or predictions are a single integer or an empty array
        if isinstance(GT, int): GT = [GT]
        elif len(GT) == 0: GT = [T-1]

        if isinstance(preds, int): preds = [preds]
        elif len(preds) == 0: preds = [T-1]
    else:
        T_gt = GT[-1]
        if len(GT) > 1:
            GT = GT[:-1]            
            
        T_pred = preds[-1]
        if len(preds) > 1:
            preds = preds[:-1]
        
    
    
    seg_GT = create_binary_segment(GT, T_gt)
    seg_preds = create_binary_segment(preds, T_pred)
    
    cost_matrix = np.zeros((seg_GT.shape[0], seg_preds.shape[0]))

    for idxg, gt in enumerate(seg_GT):
        for idxp, pred in enumerate(seg_preds):
            cost_matrix[idxg, idxp] = 1-jaccard_between_segments(gt, pred)

    return linear_sum_assignment(cost_matrix), cost_matrix

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 71
from sklearn.metrics import mean_squared_log_error as msle, f1_score

def metric_anomalous_exponent(gt = None,
                              pred = None,
                              max_error = np.abs(models_phenom().bound_alpha[0]-models_phenom().bound_alpha[1])):
    ''' 
    Compute the mean absolute error (mae) between anomalous exponents.
    Checks the current bounds of anomalous exponents from models_phenom to calculate the maximum error.
    ''' 
    error = np.mean(np.abs(gt-pred))
    if error > max_error:
        return max_error
    else: 
        return error

def metric_diffusion_coefficient(gt = None, pred = None, 
                                 threshold_min = models_phenom().bound_D[0],                               
                                 max_error = msle([models_phenom().bound_D[0]],
                                                  [models_phenom().bound_D[1]])):
    ''' 
    Compute the mean squared log error (msle) between diffusion coefficients.
    Checks the current bounds of diffusion from models_phenom to calculate the maximum error. 
    ''' 
    
     # considering the presence of zeros and negatives
    pred = np.array(pred).copy(); gt = np.array(gt).copy()
    pred[pred <= threshold_min] = threshold_min
    gt[gt <= threshold_min] = threshold_min    
    # mean squared log error
    error = msle(gt, pred)
    
    if error > max_error:
        return max_error
    else: 
        return error

def metric_diffusive_state(gt = None, pred = None):
    ''' 
    Compute the F1 score between diffusive states. 
    ''' 
    return f1_score(gt.astype(int), pred.astype(int), average = 'micro')

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 75
def check_no_changepoints(GT_cp, GT_alpha, GT_D, GT_s,
                          preds_cp, preds_alpha, preds_D, preds_s,
                          T:bool|int = None):
    '''
    Given predicionts over changepoints and variables, checks if in both GT and preds there is an 
    absence of change point. If so, takes that into account to pair variables.
    
    Parameters
    ----------
    GT_cp : list, int, float
        Groundtruth change points
    GT_alpha : list, float
        Groundtruth anomalous exponent
    GT_D : list, float
        Groundtruth diffusion coefficient
    GT_s : list, float
        Groundtruth diffusive state
    preds_cp : list, int, float
        Predicted change points
    preds_alpha : list, float
        Predicted anomalous exponent
    preds_D : list, float
        Predicted diffusion coefficient
    preds_s : list, float
        Predicted diffusive state
    T : bool,int
        (optional) Length of the trajectories. If none, last change point is length.
    
    Returns
    -------
    tuple
        - False if there are change points. True if there were missing change points.
        - Next three are either all Nones if change points were detected, or paired exponents, 
        coefficient and states if some change points were missing.
    
    '''


    if isinstance(GT_cp, int) or isinstance(GT_cp, float):
        GT_cp = [GT_cp]
    if isinstance(preds_cp, int) or isinstance(preds_cp, float):
        preds_cp = [preds_cp]
        
    no_GT_cp = False; no_preds_cp = False
    # CP always contain the final point of the trajectory, hence minimal length is one
    if len(GT_cp) == 1: no_GT_cp = True
    if len(preds_cp) == 1: no_preds_cp = True       
        

    if no_GT_cp + no_preds_cp == 0:
        return False, None, None, None
    
    else:

        [row_ind, col_ind], _ = segment_assignment(GT_cp, preds_cp, T)   

        if no_GT_cp and not no_preds_cp:
            paired_alpha = np.array([[GT_alpha[0], preds_alpha[col_ind[0]]]])
            paired_D = np.array([[GT_D[0], preds_D[col_ind[0]]]])
            paired_s = np.array([[GT_s[0], preds_s[col_ind[0]]]])

        if no_preds_cp and not no_GT_cp:
            row_position = np.argwhere(col_ind == 0).flatten()[0]            
            paired_alpha = np.array([[GT_alpha[row_position], preds_alpha[col_ind[row_position]]]])
            paired_D = np.array([[GT_D[row_position], preds_D[col_ind[row_position]]]])
            paired_s = np.array([[GT_s[row_position], preds_s[col_ind[row_position]]]])
            
        if no_preds_cp and no_GT_cp: 
            paired_alpha = np.array([[GT_alpha[0], preds_alpha[0]]])
            paired_D = np.array([[GT_D[0], preds_D[0]]])
            paired_s = np.array([[GT_s[0], preds_s[0]]])
            

        return True, paired_alpha, paired_D, paired_s

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 76
def segment_property_errors(GT_cp, GT_alpha, GT_D, GT_s,
                            preds_cp, preds_alpha, preds_D, preds_s,
                            return_pairs = False,
                            T = None):
    '''
    Given predicionts over change points and the value of diffusion parameters in the generated
    segments, computes the defined metrics.
    
    Parameters
    ----------
    GT_cp : list, int, float
        Groundtruth change points
    GT_alpha : list, float
        Groundtruth anomalous exponent
    GT_D : list, float
        Groundtruth diffusion coefficient
    GT_s : list, float
        Groundtruth diffusive state
    preds_cp : list, int, float
        Predicted change points
    preds_alpha : list, float
        Predicted anomalous exponent
    preds_D : list, float
        Predicted diffusion coefficient
    preds_s : list, float
        Predicted diffusive state
    return_pairs : bool
        If True, returns the assigment pairs for each diffusive property.
    T : bool,int
        (optional) Length of the trajectories. If none, last change point is length.
    
    Returns
    -------
    tuple
        - if return_pairs = True, returns the assigned pairs of diffusive properties
        - if return_pairs = False, returns the errors for each diffusive property
    '''
    
    # Check cases in which changepoint where not detected or there were none in groundtruth
    no_change_point_case, paired_alpha, paired_D, paired_s = check_no_changepoints(GT_cp, GT_alpha, GT_D, GT_s,
                                                                                   preds_cp, preds_alpha, preds_D, preds_s, T)
   
    if not no_change_point_case:
        # Solve the assignment problem
        [row_ind, col_ind], _ = segment_assignment(GT_cp, preds_cp, T)
   
        # iterate over the groundtruth segments
        paired_alpha, paired_D, paired_s = [], [], []
        for idx_seg, (gt_alpha, gt_D) in enumerate(zip(GT_alpha, GT_D)):

            row_position = np.argwhere(row_ind == idx_seg).flatten()

            # if the GT segment was associated to a prediction
            if len(row_position) > 0:
                row_position = int(row_position)
                # alpha                
                gt_a_seg = GT_alpha[idx_seg]                
                pred_a_seg = preds_alpha[col_ind[row_position]]
                # d
                gt_d_seg = GT_D[idx_seg]
                pred_d_seg = preds_D[col_ind[row_position]]
                # state
                gt_s_seg = GT_s[idx_seg]
                pred_s_seg = preds_s[col_ind[row_position]]

                paired_alpha.append([gt_a_seg, pred_a_seg])
                paired_D.append([gt_d_seg, pred_d_seg])
                paired_s.append([gt_s_seg, pred_s_seg])

        paired_alpha, paired_D, paired_s = np.array(paired_alpha), np.array(paired_D), np.array(paired_s) 
    
    if return_pairs:
        return paired_alpha, paired_D, paired_s 
    else:
        error_alpha = metric_anomalous_exponent(paired_alpha[:,0], paired_alpha[:,1])
        error_D = metric_diffusion_coefficient(paired_D[:,0], paired_D[:,1])
        error_s = metric_diffusive_state(paired_s[:,0], paired_s[:,1])
        return error_alpha, error_D, error_s

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 85
def _visualize_ensemble(ens):
    '''
    Given input ens:
    
    |mu_alpha1      mu_alpha2     ... |
    |sigma_alpha1   sigma_alpha2  ... |
    |mu_D1          mu_D1         ... | 
    |sigma_D1       sigma_D2      ... |
    |counts_state1  counts_state2 ... |
    
    creates a dataframe to visualize parameters.
    '''  

    return pandas.DataFrame(data = ens.transpose(), columns = [r'mean $\alpha$', r'var $\alpha$', r'mean $D$', r'var $D$', '% residence time'])

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 86
from .models_phenom import models_phenom
def extract_ensemble(state_label, dic):
        ''' 
        Given an array of the diffusive state and a dictionary with the diffusion information,
        returns a summary of the ensemble properties for the current dataset.

        Parameters
        ----------
        state_label : array
            Array containing the diffusive state of the particles in the dataset.
            For multi-state and dimerization, this must be the number associated to the
            state (for dimerization, 0 is free, 1 is dimerized). For the rest, we follow
            the numeration of models_phenom().lab_state.
        dic : dict 
            Dictionary containing the information of the input dataset.
       
       Returns
       -------
       array
           Matrix containing the ensemble information of the input dataset. It has the following shape:
            |mu_alpha1      mu_alpha2     ... |
            |sigma_alpha1   sigma_alpha2  ... |
            |mu_D1          mu_D1         ... | 
            |sigma_D1       sigma_D2      ... |
            |counts_state1  counts_state2 ... |
        '''

        # Single state
        if dic['model'] == 'single_state': 
            ensemble = np.vstack((dic['alphas'][0],
                                   dic['alphas'][1],
                                   dic['Ds'][0],
                                   dic['Ds'][1],
                                   len(state_label)
                                   ))
        # Multi-state
        if dic['model'] == 'multi_state':
            states, counts = np.unique(state_label, return_counts=True)    
            # If the number of visited stated is not equal to the expected number of states
            if len(states) != dic['alphas'].shape[0]:
                states_corrected = np.ones(dic['alphas'].shape[0])
                counts_corrected = np.ones(dic['alphas'].shape[0])
                for s, c in zip(states, counts):
                    counts_corrected[int(s)] = c
            else: 
                counts_corrected = counts

            ensemble = np.vstack((dic['alphas'][:, 0],
                                   dic['alphas'][:, 1],
                                   dic['Ds'][:, 0],
                                   dic['Ds'][:, 1],
                                   counts_corrected
                                   ))

        # Immobile
        if dic['model'] == 'immobile_traps':
            counts = [len(state_label[state_label == models_phenom().lab_state.index('i')]),
                      len(state_label[state_label == models_phenom().lab_state.index('f')])]  
            ensemble = np.vstack(([0, dic['alphas'][0]],
                                   [0, dic['alphas'][1]],
                                   [0, dic['Ds'][0]],
                                   [0, dic['Ds'][1]],
                                   counts
                                   ))
        # dimerization    
        if dic['model'] == 'dimerization':
            counts = [len(state_label[state_label == 0]),
                      len(state_label[state_label == 1])]           
            ensemble = np.vstack((dic['alphas'][:, 0],
                                   dic['alphas'][:, 1],
                                   dic['Ds'][:, 0],
                                   dic['Ds'][:, 1],
                                   counts
                                   ))

        if dic['model'] == 'confinement':
            counts = [len(state_label[state_label == models_phenom().lab_state.index('f')]),
                      len(state_label[state_label == models_phenom().lab_state.index('c')])]   
            ensemble = np.vstack((dic['alphas'][:, 0],
                                   dic['alphas'][:, 1],
                                   dic['Ds'][:, 0],
                                   dic['Ds'][:, 1],
                                   counts
                                   ))
        return ensemble

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 88
import scipy.stats
def multimode_dist(params, weights, bound, x, normalized = False, min_var = 1e-9):
    '''
    Generates a multimodal distribution with given parameters.
    Also accounts for single mode if weight is float or int.
    
    Parameters
    ----------
    params : list
        Mean and variances of every mode.
    weights : list, float
        Weight of every mode. If float, we consider a single mode.
    bound : tuple
        Bounds (min, max) of the functions support.
    x : array
        Support upon which the distribution is created.
    normalize : bool
        If True, returns the normalized distribution.    
    Returns
    -------
    array
        Value of the distribution in each point of the given support
        
    '''
    func = scipy.stats.truncnorm
    dist = np.zeros_like(x)
    lower, upper = bound 
   
    # If we have single state, change values to list to still
    # have a loop:
    if isinstance(weights, float) or isinstance(weights, int):
        params = [params]
        weights = [weights]
        
    for param, w in zip(params, weights):
        mean, var  = param  
        # introduce a cutoff to avoid nan when var = 0
        if var < min_var: var = min_var
        unimodal = func.pdf(x,
                            (lower-mean)/np.sqrt(var),
                            (upper-mean)/np.sqrt(var),
                            loc = mean,
                            scale = np.sqrt(var))
        dist += w*unimodal
    if normalized:
        dist /= np.sum(dist)
    return dist

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 90
from scipy.stats import wasserstein_distance

def distribution_distance(p:np.array, # distribution 1
                          q:np.array, # distribution 2
                          x:np.array = None, # support of the distributions (not needed for MAE)
                          metric = 'wasserstein' # distance metric (either 'wasserstein' or 'mae')
                         )-> float:  # distance between distributions
    ''' Calculates distance between two distributions. '''
#     return np.sum(np.where(p != 0, p * np.log(p / q), 0))
    if metric == 'mae':
        return np.abs(p-q).mean()
    elif metric == 'wasserstein':
        return wasserstein_distance(x, x, p, q)

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 110
from .models_phenom import models_phenom

def error_Ensemble_dataset(true_data, pred_data,
                           size_support = int(1e6),
                           metric = 'wasserstein',
                           return_distributions = False):
    ''' 
    Calculates the ensemble metrics for the ANDI 2 challenge. The input are matrices of shape:
    
    | col1 (state 1) | col2 (state 2) | col3 (state 3) | ... |
    |:--------------:|:--------------:|:--------------:|:---:|
    | $\mu_a^1$      | $\mu_a^2$      | $\mu_a^3$      | ... |
    | $\sigma_a^1$   | $\sigma_a^2$   | $\sigma_a^3$   | ... |
    | $\mu_D^1$      | $\mu_D^2$      | $\mu_D^3$      | ... |        
    | $\sigma_D^1$   | $\sigma_D^2$   | $\sigma_D^3$   | ... |
    | $N_1$          | $N_2$          | $N_3$          | ... |
    
    Parameters
    ----------
    true_data : array
        Matrix containing the groundtruth data.
    pred_data : array
        Matrix containing the predicted data.
    size_support : int
        size of the support of the distributions 
    metric : str
        metric used to calculate distance between distributions
    return_distributions : bool
        If True, the function also outputs the generated distributions.
    
    Returns
    -------
    tuple
        - distance_alpha: distance between anomalous exponents
        - distance_D: distance between diffusion coefficients
        - dists (if asked): distributions of both groundtruth and predicted data. Order: true_a, true_D, pred_a, pred_D        
    
    '''
    # Define the support for the distributions
    x_alpha = np.linspace(models_phenom().bound_alpha[0], 
                          models_phenom().bound_alpha[1], size_support)
    x_D = np.logspace(np.log10(models_phenom().bound_D[0]), 
                      np.log10(models_phenom().bound_D[1]), size_support)  
    
    
    dists = []
    for data in [true_data, pred_data]:
        
        if len(data.shape) > 1: # If we have more than one state
            alpha_info = np.delete(data, [2,3, -1], 0)
            d_info = data[2:-1,:]
            weights = data[-1,:]
            if weights.sum() > 1: weights /= weights.sum()
        else: # If single state
            alpha_info = data[:2]
            d_info = data[2:-1]
            weights = 1
            
        for idx, var in enumerate([alpha_info, d_info]):                                                
            dists.append(multimode_dist(var.T, weights, 
                                        bound  = models_phenom().bound_alpha if idx == 0 else models_phenom().bound_D, 
                                        x = x_alpha if idx == 0 else x_D))
            
    # Distance between alpha dists
    distance_alpha = distribution_distance(p = dists[0], q = dists[2],
                                           x = x_alpha, metric = metric)
    distance_D = distribution_distance(p = dists[1], q = dists[3],
                                       x = x_D, metric = metric)
    
    if return_distributions:
        return distance_alpha, distance_D, dists
    else:
        return distance_alpha, distance_D

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 113
def check_prediction_length(pred):
    '''
    Given a trajectory segments prediction, checks whether it has C changepoints and C+1 segments properties values.
    As it must also contain the index of the trajectory, this is summarized by being multiple of 4. 
    In some cases, the user needs to also predict the final point of the trajectory. In this case, 
    we will have a residu of 1.
    '''
    if len(pred) % 4 == 0 or len(pred) % 4 == 1 :
        return True
    else: 
        return False

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 114
def separate_prediction_values(pred):
    '''
    Given a prediction over trjaectory segments, extracts the predictions for each segment property
    as well as the changepoint values.
    '''        
    Ds = pred[1::4]
    alphas = pred[2::4]
    states = pred[3::4]
    cp = pred[4::4]    
    return Ds, alphas, states, cp

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 115
def load_file_to_df(path_file, 
                    columns = ['traj_idx', 'Ds', 'alphas', 'states', 'changepoints']):
    '''
    Given the path of a .txt file, extract the segmentation predictions based on 
    the rules of the ANDI 2 challenge022
    '''

    with open(path_file) as f:
        lines_pred = f.read().splitlines()

    df = pandas.DataFrame(columns = columns)

    for line in lines_pred:
        # Extract values with comma separator and transform to float
        pred_traj = line.split(',')
        pred = [float(i) for i in pred_traj]
        
        # Check that prediction has the correct shape
        pred_correct = check_prediction_length(pred)
        
        # If correct size, then extract parameters and add it to dataframe
        if pred_correct:
            preds_D, preds_a, preds_s, preds_cp = separate_prediction_values(pred)

            current_row = df.shape[0]
            for param, pred_param in zip(columns, [pred[0], preds_D, preds_a, preds_s, preds_cp]):
                df.loc[current_row, param] = pred_param
                
    return df

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 120
def error_SingleTraj_dataset(df_pred, df_true, 
                              threshold_error_alpha = None, max_val_alpha = 2, min_val_alpha = 0, 
                              threshold_error_D = None, max_val_D = 1e6, min_val_D = 1e-6, 
                              threshold_error_s = None,
                              threshold_cp = None,
                              prints = True, disable_tqdm = False
                             ):
    '''
    Given two dataframes, corresponding to the predictions and true labels of a set
    of trajectories from the ANDI 2 challenge022, calculates the corresponding metrics
    Columns must be for both (no order needed):
    traj_idx | alphas | Ds | changepoints | states
    df_true must also contain a column 'T'.
    
    Parameters
    ----------
    df_pred : dataframe
        Predictions
    df_true : dataframe
        Groundtruth
    threshold_error_alpha : float
        (same for D, s, cp) Maximum possible error allowed. If bigger, it is substituted by this error.
    max_val_alpha : float
        (same for D, s, cp) Maximum value of the parameter.
    min_val_alpha : float
        (same for D, s, cp) Minimum value of the parameter.
    print : bool
        If True, prints the results.
    disable_tqdm : bool
        If True, disables the progress bar.
    
    Returns
    -------
    tuple
        - rmse_CP: root mean squared error change points
        - JI: Jaccard index change points
        - error_alpha: mean absolute error anomalous exponents
        - error_D: mean square log error diffusion coefficients
        - error_s: Jaccar index diffusive states
    
    '''
    # Check error bounds
    andi_bounds = _get_error_bounds()
    if threshold_error_alpha is None: threshold_error_alpha = andi_bounds[0]
    if threshold_error_D is None: threshold_error_D = andi_bounds[1]
    if threshold_error_s is None: threshold_error_s = andi_bounds[2]
    if threshold_cp is None: threshold_cp = andi_bounds[3]
    
    # Initiate counting missing trajectories
    missing_traj = False
    
    # Deleter saving variables, just in case...
    try: del paired_alpha, paired_D, paired_s
    except: pass

    # for every trajectory, we stack paired segment properties. We also store changepoints info
    ensemble_pred_cp, ensemble_true_cp = [], []
    for t_idx in tqdm(df_true['traj_idx'].values, disable = disable_tqdm):
        
        traj_trues = df_true.loc[df_true.traj_idx == t_idx]

        traj_preds = df_pred.loc[df_pred.traj_idx == t_idx]    
        if traj_preds.shape[0] == 0 or len(traj_preds.changepoints.to_list()[0]) == 0:
            # If there is no trajectory, we give maximum error. To do so, we redefine predictions
            # and trues so that they give maximum error
            missing_traj += 1                       
            
            preds_cp, preds_alpha, preds_D, preds_s = [[10],
                                                       [0],
                                                       [1],
                                                       [0]]

            trues_cp, trues_alpha, trues_D, trues_s = [[10+threshold_cp],
                                                       [threshold_error_alpha],
                                                       [1+threshold_error_D],
                                                       [10]]

            # Collecting changepoints for metric
            ensemble_pred_cp.append(preds_cp)
            ensemble_true_cp.append(trues_cp) 
            
        
        else:      

            preds_cp, preds_alpha, preds_D, preds_s = [np.array(traj_preds.changepoints.values[0]).astype(int),
                                                       traj_preds.alphas.values[0],
                                                       traj_preds.Ds.values[0],
                                                       traj_preds.states.values[0]]

            trues_cp, trues_alpha, trues_D, trues_s = [np.array(traj_trues.changepoints.values[0]).astype(int),
                                                       traj_trues.alphas.values[0],
                                                       traj_trues.Ds.values[0],
                                                       traj_trues.states.values[0]]
            

        
            # Collecting changepoints for metric
            # In this "else", the trajectory contains also as final point the trajectory lenght. We get rid of it for the CP metrics.
            ensemble_pred_cp.append(preds_cp[:-1])
            ensemble_true_cp.append(trues_cp[:-1])   

        
        
        # collecting segment properties error after segment assignment
        pair_a, pair_d, pair_s = segment_property_errors(trues_cp, trues_alpha, trues_D, trues_s, 
                                                         preds_cp, preds_alpha, preds_D, preds_s,
                                                         return_pairs = True)
        

        
        try:
            paired_alpha = np.vstack((paired_alpha, pair_a))
            paired_D = np.vstack((paired_D, pair_d))
            paired_s = np.vstack((paired_s, pair_s))        
        except:
            paired_alpha = pair_a
            paired_D = pair_d
            paired_s = pair_s
               
    #### Calculate metrics from assembled properties   

    # checking for nans and problems in predictions
    wrong_alphas = np.argwhere(np.isnan(paired_alpha[:, 1]) | (paired_alpha[:, 1] > 2) | (paired_alpha[:, 1] < 0)).flatten()
    paired_alpha[wrong_alphas, 1] = paired_alpha[wrong_alphas, 0] + threshold_error_alpha

    wrong_ds = np.argwhere(np.isnan(paired_D[:, 1])).flatten()
    paired_D = np.abs(paired_D)
    paired_D[wrong_ds, 1] = paired_D[wrong_ds, 0] + threshold_error_D
    
    wrong_s = np.argwhere((paired_s[:, 1] > 4) | (paired_s[:, 1]<0))
    paired_s[wrong_s, 1] = threshold_error_s    
    
    # Changepoints
    rmse_CP, JI = ensemble_changepoint_error(ensemble_true_cp, ensemble_pred_cp, threshold = threshold_cp)
    
    # Segment properties
    error_alpha = metric_anomalous_exponent(paired_alpha[:,0], paired_alpha[:,1])
    error_D = metric_diffusion_coefficient(paired_D[:,0], paired_D[:,1])
    error_s = metric_diffusive_state(paired_s[:,0], paired_s[:,1])
    
    if prints:        
        print(f'Summary of metrics assesments:')
        if missing_traj is not False:            
            print(f'\n{missing_traj} missing trajectory/ies. ')
            
        print(f'\nChangepoint Metrics \nRMSE: {round(rmse_CP, 3)} \nJaccard Index: {round(JI, 3)}',
              f'\n\nDiffusion property metrics \nMetric anomalous exponent: {error_alpha} \nMetric diffusion coefficient: {error_D} \nMetric diffusive state: {error_s}')
              
              

    return rmse_CP, JI, error_alpha, error_D, error_s

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 138
import re
import sys
import os

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 140
def when_error_single(wrn_str):
    # Giving back max_errors for all variables when encountered an error
    # Order of variables is 'cp','JI','alpha','D','state' (as in the dataframe). See that this is not the same order as _get_error_bounds
    warnings.warn(wrn_str)   
    
    
    max_error_alpha = _get_error_bounds()[0]
    max_error_D = _get_error_bounds()[1]
    max_error_s = _get_error_bounds()[2]
    max_error_cp = _get_error_bounds()[3]
    max_error_JI = 0
    
    return (max_error_cp, max_error_JI, max_error_alpha, max_error_D, max_error_s) , pandas.DataFrame(data = np.array([None]*7).reshape(1,7), 
                                                                                                      columns = ['Exp', 'num_trajs', 'RMSE CP', 
                                                                                                                 'JSC CP', 'alpha', 'K', 'state'])


def run_single_task(exp_nums, track, submit_dir, truth_dir):
    
    data_metrics = []

    for exp in exp_nums:

        try:
            del df_true_exp, df_pred_exp
        except:
            pass

        path_pred = submit_dir+f'/track_{track}/exp_{exp}/'
        path_true = truth_dir+f'/track_{track}/exp_{exp}/'
        prefix_true = 'traj_labs_'

        # Get the number of FOVs from the trues
        fov_nums = 0
        for filename in os.listdir(path_true):
            if filename.startswith(prefix_true):
                fov_nums += 1

        for fov in range(fov_nums):
            # Predictions
            corresponding_submission_file = path_pred+f'fov_{fov}.txt'
            
            ### If one file does not exists, abort a return Nones ###
            if not os.path.isfile(corresponding_submission_file):
                wrn_str = f'Failed to compute metrics at: -- Track {track} | Task SingleTraj  | Experiment {exp} | FOV {fov} -- this is probably caused by missing files.'
                return when_error_single(wrn_str)
                
                
            else:
                preds_fov = load_file_to_df(corresponding_submission_file)

            # Groundtruths
            trues_fov = load_file_to_df(path_true+prefix_true+f'fov_{fov}.txt')

            if track == 1:
                vip_idx = np.loadtxt(path_true + f'vip_idx_fov_{fov}.txt').astype(int)
                pred_vip_idx = preds_fov.traj_idx.values.astype(int)

                if len(vip_idx) != len(pred_vip_idx) or (np.sort(vip_idx) != np.sort(pred_vip_idx)).any():
                    wrn_str = f'Index of predicted VIP  particles does not correspond to true values (Track {track}, Exp {exp}, FOV {fov}). Be sure to correctly extract the correct index from the .tiff file'
                    return when_error_single(wrn_str)
                
                # Take only VIP particles here (pred already contains only vip)
                trues_fov = trues_fov[trues_fov['traj_idx'].isin(vip_idx)]

            # Sort dataframes by the traj idx (so that index and rows correspond)
            trues_fov = trues_fov.sort_values('traj_idx')
            preds_fov = preds_fov.sort_values('traj_idx')

            # Full experiment dataframe
            try:
                trues_fov.traj_idx += df_true_exp.traj_idx.values[-1]+1
                preds_fov.traj_idx += df_pred_exp.traj_idx.values[-1]+1

                df_pred_exp = pandas.concat([df_pred_exp, preds_fov])
                df_true_exp = pandas.concat([df_true_exp, trues_fov]) 
            except:            
                df_pred_exp = preds_fov
                df_true_exp = trues_fov

        # Calculate error for each experiment
        rmse_CP_exp, JI, error_alpha_exp, error_D_exp, error_s_exp = error_SingleTraj_dataset(df_pred_exp, df_true_exp, prints = False, disable_tqdm=True);

        # Save errors and number of trajectories of later doing average
        data_metrics.append([exp, df_true_exp.shape[0], rmse_CP_exp, JI, error_alpha_exp, error_D_exp, error_s_exp])

    # Put all results in dataframe    
    data_metrics = pandas.DataFrame(data = data_metrics, columns = ['Exp', 'num_trajs', 'RMSE CP', 'JSC CP', 'alpha', 'K', 'state'])
    # Calculate weighted averages
    avg_metrics = []
    for key in data_metrics.keys()[2:]:
        avg_metrics.append(np.average(data_metrics[key], weights=data_metrics.num_trajs))

    return avg_metrics, data_metrics
    

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 145
def run_ensemble_task(exp_nums, track, submit_dir, truth_dir):
    
    avg_alpha, avg_d = [], []
    
    
    # We keep track in case the file for one experiment is missing. If for at least one we can't find
    # it, we give None as results
    filename = 'ensemble_labels.txt'
    
    
    for exp in exp_nums:
        

        path_pred = submit_dir+f'/track_{track}/exp_{exp}/'
        path_true = truth_dir+f'/track_{track}/exp_{exp}/'
         
        try:

            true = np.loadtxt(path_true+filename, skiprows = 1, delimiter = ';')
            pred = np.loadtxt(path_pred+filename, skiprows = 1, delimiter = ';')

            distance_a_exp, distance_d_exp, dists = error_Ensemble_dataset(true, pred, return_distributions = True)

            avg_alpha.append(distance_a_exp)
            avg_d.append(distance_d_exp)
            
        except:
            wrn_str = f'Failed to compute metrics at: -- Track {track} | Task Ensemble | Experiment {exp} -- this is probably caused by missing file.'
            warnings.warn(wrn_str)
            # Get the max error possible for the task
            _,_,_,_, max_error_a, max_error_D = _get_error_bounds()
            
            return (max_error_a, max_error_D), pandas.DataFrame(data = np.array([None, None, None]).reshape(1,3), 
                                              columns = ['Exp', 'alpha', 'K']) 
        
    data_metrics = pandas.DataFrame(data = np.vstack((np.arange(len(avg_alpha)),avg_alpha, avg_d)).transpose(),
                                    columns = ['Exp', 'alpha', 'K'])
    data_metrics['Exp'] = data_metrics['Exp'].values.astype(int)
        
    return (np.mean(avg_alpha), np.mean(avg_d)),  data_metrics

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 148
import os
import re


def listdir_nohidden(path):
    for f in os.listdir(path):
        if not f.startswith(('.','_')):
            yield f

def codalab_scoring(INPUT_DIR = None, # directory to where to find the reference and predicted labes
                    OUTPUT_DIR = None # directory where the scores will be saved (scores.txt)
                   ):
    
    if INPUT_DIR is None:
        INPUT_DIR = sys.argv[1]
    if OUTPUT_DIR is None:
        OUTPUT_DIR = sys.argv[2]
        
    submit_dir = os.path.join(INPUT_DIR, 'res')
    truth_dir = os.path.join(INPUT_DIR, 'ref')
    
    
    # Starting the HMTL file
    htmlOutputDir = os.path.join(OUTPUT_DIR, "html")
    if not os.path.exists(htmlOutputDir):
            os.makedirs(htmlOutputDir)
    html_filename = os.path.join(htmlOutputDir, 'scores.html')
    html_file = open(html_filename, 'a', encoding="utf-8")
    html_file.write('<h1>Submission detailed results </h1>')

    if not os.path.isdir(submit_dir):
        print( "%s doesn't exist", submit_dir)
        
    if os.path.isdir(submit_dir) and os.path.isdir(truth_dir):
        if not os.path.exists(OUTPUT_DIR):
            os.makedirs(OUTPUT_DIR)

    output_filename = os.path.join(OUTPUT_DIR, 'scores.txt')
    output_file = open(output_filename, 'w')
       
    

    # Track 1: videos
    # Track 2: trajectories
    for track, name_track in zip([1,2], ['videos', 'trajectories']):
        
        ##### ----- In case the whole track is missing, give Nones to both tasks ----- #####
        path_preds = os.path.join(INPUT_DIR, f'res/track_{track}')
        
        if not os.path.exists(path_preds):
            wrn_str = f'No submission for track {track} found.'
            warnings.warn(wrn_str)
            
            for task in enumerate(['single', 'ensemble']): 
                # Codalab naming:
                # task 1 : single traj
                # task 2: ensemble
                idx_task = 1 if task == 'single' else 2
                
                # single trajectories
                if task == 'single':                    
                    for name, max_error in zip(['alpha','D','state', 'cp','JI'], list(_get_error_bounds()[:-2])+[0]): # This names must be the same as used in the yaml leaderboard                  
                        output_file.write(f'tr{track}.ta1{idx_task}.'+name+': '+str(max_error) +'\n')
                elif task == 'ensemble':
                    for name, max_error in zip(['alpha','D'], _get_error_bounds()[-2:]): # This names must be the same as used in the yaml leaderboard
                        output_file.write(f'tr{track}.ta{idx_task}.'+name+': '+str(max_error) +'\n')
            continue
        ##### ------------------------------------------------------------------------ #####
        
        
        html_file.write(f'<h2> Track {track}: '+name_track+' </h2>')

        for task in ['ensemble', 'single']: 

            # Codalab naming:
            # task 1 : single traj
            # task 2: ensemble
            idx_task = 1 if task == 'single' else 2
            
            if task == 'single':
                html_file.write(f'<h3> Single Trajectory Task </h3>')
            elif task == 'ensemble':
                html_file.write(f'<h3> Ensemble Task </h3>')


            # Get the number of experiments from the true directory
            exp_folders = sorted(list(listdir_nohidden(truth_dir+f'/track_{track}')))
            exp_nums = [int(re.findall(r'\d+', name)[0]) for name in exp_folders]

            if task == 'single':  

                avg_metrics, df = run_single_task(exp_nums, track, submit_dir, truth_dir )

                for name, res in zip(['cp','JI','alpha','D','state'], avg_metrics): # This names must be the same as used in the yaml leaderboard                  
                    output_file.write(f'tr{track}.ta{idx_task}.'+name+': '+str(res) +'\n')

                ''' To keep consistency with leaderboard display, we swap the K and alpha columns that
                get printed in the detailed results.
                Moreover, we change the names to match leaderboard. '''
                df_swapped = df.iloc[:,[0,1,2,3,5,4,6]]
                df_swapped = df_swapped.rename(columns = {'alpha': 'MAE (alpha)', 'K': 'MSLE (K)',
                                                          'RMSE CP': 'RMSE (CP)', 'JSC CP': 'JSC (CP)',
                                                          'state': 'F1 (diff. type)'})
                # Changing the name of JI to JSC to match paper nomenclature
                html_file.write(df_swapped.to_html(index = False).replace('\n',''))
              

            if task == 'ensemble':

                avg_metrics, df = run_ensemble_task(exp_nums, track, submit_dir, truth_dir)
                
                ''' There was a problem with the leaderboard labels and we had to SWAP alpha and D in the 
                first element of the zip, i.e. the list is now ['D', 'alpha'] but avg_metrics is [alpha, D] '''
                for name, res in zip(['D','alpha'], avg_metrics):                
                    output_file.write(f'tr{track}.ta{idx_task}.'+name+': '+str(res) +'\n')      

                ''' To keep consistency with leaderboard display, we swap the K and alpha columns that
                get printed in the detailed results.
                Moreover, we change the names to match leaderboard. '''
                df_swapped = df.iloc[:,[0,2,1]]                
                df_swapped = df_swapped.rename(columns = {'alpha': r'W1 (alpha)', 'K': 'W1 (K)'})
                
                html_file.write(df_swapped.to_html(index = False).replace('\n',''))
   

    html_file.close()
    output_file.close()  
        

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 152
import os
import re

def codalab_scoring_local(submit_dir, # directory to where to find the predicted labels (i.e. folder containing folders track_1 and/or track_2)
                          truth_dir, # directory to where to find the reference labels (i.e. folder containing folders track_1 and track_2)
                          output_dir, # directory where the scores will be saved 
                          scores_filename = 'scores.txt', # name of the txt scores file
                          html_filename = 'scores.html', # name of the html scores file
                          dfs_suffix = None # if str, suffix of the df filename: df_task_{1|2}_track_{1|2}_{dfs_suffix}.csv
                         ):
    ''' 
    Local version of codalab_scoring, allowing for custom savings and without df swapping.
    Labelling is as: 
    Track 1: videos, 2: trajectories; 
    Task 1: Single, 2: Ensemble
    '''
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        
    html_filename = os.path.join(output_dir, html_filename)
    html_file = open(html_filename, 'a', encoding="utf-8")
    html_file.write('<h1>Submission detailed results </h1>')    
    
    output_filename = os.path.join(output_dir, scores_filename)
    output_file = open(output_filename, 'w')       
    

    # Track 1: videos
    # Track 2: trajectories
    for track, name_track in zip([1,2], ['videos', 'trajectories']):
        
        ##### ----- In case the whole track is missing, give Nones to both tasks ----- #####
        path_preds = os.path.join(submit_dir, f'track_{track}')
        
        if not os.path.exists(path_preds):
            wrn_str = f'No submission for track {track} found.'
            warnings.warn(wrn_str)
            
            for task in enumerate(['single', 'ensemble']): 
                # Codalab naming:
                # task 1 : single traj
                # task 2: ensemble
                idx_task = 1 if task == 'single' else 2
                
                # single trajectories
                if task == 'single':                    
                    for name, max_error in zip(['alpha','D','state', 'cp','JI'], list(_get_error_bounds()[:-2])+[0]): # This names must be the same as used in the yaml leaderboard                  
                        output_file.write(f'tr{track}.ta{idx_task}.'+name+': '+str(max_error) +'\n')
                elif task == 'ensemble':
                    for name, max_error in zip(['alpha','D'], _get_error_bounds()[-2:]): # This names must be the same as used in the yaml leaderboard
                        output_file.write(f'tr{track}.ta{idx_task}.'+name+': '+str(max_error) +'\n')
            continue
        ##### ------------------------------------------------------------------------ #####
        
        
        html_file.write(f'<h2> Track {track}: '+name_track+' </h2>')

        for task in ['ensemble', 'single']: 

            # Codalab naming:
            # task 1 : single traj
            # task 2: ensemble
            idx_task = 1 if task == 'single' else 2
            
            if task == 'single':
                html_file.write(f' Single Trajectory Task ')
            elif task == 'ensemble':
                html_file.write(f' Ensemble Task ')


            # Get the number of experiments from the true directory
            exp_folders = sorted(list(listdir_nohidden(truth_dir+f'/track_{track}')))
            exp_nums = [int(re.findall(r'\d+', name)[0]) for name in exp_folders]

            if task == 'single':  

                avg_metrics, df = run_single_task(exp_nums, track, submit_dir, truth_dir )

                for name, res in zip(['cp','JI','alpha','D','state'], avg_metrics): # This names must be the same as used in the yaml leaderboard                  
                    output_file.write(f'tr{track}.ta{idx_task}.'+name+': '+str(res) +'\n')
                    
                # Changing the name of JI to JSC to match paper nomenclature
                html_file.write(df.to_html(index = False).replace('\n',''))
              

            if task == 'ensemble':

                avg_metrics, df = run_ensemble_task(exp_nums, track, submit_dir, truth_dir)
                
                for name, res in zip(['alpha','D'], avg_metrics):                
                    output_file.write(f'tr{track}.ta{idx_task}.'+name+': '+str(res) +'\n')      
                    
                html_file.write(df.to_html(index = False).replace('\n',''))

            if dfs_suffix is not None:
                if df.shape[0] > 1:
                    df.to_csv(os.path.join(output_dir, f'df_track_{track}_task_{idx_task}_{dfs_suffix}.csv'))

    
   

    html_file.close()
    output_file.close()  

# %% ../source_nbs/lib_nbs/utils_challenge.ipynb 156
import glob
# Function to rename and delete files as required
def transform_ref_to_res(base_path : str, # path where to find the folder to reorganize
                         track : str, # either 'track_1' or 'track_2'
                         num_fovs : int
                        ):
    
    ''' Transforms an organized reference dataset into a valid submission dataset. Note that we 
    do not account for VIP indices in track_1, so will later yield an error when scoring this track.'''
    
    
    fov_range = range(num_fovs)
    
    # Iterate through each experiment directory
    for exp_dir in os.listdir(os.path.join(base_path, track)):
        exp_path = os.path.join(base_path, track, exp_dir)

        # Rename traj_labs_fov_X.txt to fov_X.txt
        for fov in fov_range:
            old_name = os.path.join(exp_path, f'traj_labs_fov_{fov}.txt')
            new_name = os.path.join(exp_path, f'fov_{fov}.txt')
            if os.path.exists(old_name):
                os.rename(old_name, new_name)

        # Delete all files except ensemble_labels.txt
        for file in glob.glob(os.path.join(exp_path, '*')):
            if not (file.endswith('ensemble_labels.txt') or os.path.basename(file).startswith('fov_')):
                os.remove(file)
