import torch
from torch import nn
from torch.nn import functional as F

# nn.Sequential does not handle multiple input by design, and this is a workaround
# https://github.com/pytorch/pytorch/issues/19808#
class mySequential(nn.Sequential):
    def forward(self, *input):
        for module in self._modules.values():
            input = module(*input)
        return input

def conv3x3(in_ch, out_ch, stride=1, padding=1, groups=1, dilation=1):
    return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=padding, groups=groups, dilation=dilation)


def conv1x1(in_ch, out_ch, stride=1):
    return nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride)


class ResNetBlock(nn.Module):
    def __init__(self, blocks=3, layers=1, input_ch=3, out_ch=32, kernel_size=None, stride=1, padding=1, groups=1,
                 dilation=1):
        """
        :type kernel_size: iterator or int
        """
        super(ResNetBlock, self).__init__()
        if kernel_size is None:
            kernel_size = [3, 3]
        self.conv1 = nn.Conv2d(
            input_ch, out_ch,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            groups=groups,
            dilation=dilation
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(
                out_ch, out_ch,
                kernel_size=kernel_size,
                stride=stride,
                padding=padding,
                groups=groups,
                dilation=dilation
            ),
            nn.LeakyReLU()
        )
        self.conv_hidden = nn.ModuleList()
        for block in range(blocks):
            for layer in range(layers):
                self.conv_hidden.append(
                    self.conv2
                )
        self.blocks = blocks
        self.layers = layers

    def forward(self, x):
        x = self.conv1(x)
        shortcut = x
        for i, hidden_layer in enumerate(self.conv_hidden):
            x = hidden_layer(x)
            if (i % self.layers == 0) & (i != 0):
                x = F.leaky_relu(x)
                x = x + shortcut
        return x


class ConvLSTM(nn.Module):
    def __init__(self, ch, kernel_size=3):
        super(ConvLSTM, self).__init__()
        self.padding = (len(kernel_size)-1)/2
        self.conv_i = nn.Conv2d(in_channels=ch, out_channels=ch, kernel_size=kernel_size, stride=1, padding=1,
                                bias=False)
        self.conv_f = nn.Conv2d(in_channels=ch, out_channels=ch, kernel_size=kernel_size, stride=1, padding=1,
                                bias=False)
        self.conv_c = nn.Conv2d(in_channels=ch, out_channels=ch, kernel_size=kernel_size, stride=1, padding=1,
                                bias=False)
        self.conv_o = nn.Conv2d(in_channels=ch, out_channels=ch, kernel_size=kernel_size, stride=1, padding=1,
                                bias=False)
        self.conv_attention_map = nn.Conv2d(in_channels=ch, out_channels=1, kernel_size=kernel_size, stride=1,
                                            padding=1, bias=False)
        self.ch = ch

    def init_hidden(self, batch_size, image_size, init=0.5):
        height, width = image_size
        return torch.ones(batch_size, self.ch, height, width).to(dtype=self.conv_i.weight.dtype , device=self.conv_i.weight.device) * init

    def forward(self, input_tensor, input_cell_state=None):
        if input_cell_state is None:
            batch_size, _, height, width = input_tensor.size()
            input_cell_state = self.init_hidden(batch_size, (height, width))

        conv_i = self.conv_i(input_tensor)
        sigmoid_i = torch.sigmoid(conv_i)

        conv_f = self.conv_f(input_tensor)
        sigmoid_f = torch.sigmoid(conv_f)

        cell_state = sigmoid_f * input_cell_state + sigmoid_i * torch.tanh(self.conv_c(input_tensor))

        conv_o = self.conv_o(input_tensor)
        sigmoid_o = torch.sigmoid(conv_o)

        lstm_feats = sigmoid_o * torch.tanh(cell_state)

        attention_map = self.conv_attention_map(lstm_feats)
        attention_map = torch.sigmoid(attention_map)

        return attention_map, cell_state, lstm_feats


