# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/02_vision_data.ipynb.

# %% auto 0
__all__ = ['pred_to_multiclass_mask', 'batch_pred_to_multiclass_mask', 'pred_to_binary_mask', 'MedDataBlock', 'MedMaskBlock',
           'MedImageDataLoaders', 'show_batch', 'show_results', 'plot_top_losses']

# %% ../nbs/02_vision_data.ipynb 2
import torch
from fastai.data.all import *
from fastai.vision.data import *
from .vision_core import *
from .vision_plot import find_max_slice
from plum import dispatch

# %% ../nbs/02_vision_data.ipynb 5
def pred_to_multiclass_mask(pred: torch.Tensor) -> torch.Tensor:
    """Apply Softmax on the predicted tensor to rescale the values in the range [0, 1]
    and sum to 1. Then apply argmax to get the indices of the maximum value of all 
    elements in the predicted Tensor.

    Args:
        pred: [C,W,H,D] activation tensor.

    Returns: 
        Predicted mask.
    """
    
    pred = pred.softmax(dim=0)

    return pred.argmax(dim=0, keepdims=True)

# %% ../nbs/02_vision_data.ipynb 6
def batch_pred_to_multiclass_mask(pred: torch.Tensor) -> (torch.Tensor, int):
    """Convert a batch of predicted activation tensors to masks.
    
    Args:
        pred: [B, C, W, H, D] batch of activations.

    Returns:
        Tuple of batch of predicted masks and number of classes.
    """
    
    n_classes = pred.shape[1]
    pred = [pred_to_multiclass_mask(p) for p in pred]

    return torch.stack(pred), n_classes

# %% ../nbs/02_vision_data.ipynb 7
def pred_to_binary_mask(pred: torch.Tensor) -> torch.Tensor:
    """Apply Sigmoid function that squishes activations into a range between 0 and 1.
    Then we classify all values greater than or equal to 0.5 to 1, 
    and the values below 0.5 to 0.

    Args:
        pred: [B, C, W, H, D] or [C, W, H, D] activation tensor

    Returns:
        Predicted binary mask(s).
    """
    
    pred = torch.sigmoid(pred)

    return torch.where(pred >= 0.5, 1, 0)

# %% ../nbs/02_vision_data.ipynb 9
class MedDataBlock(DataBlock):
    """Container to quickly build dataloaders."""
    #TODO add get_x
    def __init__(self, blocks: list = None, dl_type: TfmdDL = None, getters: list = None,
                 n_inp: int | None = None, item_tfms: list = None, batch_tfms: list = None,
                 reorder: bool = False, resample: (int, list) = None, **kwargs):

        super().__init__(blocks, dl_type, getters, n_inp, item_tfms,
                         batch_tfms, **kwargs)

        MedBase.item_preprocessing(resample, reorder)

# %% ../nbs/02_vision_data.ipynb 11
def MedMaskBlock():
    """Create a TransformBlock for medical masks."""
    return TransformBlock(type_tfms=MedMask.create)

# %% ../nbs/02_vision_data.ipynb 13
class MedImageDataLoaders(DataLoaders):
    """Higher-level `MedDataBlock` API."""
    
    @classmethod
    @delegates(DataLoaders.from_dblock)
    def from_df(cls, df, valid_pct=0.2, seed=None, fn_col=0, folder=None, suff='',
                label_col=1, label_delim=None, y_block=None, valid_col=None,
                item_tfms=None, batch_tfms=None, reorder=False, resample=None, **kwargs):
        """Create from DataFrame."""
                    
        if y_block is None:
            is_multi = (is_listy(label_col) and len(label_col) > 1) or label_delim is not None
            y_block = MultiCategoryBlock if is_multi else CategoryBlock

        splitter = (RandomSplitter(valid_pct, seed=seed) 
                    if valid_col is None else ColSplitter(valid_col))

        dblock = MedDataBlock(
            blocks=(ImageBlock(cls=MedImage), y_block),
            get_x=ColReader(fn_col, suff=suff),
            get_y=ColReader(label_col, label_delim=label_delim),
            splitter=splitter,
            item_tfms=item_tfms,
            reorder=reorder,
            resample=resample
        )

        return cls.from_dblock(dblock, df, **kwargs)

