import random
import time
import visdom
import glob
import torch
import cv2
from torchvision.transforms import ToTensor, Compose, Normalize
from flask import request
from model.AttentiveRNN import AttentiveRNN
from model.Classifier import Resnet as Classifier
from subfuction.image_crop import crop_image

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# execute visdom instance first
# to do that, install visdom via pip and execute in terminal

def process_image():
    vis = visdom.Visdom()
    arnn = AttentiveRNN(6, 3, 2)
    arnn.load_state_dict(torch.load("weights/ARNN_trained_weight_6_3_2.pt"))
    arnn.to(device=device)
    arnn.eval()
    crop_size = (512, 512)
    start_point = (750, 450)
    tf_toTensor = ToTensor()
    classifier = Classifier(in_ch=1)
    classifier.load_state_dict(torch.load("weights/classifier_e19_weight_1080p_512512_fixed_wrong_resolution_and_ch.pt"))
    classifier.to(device=device)
    classifier.eval()
    rainy_data_path = glob.glob("/home/takensoft/Pictures/화창한날, 비오는날 프레임2000장/SUNNY/**/**/*.png")
    # rainy_data_path = glob.glob("/home/takensoft/Pictures/폭우 빗방울 (475개)/*.png")
    img_path = rainy_data_path
    # clean_data_path = glob.glob("/home/takensoft/Documents/AttentiveRNNClassifier/output/original/*.png")

    # img_path = rainy_data_path + clean_data_path
    # normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    random.shuffle(img_path)

    for i in iter(range(len(img_path))):
        ori_img = cv2.imread(img_path[i])
        image = crop_image(ori_img, crop_size, start_point)
        if not image.any():
            continue
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image_tensor = tf_toTensor(image)
        image_tensor = image_tensor.unsqueeze(0)
        image_tensor = image_tensor.to(device)
        image_arnn = arnn(image_tensor)

        input_win = 'input_window'
        attention_map_wins = [f'attention_map_{i}' for i in range(6)]
        prediction_win = 'prediction_window'

        # Visualize attention maps using visdom
        vis.images(
            image_tensor,
            opts=dict(title=f"input"),
            win=input_win
        )
        for idx, attention_map in enumerate(image_arnn['attention_map_list']):
            if idx == 0 or idx == 5:
                vis.images(
                    attention_map.cpu(),  # Expected shape: (batch_size, C, H, W)
                    opts=dict(title=f'Attention Map {idx + 1}'),
                    win=attention_map_wins[idx]
                )
        # arnn_result = normalize(image_arnn['x'])
        result = classifier(image_arnn['attention_map_list'][-1])
        result = result.to("cpu")
        _, predicted = torch.max(result.data, 1)
        print(result.data)
        print(_)
        print(predicted)
        # Load and display the corresponding icon
        if predicted == 0:
            icon_path = 'asset/sun-svgrepo-com.png'
        else:  # elif result == 1
            icon_path = 'asset/rain-svgrepo-com.png'

        # Load icon and convert to tensor
        icon_image = cv2.imread(icon_path, cv2.IMREAD_UNCHANGED)
        transform = Compose([
            ToTensor()
        ])
        icon_tensor = transform(icon_image).unsqueeze(0)  # Add batch dimension

        # Visualize icon using visdom
        vis.images(
            icon_tensor,
            opts=dict(title='Weather Prediction'),
            win=prediction_win
        )
        time.sleep(1)

    # result = classifier(image_arnn['x'])
    # result = result.to("cpu")
    # _, predicted = torch.max(result.data, 1)
    # if predicted == 0:
    #     rain = False
    # else:  # elif result == 1
    #     rain = True
    # return {
    #     'rain': rain,
    # }, 200

if __name__ == "__main__":
    process_image()