import random
import time
import visdom
import glob
import torch
import cv2
from torchvision.transforms import ToTensor, Compose, Normalize
from flask import request
from model.AttentiveRNN import AttentiveRNN
from model.Classifier import Resnet as Classifier
from subfuction.image_crop import crop_image

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# execute visdom instance first
# to do that, install visdom via pip and execute in terminal

def process_image():
    # vis = visdom.Visdom()
    arnn = AttentiveRNN(6, 3, 2)
    arnn.load_state_dict(torch.load("weights/ARNN_trained_weight_6_3_2.pt"))
    arnn.to(device=device)
    arnn.eval()
    crop_size = (512, 512)
    start_point = (750, 450)
    tf_toTensor = ToTensor()
    classifier = Classifier(in_ch=1)
    classifier.load_state_dict(torch.load("weights/classifier_e19_weight_1080p_512512_fixed_wrong_resolution_and_ch.pt"))
    classifier.to(device=device)
    classifier.eval()
    rainy_data_path = glob.glob("/home/takensoft/Desktop/KOLAS_TEST/정상/*.png")
    # rainy_data_path = glob.glob("/home/takensoft/Pictures/화창한날, 비오는날 프레임2000장/RAIN/**/**/*.png")
    # rainy_data_path = glob.glob("/home/takensoft/Pictures/폭우 빗방울 (475개)/*.png")
    img_path = rainy_data_path
    # clean_data_path = glob.glob("/home/takensoft/Documents/AttentiveRNNClassifier/output/original/*.png")

    # img_path = rainy_data_path + clean_data_path
    # normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    random.shuffle(img_path)

    index = []

    for i in iter(range(len(img_path))):
        ori_img = cv2.imread(img_path[i])
        image = crop_image(ori_img, crop_size, start_point)
        if not image.any():
            continue
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image_tensor = tf_toTensor(image)
        image_tensor = image_tensor.unsqueeze(0)
        image_tensor = image_tensor.to(device)
        with torch.no_grad():
            image_arnn = arnn(image_tensor)

        input_win = 'input_window'
        attention_map_wins = [f'attention_map_{i}' for i in range(6)]
        prediction_win = 'prediction_window'

        # Visualize attention maps using visdom
        # vis.images(
        #     image_tensor,
        #     opts=dict(title=f"input"),
        #     win=input_win
        # )
        # for idx, attention_map in enumerate(image_arnn['attention_map_list']):
        #     if idx == 0 or idx == 5:
        #         vis.images(
        #             attention_map.cpu(),  # Expected shape: (batch_size, C, H, W)
        #             opts=dict(title=f'Attention Map {idx + 1}'),
        #             win=attention_map_wins[idx]
        #         )
        # arnn_result = normalize(image_arnn['x'])
        with torch.no_grad():
            result = classifier(image_arnn['attention_map_list'][-1])
        result = result.to("cpu")
        _, predicted = torch.max(result.data, 1)
        print(result.data)
        print(_)
        print(predicted)
        index += (predicted)
        # # Load and display the corresponding icon
        # if predicted == 0:
        #     icon_path = 'asset/sun-svgrepo-com.png'
        # else:  # elif result == 1
        #     icon_path = 'asset/rain-svgrepo-com.png'

        # # Load icon and convert to tensor
        # icon_image = cv2.imread(icon_path, cv2.IMREAD_UNCHANGED)
        # transform = Compose([
        #     ToTensor()
        # ])
        # icon_tensor = transform(icon_image).unsqueeze(0)  # Add batch dimension

        # # Visualize icon using visdom
        # vis.images(
        #     icon_tensor,
        #     opts=dict(title='Weather Prediction'),
        #     win=prediction_win
        # )
        # time.sleep(1)
    print(index)
    # result = classifier(image_arnn['x'])
    # result = result.to("cpu")
    # _, predicted = torch.max(result.data, 1)
    # if predicted == 0:
    #     rain = False
    # else:  # elif result == 1
    #     rain = True
    # return {
    #     'rain': rain,
    # }, 200

if __name__ == "__main__":
    process_image()