# %% ../nbs/02_vision_data.ipynb 16
@dispatch
def show_batch(x: MedImage, y, samples, ctxs=None, max_n=6, nrows=None, 
               ncols=None, figsize=None, channel=0, slice_index=None,
               anatomical_plane=0, **kwargs):
    '''Showing a batch of samples for classification and regression tasks.'''

    if ctxs is None: 
        ctxs = get_grid(min(len(samples), max_n), nrows=nrows, ncols=ncols, figsize=figsize)
        
    n = 1 if y is None else 2
    
    for i in range(n):
        ctxs = [
            b.show(ctx=c, channel=channel, slice_index=slice_index, anatomical_plane=anatomical_plane, **kwargs)
            for b,c,_ in zip(samples.itemgot(i),ctxs,range(max_n))
        ]

    plt.tight_layout()
    
    return ctxs

# %% ../nbs/02_vision_data.ipynb 17
@dispatch
def show_batch(x: MedImage, y: MedMask, samples, ctxs=None, max_n=6, nrows=None,
               ncols=None, figsize=None, channel=0, slice_index=None,
               anatomical_plane=0, **kwargs):
                   
    """Showing a batch of decoded segmentation samples."""
                   
    nrows = min(len(samples), max_n)
    ncols = x.shape[1] + 1
    imgs = []
    slice_idxs = []

    fig, axs = subplots(nrows, ncols, figsize=figsize, **kwargs)
    axs = axs.flatten()

    for img, mask in zip(x, y):
        im_channels = [MedImage(c_img[None]) for c_img in img]
        im_channels.append(MedMask(mask))
        imgs.extend(im_channels)

        idx = find_max_slice(mask[0].numpy(), anatomical_plane) if slice_index is None else slice_index
        slice_idxs.extend([idx] * (img.shape[0] + 1))
                   
    ctxs = [im.show(ax=ax, slice_index=idx, anatomical_plane=anatomical_plane)
            for im, ax, idx in zip(imgs, axs, slice_idxs)]

    plt.tight_layout()

    return ctxs

# %% ../nbs/02_vision_data.ipynb 19
@dispatch
def show_results(x: MedImage, y: torch.Tensor, samples, outs, ctxs=None, max_n: int = 6,
                 nrows: int | None = None, ncols: int | None = None, figsize=None, channel: int = 0,
                 slice_index: int | None = None, anatomical_plane: int = 0, **kwargs):
    """Showing samples and their corresponding predictions for regression tasks."""

    if ctxs is None:
        ctxs = get_grid(min(len(samples), max_n), nrows=nrows,
                        ncols=ncols, figsize=figsize)

    for i in range(len(samples[0])):
        ctxs = [
            b.show(ctx=c, channel=channel, slice_index=slice_index,
                   anatomical_plane=anatomical_plane, **kwargs)
            for b, c, _ in zip(samples.itemgot(i), ctxs, range(max_n))
        ]

    for i in range(len(outs[0])):
        ctxs = [
            b.show(ctx=c, **kwargs)
            for b, c, _ in zip(outs.itemgot(i), ctxs, range(max_n))
        ]

    return ctxs

# %% ../nbs/02_vision_data.ipynb 20
@dispatch
def show_results(x: MedImage, y: TensorCategory, samples, outs, ctxs=None, 
                 max_n: int = 6, nrows: int | None = None, ncols: int | None = None, figsize=None, channel: int = 0, 
                 slice_index: int | None = None, anatomical_plane: int = 0, **kwargs):
    """Showing samples and their corresponding predictions for classification tasks."""

    if ctxs is None: 
        ctxs = get_grid(min(len(samples), max_n), nrows=nrows, 
                        ncols=ncols, figsize=figsize)
    
    for i in range(2):
        ctxs = [b.show(ctx=c, channel=channel, slice_index=slice_index, 
                       anatomical_plane=anatomical_plane, **kwargs) 
                for b, c, _ in zip(samples.itemgot(i), ctxs, range(max_n))]

    ctxs = [r.show(ctx=c, color='green' if b == r else 'red', **kwargs) 
            for b, r, c, _ in zip(samples.itemgot(1), outs.itemgot(0), ctxs, range(max_n))]

    return ctxs