class AttentiveRNNBLCK(nn.Module):
    def __init__(self, blocks=3, layers=1, input_ch=3, out_ch=32, kernel_size=None, stride=1, padding=1, groups=1,
                 dilation=1):
        """
        :type kernel_size: iterator or int
        """
        super(AttentiveRNNBLCK, self).__init__()
        if kernel_size is None:
            kernel_size = [3, 3]
        self.blocks = blocks
        self.layers = layers
        self.input_ch = input_ch
        self.out_ch = out_ch
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.groups = groups
        self.dilation = dilation
        self.sigmoid = nn.Sigmoid()
        self.resnet = nn.Sequential(
            ResNetBlock(
                blocks=self.blocks,
                layers=self.layers,
                input_ch=self.input_ch,
                out_ch=self.out_ch,
                kernel_size=self.kernel_size,
                stride=self.stride,
                padding=self.padding,
                groups=self.groups,
                dilation=self.dilation
            )
        )
        self.LSTM = mySequential(
            ConvLSTM(
                ch=out_ch, kernel_size=kernel_size,
            )
        )

    def forward(self, original_image, prev_cell_state=None):
        x = self.resnet(original_image)
        attention_map, cell_state, lstm_feats = self.LSTM(x, prev_cell_state)
        x = attention_map * original_image
        ret = {
            'x' : x,
            'attention_map' : attention_map,
            'cell_state' : cell_state,
            'lstm_feats' : lstm_feats
        }
        return ret


class AttentiveRNN(nn.Module):
    def __init__(self, repetition, blocks=3, layers=1, input_ch=3, out_ch=32, kernel_size=None, stride=1, padding=1,
                 groups=1, dilation=1):
        """
        :type kernel_size: iterator or int
        """
        super(AttentiveRNN, self).__init__()
        if kernel_size is None:
            kernel_size = [3, 3]
        self.blocks = blocks
        self.layers = layers
        self.input_ch = input_ch
        self.out_ch = out_ch
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.groups = groups
        self.dilation = dilation
        self.repetition = repetition
        self.generator_block = mySequential(
            AttentiveRNNBLCK(blocks=blocks,
                             layers=layers,
                             input_ch=input_ch,
                             out_ch=out_ch,
                             kernel_size=kernel_size,
                             stride=stride,
                             padding=padding,
                             groups=groups,
                             dilation=dilation)
        )
        self.generator_blocks = nn.ModuleList()
        for repetition in range(repetition):
            self.generator_blocks.append(
                self.generator_block
            )

    def forward(self, x):
        cell_state = None
        attention_map = []
        lstm_feats = []
        for generator_block in self.generator_blocks:
            generator_block_return = generator_block(x, cell_state)
            attention_map_i = generator_block_return['attention_map']
            lstm_feats_i = generator_block_return['lstm_feats']
            cell_state = generator_block_return['cell_state']
            x = generator_block_return['x']

            attention_map.append(attention_map_i)
            lstm_feats.append(lstm_feats_i)
        ret = {
            'x' : x,
            'attention_map_list' : attention_map,
            'lstm_feats' : lstm_feats
        }
        return ret

    #
    def loss(self, input_image_tensor, difference_maskmap, theta=0.8):
        self.theta = theta
        # Initialize attentive rnn model
        inference_ret = self.forward(input_image_tensor)
        loss = 0.0
        n = len(inference_ret['attention_map_list'])
        for index, attention_map in enumerate(inference_ret['attention_map_list']):
            mse_loss = (self.theta ** (n - index + 1)) * nn.MSELoss()(attention_map, difference_maskmap)
            loss += mse_loss
        return loss

# Need work

if __name__ == "__main__":
    from torchinfo import summary

    torch.set_default_tensor_type(torch.FloatTensor)

    generator = AttentiveRNN(3, blocks=2)
    batch_size = 5
    summary(generator, input_size=(batch_size, 3, 960,540))
