• Y
  • List All
  • Feedback
    • This Project
    • All Projects
Profile Account settings Log out
  • Favorite
  • Project
  • All
Loading...
  • Log in
  • Sign up
yjyoon / Raindrop_Detection star
  • Project homeH
  • CodeC
  • IssueI
  • Pull requestP
  • Review R
  • MilestoneM
  • BoardB
  • Files
  • Commit
  • Branches
Raindrop_Detectiontrain.py
Download as .zip file
File name
Commit message
Commit date
data
data cleaning
2023-07-12
model
typo
2023-07-24
tools
dataloader refactoring
2023-07-13
.gitignore
data cleaning
2023-07-12
README.md
readme update
2023-06-21
batchmix.png
theorizing train code for GAN
2023-06-22
binary mask map test.py
data cleaning
2023-07-12
datasetmananger.py
theorizing about dataset management
2023-06-23
inference.py
configuring saving method and inference method for model
2023-07-11
main.py
Hello YONA
2023-06-21
train.py
typo
2023-07-24
yjyoon 2023-07-24 97e1e42 typo UNIX
Raw Open in browser Change history
import sys import os import torch import glob import numpy as np import time import pandas as pd import subprocess import atexit import torchvision.transforms import cv2 from visdom import Visdom from torchvision.utils import save_image from torchvision import transforms from torchvision.transforms import RandomCrop, RandomPerspective, Compose from torch.utils.data import DataLoader from time import gmtime, strftime from model.Autoencoder import AutoEncoder from model.Generator import Generator from model.Discriminator import DiscriminativeNet as Discriminator from model.AttentiveRNN import AttentiveRNN from tools.argparser import get_param from tools.logger import Logger from tools.dataloader import ImagePairDataset args = get_param() # I am doing this for easier debugging # when you have error on those variables without doing this, # you will be in trouble because error message will not say anything. epochs = args.epochs batch_size = args.batch_size save_interval = args.save_interval sample_interval = args.sample_interval num_worker = args.num_worker device = args.device load = args.load generator_learning_rate = args.generator_learning_rate generator_ARNN_learning_rate = args.generator_ARNN_learning_rate if args.discriminator_learning_rate is not None else args.generator_learning_rate generator_learning_miniepoch = args.generator_learning_miniepoch generator_attentivernn_blocks = args.generator_attentivernn_blocks generator_resnet_depth = args.generator_resnet_depth discriminator_learning_rate = args.discriminator_learning_rate if args.discriminator_learning_rate is not None else args.generator_learning_rate logger = Logger() cuda = True if torch.cuda.is_available() else False device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') generator = Generator(generator_attentivernn_blocks, generator_resnet_depth).to(device=device) discriminator = Discriminator().to(device=device) if load is not None: generator.attentiveRNN.load_state_dict(torch.load(load)) generator.autoencoder.load_state_dict(torch.load(load)) discriminator.load_state_dict(torch.load(load)) else: pass # 이건 땜빵이고 차후에 데이터 관리 모듈 만들꺼임 rainy_data_path = glob.glob("data/source/Oxford_raindrop_dataset/dirty/*.png") rainy_data_path = sorted(rainy_data_path) clean_data_path = glob.glob("data/source/Oxford_raindrop_dataset/clean/*.png") clean_data_path = sorted(clean_data_path) height = 480 width = 720 transform = Compose([ RandomPerspective(), RandomCrop((height, width)) ]) resize = torchvision.transforms.Resize((height, width), antialias=True) dataset = ImagePairDataset(clean_data_path, rainy_data_path, transform=transform) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4) # declare generator loss optimizer_G = torch.optim.Adam(generator.parameters(), lr=generator_learning_rate) optimizer_G_ARNN = torch.optim.Adam(generator.attentiveRNN.parameters(), lr=generator_ARNN_learning_rate, betas=(0.5, 0.999)) optimizer_G_AE = torch.optim.Adam(generator.autoencoder.parameters(), lr=generator_learning_rate, betas=(0.5, 0.999)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=discriminator_learning_rate) # ------visdom visualizer ---------- server_process = subprocess.Popen("python -m visdom.server", shell=True) # to ensure the visdom server process must stop whenever the script is terminated. def cleanup(): server_process.terminate() atexit.register(cleanup) time.sleep(10) vis = Visdom(server="http://localhost", port=8097) ARNN_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Generator-AttentionRNN Loss')) AE_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Generator-AutoEncoder Loss')) Discriminator_loss_window = vis.line(Y=np.array([0]), X=np.array([0]), opts=dict(title='Discriminator Loss')) Attention_map_visualizer = vis.image(np.zeros((height, width)), opts=dict(title='Attention Map')) Difference_mask_map_visualizer = vis.image(np.zeros((height, width)), opts=dict(title='Mask Map')) Generator_output_visualizer = vis.image(np.zeros((height, width)), opts=dict(title='Generated Derain Output')) Input_image_visualizer = vis.image(np.zeros((height, width)), opts=dict(title='input clean image')) for epoch_num, epoch in enumerate(range(epochs)): for i, imgs in enumerate(dataloader): img_batch = imgs clean_img = img_batch["clean_image"] / 255 rainy_img = img_batch["rainy_image"] / 255 # clean_img = clean_img.to() # rainy_img = rainy_img.to() clean_img = clean_img.to(device=device) rainy_img = rainy_img.to(device=device) optimizer_G_ARNN.zero_grad() optimizer_G_AE.zero_grad() attentiveRNNresults = generator.attentiveRNN(rainy_img) generator_attention_map = attentiveRNNresults['attention_map_list'] binary_difference_mask = generator.binary_diff_mask(clean_img, rainy_img, thresold=0.24) generator_loss_ARNN = generator.attentiveRNN.loss(rainy_img, binary_difference_mask) generator_loss_ARNN.backward() optimizer_G_ARNN.step() generator_outputs = generator.autoencoder(rainy_img * attentiveRNNresults['attention_map_list'][-1]) generator_result = generator_outputs['skip_3'] generator_output = generator_outputs['output'] generator_loss_AE = generator.autoencoder.loss(clean_img, rainy_img) generator_loss_AE.backward() optimizer_G_AE.step() optimizer_D.zero_grad() real_clean_prediction = discriminator(clean_img) fake_clean_prediction = discriminator(generator_result)["fc_out"] discriminator_loss = discriminator.loss(clean_img, generator_result, generator_attention_map) discriminator_loss.backward() optimizer_D.step() # Total loss optimizer_G.zero_grad() generator_loss_whole = generator_loss_AE + generator_loss_ARNN + torch.mean( torch.log(torch.subtract(1, fake_clean_prediction)) ) optimizer_G.step() losses = { "generator_loss_ARNN": generator_loss_ARNN, "generator_loss_AE" : generator_loss_AE, "discriminator loss" : discriminator_loss } logger.print_training_log(epoch_num, epochs, i, len(dataloader), losses) # visdom logger vis.line(Y=np.array([generator_loss_ARNN.item()]), X=np.array([epoch_num * epochs + i]), win=ARNN_loss_window, update='append') vis.line(Y=np.array([generator_loss_AE.item()]), X=np.array([epoch_num * epochs + i]), win=AE_loss_window, update='append') vis.line(Y=np.array([discriminator_loss.item()]), X=np.array([epoch_num * epochs + i]), win=Discriminator_loss_window, update='append') vis.image(generator_attention_map[-1][0, 0, :, :], win=Attention_map_visualizer, opts=dict(title="Attention Map")) vis.image(binary_difference_mask[-1], win=Difference_mask_map_visualizer, opts=dict(title="Binary Mask Map")) vis.image(generator_result[-1], win=Generator_output_visualizer, opts=dict(title="Generator Output")) vis.image(clean_img[-1], win=Input_image_visualizer, opts=dict(title="input clean image")) day = strftime("%Y-%m-%d %H:%M:%S", gmtime()) if epoch % save_interval == 0 and epoch != 0: torch.save(generator.attentiveRNN.state_dict(), f"weight/Attention_RNN_{epoch}_{day}.pt") torch.save(generator.state_dict(), f"weight/Generator_{epoch}_{day}.pt") torch.save(discriminator.state_dict(), f"weight/Discriminator_{epoch}_{day}.pt") server_process.terminate() ## RNN 따로 돌리고 CPU로 메모리 옳기고 ## Autoencoder 따로 돌리고 메모리 옳기고 ## 안되는가 ## 대충 열심히 GAN 구성하는 코드 ## 대충 그래서 weight export해서 inference용과 training용으로 나누는 코드 ## 대충 그래서 inference용은 attention map까지 하는 녀석과 deraining까지 하는 녀석 두개가 나오는 코드 ## 학습용은 그래서 풀 weight 나옴 ## GAN은 학습 시키면 Nash equilibrium ... 나오게 할 수 있으려나? ## 대충 학습은 어떻게 돌려야 되지 하는 코드 ## generator에서 튀어 나온 애들을 따로 저장해야 하는건가

          
        
    
    
Copyright Yona authors & © NAVER Corp. & NAVER LABS Supported by NAVER CLOUD PLATFORM

or
Sign in with github login with Google Sign in with Google
Reset password | Sign up