from AttentiveRNN import AttentiveRNN
from Autoencoder import AutoEncoder
from torch import nn


class Generator(nn.Module):
    def __init__(self, repetition, blocks=3, layers=1, input_ch=3, out_ch=32, kernel_size=None, stride=1, padding=1, groups=1,
                 dilation=1):
        super(Generator, self).__init__()
        if kernel_size is None:
            kernel_size = [3, 3]
        self.attentiveRNN = AttentiveRNN( repetition,
            blocks=blocks, layers=layers, input_ch=input_ch, out_ch=out_ch,
            kernel_size=None, stride=stride, padding=padding, groups=groups, dilation=dilation
        )
        self.autoencoder = AutoEncoder()
        self.blocks = blocks
        self.layers = layers
        self.input_ch = input_ch
        self.out_ch = out_ch
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.groups = groups
        self.dilation = dilation
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x, attention_map = self.attentiveRNN(x)
        x = self.autoencoder(x * attention_map)
        return x

if __name__ == "__main__":
    import torch
    from torchinfo import summary

    torch.set_default_tensor_type(torch.FloatTensor)
    generator = Generator(3, blocks=2)
    batch_size = 2
    summary(generator, input_size=(batch_size, 3, 960,540))
