"""Module to define the labeler for the BERT model."""

from torch import nn
from transformers import AutoModel, BertModel


class BertEncoder(nn.Module):
    """Module to define the labeler for the BERT model."""

    def __init__(  # pylint: disable=too-many-arguments
        self, logits, p=0.1, clinical=False, freeze_embeddings=False, pretrain_path=None
    ):
        """Init the labeler module.

        @param p (float): p to use for dropout in the linear heads, 0.1 by default is
                        consistant with transformers.BertForSequenceClassification
        @param clinical (boolean): True if Bio_Clinical BERT desired, False otherwise. Ignored if
                                   pretrain_path is not None
        @param freeze_embeddings (boolean): true to freeze bert embeddings during training
        @param pretrain_path (string): path to load checkpoint from
        """
        super().__init__()

        if pretrain_path is not None:
            self.bert = BertModel.from_pretrained(pretrain_path)
        elif clinical:
            self.bert = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
        else:
            self.bert = BertModel.from_pretrained("bert-base-uncased")

        if freeze_embeddings:
            for param in self.bert.embeddings.parameters():
                param.requires_grad = False
        self.logits = logits
        self.dropout = nn.Dropout(p)
        # size of the output of transformer's last layer
        hidden_size = self.bert.pooler.dense.in_features
        # classes: present, absent, unknown, blank for 12 conditions + support devices
        self.linear_heads = nn.ModuleList(
            [nn.Linear(hidden_size, 4, bias=True) for _ in range(13)]
        )
        # classes: yes, no for the 'no finding' observation
        self.linear_heads.append(nn.Linear(hidden_size, 2, bias=True))

    def forward(self, source_padded, attention_mask):
        """Forward pass of the labeler.

        @param source_padded (torch.LongTensor): Tensor of word indices with padding,
                            shape (batch_size, max_len)
        @param attention_mask (torch.Tensor): Mask to avoid attention on padding
                            tokens, shape (batch_size, max_len)
        @returns out (List[torch.Tensor])): A list of size 14 containing tensors.
                            The first 13 have shape (batch_size, 4) and the last
                            has shape (batch_size, 2)
        """
        # shape (batch_size, max_len, hidden_size)
        final_hidden = self.bert(source_padded, attention_mask=attention_mask)[0]
        # shape (batch_size, hidden_size)
        cls_hidden = final_hidden[:, 0, :].squeeze(dim=1)
        out = cls_hidden

        if self.logits:
            cls_hidden = self.dropout(cls_hidden)
            out = []
            for i in range(14):
                out.append(self.linear_heads[i](cls_hidden))
        return out
