from flask_restx import Resource, Namespace
from flask import request
from werkzeug.utils import secure_filename
import os
from database.database import DB
import torch
from torchvision.transforms import ToTensor
from datetime import datetime
from model.AttentiveRNN import AttentiveRNN
from model.Classifier import Resnet as Classifier
from subfuction.image_crop import crop_image

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

paths = os.getcwd()

Action = Namespace(
    name="Action",
    description="노드 분석을 위해 사용하는 api.",
)


@Action.route('/image_summit')
class fileUpload(Resource):
    @Action.doc(responses={200: 'Success'})
    @Action.doc(responses={500: 'Register Failed'})
    def post(self):
        if request.method == 'POST':
            f = request.files['file']
            f.save(secure_filename(f.filename))
            return {
                'save': 'done'  # str으로 반환하여 return
            }, 200


crop_size = (512, 512)
start_point = (750, 450)
tf_toTensor = ToTensor()

@Action.route('/image_anal')
class fileUpload(Resource):
    @Action.doc(responses={200: 'Success'})
    @Action.doc(responses={500: 'Register Failed'})
    def post(self):
        if request.method == 'POST':
            db = DB()
            arnn = AttentiveRNN(6, 3, 2)
            arnn.load_state_dict(torch.load("weights/ARNN_trained_weight_6_3_2.pt"))
            arnn.to(device=device)
            clasifier = Classifier()
            clasifier.load_state_dict(torch.load("weights/Classifier_512.pt"))
            clasifier.to(device=device)
            dir = os.getcwd()
            lat = float(request.json['gps_x'])
            lon = float(request.json['gps_y'])
            filename = request.json['filename']
            file_type = request.json['file_type']
            total_path = dir + os.path.sep + filename + file_type
            image = crop_image(total_path, crop_size, start_point)
            if not image:
                return {
                    'node': (lat, lon),
                    'rain': 'rain',
                }, 500
            image_tensor = tf_toTensor(image)
            image_tensor.to(device)
            image_arnn = AttentiveRNN(image_tensor)
            result = Classifier(image_arnn)
            result = result.to("cpu")
            if result == 0:
                rain = False
            else: # elif result == 1
                rain = True
            user_id = 'test'
            action_success = True
            action_id = 'test'
            db.db_add_action(action_id, lat, lon, user_id, action_success)
            return {
                'node': (lat, lon),
                'rain': rain,
            }, 200


@Action.route('/action_display')
class fileUpload(Resource):
    @Action.doc(responses={200: 'Success'})
    @Action.doc(responses={500: 'Register Failed'})
    def post(self):
        if request.method == 'GET':
            db = DB()
            now = datetime.now()
            d = now.strftime('%Y-%m-%d %X')
            value = db.db_display_action(d)
            return {
                'report': list(value)
            }, 200
