from flask_restx import Resource, Namespace
from flask import request, jsonify
import os
import json
from database.database import DB
import torch
from torchvision.transforms import ToTensor
from datetime import datetime
from model.AttentiveRNN import AttentiveRNN
from model.Classifier import Resnet as Classifier
from subfuction.image_crop import crop_image
import numpy as np
import cv2

db = DB()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# pre-loading models
arnn = AttentiveRNN(6, 3, 2)
arnn.eval()
arnn.load_state_dict(torch.load("weights/ARNN_trained_weight_6_3_2.pt"))
arnn.to(device=device)

classifier = Classifier()
classifier.eval()
classifier.load_state_dict(torch.load("weights/Classifier_512.pt"))
classifier.to(device=device)

tf_toTensor = ToTensor()
crop_size = (512, 512)
start_point = (750, 450)
root_dir = os.getcwd()

Action = Namespace(
    name="Action",
    description="노드 분석을 위해 사용하는 api.",
)


# @Action.route('/image_summit')
# class fileUpload(Resource):
#     @Action.doc(responses={200: 'Success'})
#     @Action.doc(responses={500: 'Register Failed'})
#     def post(self):
#         if request.method == 'POST':
#             f = request.files['file']
#             f.save(secure_filename(f.filename))
#             return {
#                 'save': 'done'  # str으로 반환하여 return
#             }, 200


@Action.route('/image_anal')
class fileUpload(Resource):
    @Action.doc(responses={200: 'Success'})
    @Action.doc(responses={500: 'Register Failed'})
    def post(self):
        # Extracting JSON data
        json_data = request.form.get('data')
        if not json_data:
            return jsonify({"message": "Missing JSON data"}), 400
        data = json.loads(json_data)

        lat = float(data['gps_x'])
        lon = float(data['gps_y'])
        filename = data['filename']
        file_type = data['file_type']

        uploaded_file = request.files.get('file')
        file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
        image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)

        image = crop_image(image, crop_size, start_point)

        image_tensor = tf_toTensor(image)
        image_tensor = image_tensor.unsqueeze(0)
        image_tensor = image_tensor.to(device)
        with torch.no_grad():
            image_arnn = arnn(image_tensor)
            image_tensor.cpu()
            del image_tensor
            result = classifier(image_arnn['x'])
            image_arnn['x'].cpu()
            del image_arnn

        result = result.to("cpu")
        _, predicted = torch.max(result.data, 1)
        del result
        if predicted == 0:
            rain = False
        else:  # elif result == 1
            rain = True
        user_id = 'test'
        action_success = True
        action_id = 'test'
        db.db_add_action(action_id, lat, lon, user_id, action_success)
        return {
            'node': (lat, lon),
            'rain': rain,
        }, 200


@Action.route('/action_display')
class fileUpload(Resource):
    @Action.doc(responses={200: 'Success'})
    @Action.doc(responses={500: 'Register Failed'})
    def get(self):
        if request.method == 'GET':
            db = DB()
            value = db.db_display_action()
            return {
                'report': list(value)
            }, 200
