import random

import numpy as np
from flask import Flask, request, jsonify
from flask_restx import Api, Resource, fields
import datetime
import psycopg2
import time
import base64
import json
import cv2
import requests
import typing
from requests_toolbelt import MultipartEncoder


debug = True

with open('config_files/MAIN_DB_ENDPOINT.json', 'r') as file:
    db_config = json.load(file)

app = Flask(__name__)
api = Api(app, version='1.0', title='CCTV Image Upload API',
          description='A postprocessing and adaptive rate mainserver data pusher')

image_upload_model = api.model('ImageUpload', {
    'image': fields.String(required=True, description='Image file', dt='File'),
    'x-cctv-info': fields.String(required=False, description='CCTV identifier'),
    'x-time-sent': fields.String(required=False, description='Time image was sent'),
    'x-cctv-latitude': fields.String(required=False, description='Latitude of CCTV'),
    'x-cctv-longitude': fields.String(required=False, description='Longitude of CCTV'),
    'X-Flag-Detected' : fields.String(required=True, description='If detected')
})

# Namespace definition
ns = api.namespace('postprocess', description='Postprocessing of inference results')

class StreamSources():
    def __init__(self, buffer_size, normal_send_interval, failure_mode_thres, failure_mode_check_past_n, normal_mode_thres, normal_mode_check_past_n):
        assert failure_mode_thres <= failure_mode_check_past_n,\
            (f"failure_mode checker condition is invaild!,"
             f" failure_mode needs {failure_mode_thres} fails in {failure_mode_check_past_n}, which is not possible!")
        assert normal_mode_thres <= normal_mode_check_past_n,\
            (f"normal_mode checker condition is invaild!,"
             f" normal_mode needs {normal_mode_thres} fails in {normal_mode_check_past_n}, which is not possible!")
        assert buffer_size >= failure_mode_check_past_n,\
            (f"'buffer_size' is smaller then failure_mode_thres! This is not possible!"
             f" This will cause program to never enter failure mode!! \n"
             f"Printing relevent args and shutting down!\n"
             f" buffer_size : {buffer_size}\n failure_mode_thres : {failure_mode_thres}")
        assert buffer_size >= normal_mode_check_past_n,\
            (f"'buffer_size' is smaller then normal_mode_thres!"
             f" This is will cause the program to never revert back to normal mode!! \n"
             f"Printing relevent args and shutting down!\n"
             f" buffer_size : {buffer_size}\n normal_mode_thres : {normal_mode_thres}")
    
        self.sources = {}
        self.buffer_size = buffer_size
        self.normal_send_interval = normal_send_interval
        
        if failure_mode_thres == failure_mode_check_past_n:
            self.switching_fail_consecutive_mode = True
        else:
            self.switching_fail_consecutive_mode = False
        if normal_mode_thres == normal_mode_check_past_n:
            self.switching_normal_consecutive_mode = True
        else:
            self.switching_normal_consecutive_mode = False
            
        self.failure_mode_thres = failure_mode_thres
        self.failure_mode_check_past_n = failure_mode_check_past_n
        self.normal_mode_thres = normal_mode_thres
        self.normal_mode_check_past_n = normal_mode_check_past_n


    def __setitem__(self, key, value):
        if key not in self.sources:
            self.sources[key] = {
                "status_counts": [],
                "ok_counts": 0,
                "failure_counts" : 0,
                "force_send_mode": False,
                "most_recent_image" : None,
                "most_recent_mask" : None,
                "most_recent_seg_iamge" : None,
                "cctv_info" : value,
                "last_send_before" : 0,
                "normal_to_failure_mode_change_alert" : False,
                "failure_to_normal_mode_change_alert" : False
            }
            self.cache_cctv_info = None
        else : 
            raise KeyError(f"Error! Source {key} already initialized.")
        # Update logic here if needed

    def __getitem__(self, key):
        return self.sources[key]

    def __call__(self):
        return self.sources

    def add_status(self, source, status, cctv_info):
        assert status in ["N", "Y"],\
            f"Invalid status was given!, status must be one of 'N'(== 'OK') or 'Y'(==== 'Fail'), but given '{status}'!"

        if source not in self.sources:
            raise ValueError(f"No key found for source. Did you forgot to add it? \n source : {source}")

        flag_send_event = False

        status_value = 1 if status == "N" else 0

        self.cache_cctv_info = cctv_info

        self.sources[source]["status_counts"].append(status_value)
        if len(self.sources[source]["status_counts"]) > self.buffer_size:
            self.sources[source]["status_counts"].pop(0)
        # print(len(self.sources[source]["status_counts"]))

        if self.sources[source]["force_send_mode"]:
            seek_n_recent_memory = min(len(self.sources[source]["status_counts"]), self.failure_mode_check_past_n)
            self.sources[source]['failure_counts'] = (seek_n_recent_memory
                              - sum(self.sources[source]["status_counts"][-seek_n_recent_memory:]))
            self.sources[source]['ok_counts'] = sum(self.sources[source]["status_counts"][-seek_n_recent_memory:])
            flag_send_event = True

            # mode switching condition check
            if self.sources[source]['ok_counts'] >= self.normal_mode_thres:
                self.sources[source]["force_send_mode"] = False
                flag_send_event = False
                self.sources[source]["failure_to_normal_mode_change_alert"] = True


        else:
            seek_n_recent_memory = min(len(self.sources[source]["status_counts"]), self.normal_mode_check_past_n)
            self.sources[source]['failure_counts'] = (seek_n_recent_memory
                              - sum(self.sources[source]["status_counts"][-seek_n_recent_memory:]))
            self.sources[source]['ok_counts'] = sum(self.sources[source]["status_counts"][-seek_n_recent_memory:])
            # print(self.sources[source]['failure_counts'])

            # mode switching condition check
            if self.sources[source]['failure_counts'] >= self.failure_mode_thres:
                self.sources[source]["force_send_mode"] = True
                flag_send_event = True
                self.sources[source]["normal_to_failure_mode_change_alert"] = True

            # regular interval message logic
            if self.sources[source]["last_send_before"] >= self.normal_send_interval:
                flag_send_event =True
            else :
                self.sources[source]["last_send_before"] += 1

        print(f"ok_counts : {self.sources[source]['ok_counts']}")
        print(f"last_send_before : {self.sources[source]['last_send_before']}")

        if flag_send_event:
            self.send_event(source)

        # alert alarms only once
        if self.sources[source]["failure_to_normal_mode_change_alert"]:
            self.sources[source]["failure_to_normal_mode_change_alert"] = False

        if self.sources[source]["normal_to_failure_mode_change_alert"]:
            self.sources[source]["normal_to_failure_mode_change_alert"] = False

    def send_event(self, source):
        print(f"{source} is now sending data!")
        source_data = self.sources[source]

        # Connect to the database
        conn = psycopg2.connect(**db_config)
        cursor = conn.cursor()

        # Set the search path for the schema
        cursor.execute("SET search_path TO ai_camera_v0_1;")

        # Prepare the SQL query
        insert_sql = """
        INSERT INTO flooding_detect_event (
            ocrn_dt, 
            eqpmn_nm, 
            flooding_result, 
            flooding_per, 
            image, 
            image_seg, 
            eqpmn_lat, 
            eqpmn_lon, 
            norm_to_alert_flag, 
            alert_to_norm_flag
        ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s);
        """
        # Prepare data to insert
        print(self.cache_cctv_info["time_sent"])
        data_tuple = (
            self.cache_cctv_info["time_sent"],
            source_data["cctv_info"]["cctv_name"],
            self.cache_cctv_info["detected"],
            self.cache_cctv_info["area_percent"],
            self.cache_cctv_info["source_frame"],
            self.cache_cctv_info["seg_frame"],
            source_data["cctv_info"]["cctv_latitude"],
            source_data["cctv_info"]["cctv_longitude"],
            "Y" if source_data["normal_to_failure_mode_change_alert"] else "N",
            "Y" if source_data["failure_to_normal_mode_change_alert"] else "N",
        )
        # print(hash(self.cache_cctv_info["source_frame"]))
        # print(self.cache_cctv_info["source_frame"][:100])
        # print(self.cache_cctv_info["source_frame"][-100:])

        # Execute the query
        cursor.execute(insert_sql, data_tuple)
        conn.commit()

        print(f"EVENT: Sent for {source} - Data inserted successfully.")

        if cursor:
            cursor.close()
        if conn:
            conn.close()

        # Reset the image data after sending to avoid re-sending the same image
        source_data["most_recent_image"] = None
        source_data["most_recent_seg_image"] = None
        source_data["last_send_before"] = 0


