# Copyright (c) 2023-present, SUSTech-ML.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import torch
import torch.nn as nn
import torch.nn.functional as F


def build_DomainDetecor(in_dim, out_dim, device):
    return MidFNN(in_dim, out_dim).to(device)


class FNN(nn.Module):
    def __init__(self, n_in, n_out, n_hiddens, n_layers):
        super().__init__()

        models = []
        for i in range(n_layers):
            n = n_in if i == 0 else n_hiddens
            models.append(nn.Linear(n, n_hiddens))
            models.append(nn.ReLU())
            models.append(nn.Dropout(0.5))
        models.append(nn.Linear(n_hiddens if n_hiddens is not None else n_in, n_out))
        self.model = nn.Sequential(*models)

    def forward(self, x, training=False):
        if training:
            self.model.train()
        else:
            self.model.eval()
        logits = self.model(x)
        if logits.shape[1] == 1:
            probs = torch.sigmoid(logits)
        else:
            probs = F.softmax(logits, -1)
        return probs


class Linear(FNN):
    def __init__(self, n_in, n_out, n_hiddens=None):
        super().__init__(n_in, n_out, n_hiddens, n_layers=0)


class SmallFNN(FNN):
    def __init__(self, n_in, n_out, n_hiddens=500):
        super().__init__(n_in, n_out, n_hiddens, n_layers=1)


class MidFNN(FNN):
    def __init__(self, n_in, n_out, n_hiddens=500):
        super().__init__(n_in, n_out, n_hiddens, n_layers=2)


class BigFNN(FNN):
    def __init__(self, n_in, n_out, n_hiddens=500):
        super().__init__(n_in, n_out, n_hiddens, n_layers=4)


class IW(nn.Module):
    """
    Compute the importance weight
    """

    def __init__(self, domain_detector):
        super().__init__()

        self.domain_detector = domain_detector

    def forward(self, x_batch):
        prob = self.domain_detector(x_batch)
        if prob.shape[1] == 1:
            return prob / (1 - prob)
        else:
            return prob[:, 1] / prob[:, 0]
