from torch import nn, clamp
from torch.functional import F

class DiscriminativeNet(nn.Module):
    def __init__(self):
        super(DiscriminativeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=8, kernel_size=5, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=5, stride=2, padding=2)
        self.conv3 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=2, padding=2)
        self.conv4 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=2, padding=2)
        self.conv5 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=1, padding=2)
        self.conv6 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=5, stride=1, padding=2)
        self.conv_attention = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=5, stride=1, padding=2, bias=False)
        self.conv7 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=5, stride=4, padding=2)
        self.conv8 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5, stride=4, padding=2)
        self.conv9 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=5, stride=4, padding=2)
        self.fc1 = nn.Linear(32, 1)  # You need to adjust the input dimension here depending on your input size
        self.fc2 = nn.Linear(1, 1)
    def forward(self, x):
        x1 = F.leaky_relu(self.conv1(x))
        x2 = F.leaky_relu(self.conv2(x1))
        x3 = F.leaky_relu(self.conv3(x2))
        x4 = F.leaky_relu(self.conv4(x3))
        x5 = F.leaky_relu(self.conv5(x4))
        x6 = F.leaky_relu(self.conv6(x5))
        attention_map = self.conv_attention(x6)
        x7 = F.leaky_relu(self.conv7(attention_map * x6))
        x8 = F.leaky_relu(self.conv8(x7))
        x9 = F.leaky_relu(self.conv9(x8))
        x9 = x9.view(x9.size(0), -1)  # flatten the tensor
        fc1 = self.fc1(x9)
        fc_raw = self.fc2(fc1)
        fc_out = F.sigmoid(fc_raw)

        # Ensure fc_out is not exactly 0 or 1 for stability of log operation in loss
        fc_out = clamp(fc_out, min=1e-7, max=1 - 1e-7)

        ret = {
            "fc_out" : fc_out,
            "attention_map": attention_map,
            "fc_raw" : fc_raw
        }
        return ret

    def loss(self, real_clean, generated_clean, attention_map):
        """
        :param real_clean:
        :param generated_clean:
        :param attention_map: This is the final attention map from the generator.
        :return:
        """

        batch_size, image_h, image_w, _ = real_clean.size()

        zeros_mask = torch.zeros([batch_size, image_h, image_w, 1], dtype=torch.float32)

        # Inference function
        ret = self.forward(real_clean)
        fc_out_o, attention_mask_o, fc2_o = ret["fc_out"], ret["attention_map"], ret["fc_raw"]
        ret = self.forward(generated_clean)
        fc_out_r, attention_mask_r, fc2_r = ret["fc_out"], ret["attention_map"], ret["fc_raw"]

        l_map = F.mse_loss(attention_map, attention_mask_o) + \
                F.mse_loss(attention_mask_r, zeros_mask)

        entropy_loss = -torch.log(fc_out_r) - torch.log(-torch.sub(fc_out_o, 1.0))
        entropy_loss = torch.mean(entropy_loss)

        loss = entropy_loss + 0.05 * l_map

        return fc_out_o, loss

if __name__ == "__main__":
    import torch
    from torchinfo import summary

    torch.set_default_tensor_type(torch.FloatTensor)
    generator = DiscriminativeNet(960,540)
    batch_size = 1
    summary(generator, input_size=(batch_size, 3, 960,540))
