import sys import os import torch import numpy as np import pandas as pd import plotly.express as px from torchvision.utils import save_image from model import Autoencoder from model import Generator from model import Discriminator from model import AttentiveRNN from tools.argparser import get_param from tools.logger import Logger from tools.dataloader import Dataset # this function is from https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/dualgan/models.py def weights_init_normal(m): classname = m.__class__.__name__ if classname.find("Conv") != -1: torch.nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find("BatchNorm2d") != -1: torch.nn.init.normal_(m.weight.data, 1.0, 0.02) torch.nn.init.constant_(m.bias.data, 0.0) args = get_param() epochs = args.epochs batch_size = args.batch_size save_interval = args.save_interval sample_interval = args.sample_interval device = args.device load = args.load generator_learning_rate = args.generator_learning_rate generator_learning_miniepoch = args.generator_learning_miniepoch generator_attentivernn_blocks = args.generator_attentivernn_blocks generator_resnet_depth = args.generator_resnet_depth discriminator_learning_rate = args.discriminator_learning_rate if args.discriminator_learning_rate is not None else args.generator_learning_rate logger = Logger() cuda = True if torch.cuda.is_available() else False generator = Generator() # get network values and stuff discriminator = Discriminator() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') generator = Generator().to(device) discriminator = Discriminator().to(device) if load is not False: generator.load_state_dict(torch.load("example_path")) discriminator.load_state_dict(torch.load("example_path")) else: generator.apply(weights_init_normal) discriminator.apply(weights_init_normal) dataloader = Dataloader() # declare generator loss optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr) optimizer_D = torch.optim.Adam(generator.parameters(), lr=lr) for epoch_num, epoch in enumerate(range(epochs)): for i, (imgs, _) in enumerate(dataloader): logger.print_training_log(epoch_num, epochs, i, len(enumerate(dataloader))) img_batch = data[0].to(device) clean_img = img_batch["clean_image"] rainy_img = img_batch["rainy_image"] optimizer_G.zero_grad() generator_outputs = generator(clean_img) generator_result = generator_outputs["x"] generator_attention_map = generator_outputs["attention_map_list"] generator_loss = generator.loss(clean_img, rainy_img) generator_loss.backward() optimizer_G.step() optimizer_D.zero_grad() real_clean_prediction = discriminator(clean_img) discriminator_loss = discriminator.loss(real_clean_prediction, generator_result, generator_attention_map) discriminator_loss.backward() optimizer_D.step() torch.save(generator.attentionRNN.state_dict(), "attentionRNN_model_path") ## RNN 따로 돌리고 CPU로 메모리 옳기고 ## Autoencoder 따로 돌리고 메모리 옳기고 ## 안되는가 ## 대충 열심히 GAN 구성하는 코드 ## 대충 그래서 weight export해서 inference용과 training용으로 나누는 코드 ## 대충 그래서 inference용은 attention map까지 하는 녀석과 deraining까지 하는 녀석 두개가 나오는 코드 ## 학습용은 그래서 풀 weight 나옴 ## GAN은 학습 시키면 Nash equilibrium ... 나오게 할 수 있으려나? ## 대충 학습은 어떻게 돌려야 되지 하는 코드 ## generator에서 튀어 나온 애들을 따로 저장해야 하는건가