# %% ../nbs/02_vision_data.ipynb 21
@dispatch
def show_results(x: MedImage, y: MedMask, samples, outs, ctxs=None, max_n: int = 6, 
                 nrows: int | None = None, ncols: int = 3, figsize=None, channel: int = 0, 
                 slice_index: int | None = None, anatomical_plane: int = 0, **kwargs):
    """Showing decoded samples and their corresponding predictions for segmentation tasks."""
    
    if ctxs is None: 
        total_slots = 3 * min(len(samples), max_n)
        ctxs = get_grid(total_slots, nrows=nrows, ncols=ncols, 
                        figsize=figsize, double=False, title='Image/Target/Prediction')
    
    slice_idxs = [find_max_slice(mask[0].numpy(), anatomical_plane) if slice_index is None else slice_index for mask in y]

    ctxs[::3] = [b.show(ctx=c, channel=channel, slice_index=idx, anatomical_plane=anatomical_plane, **kwargs)
                 for b, c, idx in zip(x, ctxs[::3], slice_idxs)]

    for i in range(2):
        current_channel = 0 if i == 1 else channel
        ctxs[1::3] = [b.show(ctx=c, channel=current_channel, slice_index=idx, 
                            anatomical_plane=anatomical_plane, **kwargs) 
                     for b, c, _, idx in zip(samples.itemgot(i), ctxs[1::3], range(2 * max_n), slice_idxs)]

    for index, o in enumerate([samples, outs]):
        current_channel = 0 if index == 1 else channel
        ctxs[2::3] = [b.show(ctx=c, channel=current_channel, slice_index=idx, 
                             anatomical_plane=anatomical_plane, **kwargs) 
                      for b, c, _, idx in zip(o.itemgot(0), ctxs[2::3], range(2 * max_n), slice_idxs)]

    return ctxs

# %% ../nbs/02_vision_data.ipynb 23
@dispatch
def plot_top_losses(x: MedImage, y: TensorCategory, samples, outs, raws, losses, nrows: int | None = None, 
                    ncols: int | None = None, figsize=None, channel: int = 0, slice_index: int | None = None, 
                    anatomical_plane: int = 0, **kwargs):
    """Show images in top_losses along with their prediction, actual, loss, and probability of actual class."""

    title = 'Prediction/Actual/Loss' if isinstance(y, torch.Tensor) else 'Prediction/Actual/Loss/Probability'
    axs = get_grid(len(samples), nrows=nrows, ncols=ncols, figsize=figsize, title=title)

    for ax, s, o, r, l in zip(axs, samples, outs, raws, losses):
        s[0].show(ctx=ax, channel=channel, slice_index=slice_index, anatomical_plane=anatomical_plane, **kwargs)

        if isinstance(y, torch.Tensor): 
            ax.set_title(f'{r.max().item():.2f}/{s[1]} / {l.item():.2f}')
        else: 
            ax.set_title(f'{o[0]}/{s[1]} / {l.item():.2f} / {r.max().item():.2f}')

# %% ../nbs/02_vision_data.ipynb 24
@dispatch
def plot_top_losses(x: MedImage, y: TensorMultiCategory, samples, outs, raws, 
                    losses, nrows: int | None = None, ncols: int | None = None, figsize=None, 
                    channel: int = 0, slice_index: int | None = None, 
                    anatomical_plane: int = 0, **kwargs):
    # TODO: not tested yet
    axs = get_grid(len(samples), nrows=nrows, ncols=ncols, figsize=figsize)

    for i, (ax, s) in enumerate(zip(axs, samples)):
        s[0].show(ctx=ax, title=f'Image {i}', channel=channel, 
                  slice_index=slice_index, anatomical_plane=anatomical_plane, **kwargs)

    rows = get_empty_df(len(samples))
    outs = L(s[1:] + o + (TitledStr(r), TitledFloat(l.item())) 
             for s, o, r, l in zip(samples, outs, raws, losses))

    for i, l in enumerate(["target", "predicted", "probabilities", "loss"]):
        rows = [b.show(ctx=r, label=l, channel=channel, slice_index=slice_index, 
                       anatomical_plane=anatomical_plane, **kwargs) 
                for b, r in zip(outs.itemgot(i), rows)]

    display_df(pd.DataFrame(rows))
