from AttentiveRNN import AttentiveRNN
from Autoencoder import AutoEncoder
from torch import nn


class Generator(nn.Module):
    def __init__(self, repetition, blocks=3, layers=1, input_ch=3, out_ch=32, kernel_size=None, stride=1, padding=1, groups=1,
                 dilation=1):
        super(Generator, self).__init__()
        if kernel_size is None:
            kernel_size = [3, 3]
        self.attentiveRNN = AttentiveRNN(repetition,
            blocks=blocks, layers=layers, input_ch=input_ch, out_ch=out_ch,
            kernel_size=None, stride=stride, padding=padding, groups=groups, dilation=dilation
        )
        self.autoencoder = AutoEncoder()
        self.blocks = blocks
        self.layers = layers
        self.input_ch = input_ch
        self.out_ch = out_ch
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.groups = groups
        self.dilation = dilation
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        attentiveRNNresults = self.attentiveRNN(x)
        x = self.autoencoder(attentiveRNNresults['x'] * attentiveRNNresults['attention_map_list'][-1])
        ret = {
            'x' : x,
            'attention_maps' : attentiveRNNresults['attention_map_list']
        }
        return ret

    def loss(self, x, diff):
        self.attentiveRNN.loss(x, diff)
        self.autoencoder.loss(x)

if __name__ == "__main__":
    import torch
    from torchinfo import summary

    torch.set_default_tensor_type(torch.FloatTensor)
    generator = Generator(3, blocks=2)
    batch_size = 2
    summary(generator, input_size=(batch_size, 3, 720,720))
