import sys
import os
import torch
import numpy as np
import pandas as pd
import plotly.express as px
from torchvision.utils import save_image

from model import Autoencoder
from model import Generator
from model import Discriminator
from model import AttentiveRNN
from tools.argparser import get_param
from tools.logger import Logger
from tools.dataloader import Dataset

# this function is from https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/dualgan/models.py
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

args = get_param()

epochs = args.epochs
batch_size = args.batch_size
save_interval = args.save_interval
sample_interval = args.sample_interval
device = args.device
load = args.load
generator_learning_rate = args.generator_learning_rate
generator_learning_miniepoch = args.generator_learning_miniepoch
generator_attentivernn_blocks = args.generator_attentivernn_blocks
generator_resnet_depth = args.generator_resnet_depth
discriminator_learning_rate = args.discriminator_learning_rate if args.discriminator_learning_rate is not None else args.generator_learning_rate

logger = Logger()

cuda = True if torch.cuda.is_available() else False

generator = Generator() # get network values and stuff
discriminator = Discriminator()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

generator = Generator().to(device)
discriminator = Discriminator().to(device)

if load is not False:
    generator.load_state_dict(torch.load("example_path"))
    discriminator.load_state_dict(torch.load("example_path"))
else:
    generator.apply(weights_init_normal)
    discriminator.apply(weights_init_normal)

dataloader = Dataloader()

# declare generator loss

optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr)
optimizer_D = torch.optim.Adam(generator.parameters(), lr=lr)

for epoch_num, epoch in enumerate(range(epochs)):
    for i, (imgs, _) in enumerate(dataloader):
        logger.print_training_log(epoch_num, epochs, i, len(enumerate(dataloader)))

        img_batch = data[0].to(device)
        clean_img = img_batch["clean_image"]
        rainy_img = img_batch["rainy_image"]

        optimizer_G.zero_grad()
        generator_outputs = generator(clean_img)
        generator_result = generator_outputs["x"]
        generator_attention_map = generator_outputs["attention_map_list"]
        generator_loss = generator.loss(clean_img, rainy_img)
        generator_loss.backward()
        optimizer_G.step()

        optimizer_D.zero_grad()
        real_clean_prediction = discriminator(clean_img)
        discriminator_loss = discriminator.loss(real_clean_prediction, generator_result, generator_attention_map)

        discriminator_loss.backward()
        optimizer_D.step()








torch.save(generator.attentionRNN.state_dict(), "attentionRNN_model_path")



## RNN 따로 돌리고 CPU로 메모리 옳기고
## Autoencoder 따로 돌리고 메모리 옳기고
## 안되는가

## 대충 열심히 GAN 구성하는 코드
## 대충 그래서 weight export해서 inference용과 training용으로 나누는 코드
## 대충 그래서 inference용은 attention map까지 하는 녀석과 deraining까지 하는 녀석 두개가 나오는 코드
## 학습용은 그래서 풀 weight 나옴
## GAN은 학습 시키면 Nash equilibrium ... 나오게 할 수 있으려나?
## 대충 학습은 어떻게 돌려야 되지 하는 코드
## generator에서 튀어 나온 애들을 따로 저장해야 하는건가