memory = StreamSources(
            buffer_size=15,
            normal_send_interval=10,
            failure_mode_thres=8,
            failure_mode_check_past_n=12,
            normal_mode_thres=8,
            normal_mode_check_past_n=12,
        )


def get_base64_encoded_image_from_file_binary(image):
    image = np.frombuffer(image, np.uint8)
    image = cv2.imdecode(image, cv2.IMREAD_COLOR)
    _, image = cv2.imencode('.jpg', image)
    image = image.tobytes()
    image = base64.b64encode(image).decode('utf-8')
    return image


@ns.route('/postprocess', )
class PostProcesser(Resource):
    def __init__(self, *args, **kargs):
        super().__init__(*args, **kargs)
        self.time_sent = None
        self.cctv_latitude = None
        self.cctv_longitude = None
        self.cctv_name = None
        self.cctv_info = None
        self.mask = None
        self.mask_blob = None
        self.image = None
        self.image_type = None
        self.seg_image = None
        self.area_percent = 0
        self.detected = False
        pass

    @ns.response(200, 'Success')
    @ns.response(400, 'Validation Error')
    def post(self):
        # try:
        # Gathering values
        self.image_type = request.headers.get('Content-Type')
        self.cctv_name = base64.b64decode(request.headers.get('x-cctv-name', '')).decode('UTF-8')
        self.time_sent = request.headers.get('x-time-sent', '')
        self.time_sent = datetime.datetime.strptime(self.time_sent, '%Y-%m-%dT%H:%M:%SZ')
        # self.time_sent = self.time_sent.timestamp()
        self.cctv_latitude = request.headers.get('x-cctv-latitude', 'Not provided')
        self.cctv_longitude = request.headers.get('x-cctv-longitude', 'Not provided')
        self.detected = request.headers.get('X-Flag-Detected')

        if self.detected == "True":
            self.detected = True
        elif self.detected == "False":
            self.detected = False
        else:
            raise ValueError(f"Invalid value for x-flag-detected: {self.detected}")

        self.area_percent = request.headers.get('x-area-percentage')
        try:
            self.area_percent = float(self.area_percent)
        except (TypeError, ValueError) as e:
            raise ValueError(f"Invalid value for x-area-percentage: {self.area_percent}")

        # gathering files
        try:
            self.image = request.files['image'].read()
        except:
            raise ValueError("Error reading 'image!'")


        if self.detected:
            try:
                # self.mask = request.files['mask'].read()
                self.seg_image = request.files['seg_image'].read()
            except:
                raise ValueError("Error reading 'mask' and 'seg_mask'")

        if debug:
            pass
            # self.image.save(f"network_test/image_p{time.time()}.png")
            # if self.detected :
                # self.mask.save(f"network_test/mask_p{time.time()}.png")
                # self.seg_image.save(f"network_test/seg_p{time.time()}.png")

        image_b64 = get_base64_encoded_image_from_file_binary(self.image)
        if self.detected:
            seg_image_b64 = get_base64_encoded_image_from_file_binary(self.seg_image)

        pass_fail = self.pass_fail()

        if self.detected:
            self.cctv_info = {
                'cctv_name': self.cctv_name,
                'cctv_latitude': self.cctv_latitude,
                'cctv_longitude': self.cctv_longitude,
                'source_frame': image_b64,
                # 'frame_mask': self.mask,
                'seg_frame': seg_image_b64,
                'time_sent': self.time_sent,
                'area_percent' : self.area_percent,
                'detected' : pass_fail
            }
        else :
            self.cctv_info = {
                'cctv_name': self.cctv_name,
                'cctv_latitude': self.cctv_latitude,
                'cctv_longitude': self.cctv_longitude,
                'source_frame': image_b64,
                # 'frame_mask': self.mask,
                'seg_frame': None,
                'time_sent': self.time_sent,
                'area_percent': self.area_percent,
                'detected': pass_fail
            }
        # if self.cctv_name in memory:

        try :
            memory[self.cctv_info['cctv_name']] = self.cctv_info
        except:
            pass

        if self.detected:
            memory.add_status(self.cctv_name, pass_fail, self.cctv_info)
        else :
            memory.add_status(self.cctv_name, pass_fail, self.cctv_info)

        # if debug:
        #     print(memory())
        #
        # except ValueError as e:
        #     print(e)
        # except Exception as e:
        #     print(e)

    def pass_fail(self):
        thres = 0.1
        #TODO temporal pass_fail threshold
        if self.area_percent > thres:
            # ret = 'FAIL'
            ret = 'Y'
        else:
            # ret = 'OK'
            ret = 'N'
        return ret


if __name__ == "__main__":
    print("Postprocess Online")
    app.run(debug=True, port=13579)