from model.AttentiveRNN import AttentiveRNN
from model.Autoencoder import AutoEncoder
from torch import nn
from torch import sum, pow, abs


class Generator(nn.Module):
    def __init__(self, repetition, blocks=3, layers=1, input_ch=3, out_ch=32, kernel_size=None, stride=1, padding=1, groups=1,
                 dilation=1):
        super(Generator, self).__init__()
        if kernel_size is None:
            kernel_size = [3, 3]
        self.attentiveRNN = AttentiveRNN(repetition,
            blocks=blocks, layers=layers, input_ch=input_ch, out_ch=out_ch,
            kernel_size=None, stride=stride, padding=padding, groups=groups, dilation=dilation
        )
        self.autoencoder = AutoEncoder()
        self.blocks = blocks
        self.layers = layers
        self.input_ch = input_ch
        self.out_ch = out_ch
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.groups = groups
        self.dilation = dilation
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        attentiveRNNresults = self.attentiveRNN(x)
        x = self.autoencoder(attentiveRNNresults['x'] * attentiveRNNresults['attention_map_list'][-1])
        ret = {
            'x' : x,
            'attention_maps' : attentiveRNNresults['attention_map_list']
        }
        return ret

    def binary_diff_mask(self, clean, dirty, thresold=0.1):
        # this parts corrects gamma, and always remember, sRGB values are not in linear scale with lights intensity,
        clean = pow(clean, 0.45)
        dirty = pow(dirty, 0.45)
        diff = abs(clean - dirty)
        diff = sum(diff, dim=1)

        bin_diff = (diff < thresold).to(clean.dtype)

        return bin_diff

    def loss(self, clean, dirty, thresold=0.1):
        # check diff if they are working as intended
        diff_mask = self.binary_diff_mask(clean, dirty, thresold)

        attentive_rnn_loss = self.attentiveRNN.loss(clean, diff_mask)
        autoencoder_loss = self.autoencoder.loss(clean, dirty)
        ret = {
            "attentive_rnn_loss" : attentive_rnn_loss,
            "autoencoder_loss" : autoencoder_loss,
        }
        return ret

if __name__ == "__main__":
    import torch
    from torchinfo import summary

    torch.set_default_tensor_type(torch.FloatTensor)
    generator = Generator(3, blocks=2)
    batch_size = 2
    summary(generator, input_size=(batch_size, 3, 720,720))
