import torch
from torch import nn
from torch.nn import functional as F
from torchvision.models import vgg16


class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        #layers
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, stride=1, padding=2, bias=False)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1, bias=False)

        self.conv3 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv4 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, bias=False)

        self.conv5 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv6 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)

        self.dilated_conv1 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, dilation=2, padding=2, bias=False)
        self.dilated_conv2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, dilation=4, padding=4, bias=False)
        self.dilated_conv3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, dilation=8, padding=8, bias=False)
        self.dilated_conv4 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, dilation=16, padding=16, bias=False)

        self.conv7 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv8 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)

        self.deconv1 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False)
        self.avg_pool1 = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)

        self.deconv2 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False)
        self.avg_pool2 = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)

        self.conv9 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv10 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, bias=False)

        self.skip_output1 = nn.Conv2d(in_channels=256, out_channels=3, kernel_size=3, stride=1, padding=1, bias=False)
        self.skip_output2 = nn.Conv2d(in_channels=128, out_channels=3, kernel_size=3, stride=1, padding=1, bias=False)
        self.skip_output3 = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=3, stride=1, padding=1, bias=False)

        # Loss specific definitions
        # the paper uses vgg16 for features extraction,
        # however, since vgg16 is not a light model, we may consider it to be replaced
        self.vgg = vgg16(pretrained=True).features
        self.vgg.eval()
        for param in self.vgg.parameters():
            param.requires_grad = False
        self.lambda_i = [0.6, 0.8, 1.0]

    def forward(self, input_tensor):
        # Feed the input through each layer
        x = torch.relu(self.conv1(input_tensor))
        relu1 = x
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        relu3 = x
        x = torch.relu(self.conv4(x))
        x = torch.relu(self.conv5(x))
        x = torch.relu(self.conv6(x))
        x = torch.relu(self.dilated_conv1(x))
        x = torch.relu(self.dilated_conv2(x))
        x = torch.relu(self.dilated_conv3(x))
        x = torch.relu(self.dilated_conv4(x))
        x = torch.relu(self.conv7(x))
        x = torch.relu(self.conv8(x))
        relu12 = x

        deconv1 = self.deconv1(relu12)
        avg_pool1 = self.avg_pool1(deconv1)
        relu13 = torch.relu(avg_pool1)

        relu14 = torch.relu(self.conv9(relu13 + relu3))

        deconv2 = self.deconv2(relu14)
        avg_pool2 = self.avg_pool2(deconv2)
        relu15 = torch.relu(avg_pool2)

        relu16 = torch.relu(self.conv10(relu15 + relu1))

        skip_output_1 = self.skip_output1(relu12)
        skip_output_2 = self.skip_output2(relu14)
        skip_output_3 = torch.tanh(self.skip_output3(relu16))

        ret = {
            'skip_1': skip_output_1,
            'skip_2': skip_output_2,
            'skip_3': skip_output_3,
        }

        return ret

    def loss(self, input_tensor, label_tensor):
        ori_height, ori_width = label_tensor.shape[2:]

        # Rescale labels to match the scales of the outputs
        label_tensor_resize_2 = F.interpolate(label_tensor, size=(ori_height // 2, ori_width // 2))
        label_tensor_resize_4 = F.interpolate(label_tensor, size=(ori_height // 4, ori_width // 4))
        label_list = [label_tensor_resize_4, label_tensor_resize_2, label_tensor]

        inference_ret = self.forward(input_tensor)

        output_list = [inference_ret['skip_1'], inference_ret['skip_2'], inference_ret['skip_3']]

        # Compute lm_loss
        lm_loss = 0.0
        for index, output in enumerate(output_list):
            mse_loss = nn.MSELoss()(output, label_list[index]) * self.lambda_i[index]
            lm_loss += mse_loss

        # Compute lp_loss
        src_vgg_feats = self.vgg(label_tensor)
        pred_vgg_feats = self.vgg(output_list[-1])

        lp_losses = []
        for index in range(len(src_vgg_feats)):
            lp_losses.append(nn.MSELoss()(src_vgg_feats[index], pred_vgg_feats[index]))
        lp_loss = torch.mean(torch.stack(lp_losses))

        loss = lm_loss + lp_loss

        return loss, inference_ret['skip_3']