import numpy as np
from flask import Flask, request, jsonify
from flask_restx import Api, Resource, fields
import os
from datetime import datetime
from yoloseg.inference_ import Inference, overlay_mask
import cv2
import time
import base64
import requests
import typing
from requests_toolbelt import MultipartEncoder

# from config_files import API_ENDPOINT_MAIN

app = Flask(__name__)
api = Api(app, version='1.0', title='CCTV Image Upload API',
          description='A postprocessing and adaptive rate mainserver data pusher')

image_upload_model = api.model('ImageUpload', {
    'image': fields.String(required=True, description='Image file', dt='File'),
    'x-cctv-info': fields.String(required=False, description='CCTV identifier'),
    'x-time-sent': fields.String(required=False, description='Time image was sent'),
    'x-cctv-latitude': fields.String(required=False, description='Latitude of CCTV'),
    'x-cctv-longitude': fields.String(required=False, description='Longitude of CCTV'),
    'X-Flag-Detected' : fields.String(required=True, description='If detected')
})

# Namespace definition
ns = api.namespace('postprocess', description='Postprocessing of inference results')

class StreamSources():
    def __init__(self, buffer_size, normal_send_interval, failure_mode_thres, failure_mode_check_past_n, normal_mode_thres, normal_mode_check_past_n):
        assert failure_mode_thres <= failure_mode_check_past_n, f"failure_mode checker condition is invaild!, failure_mode needs {failure_mode_thres} fails in {failure_mode_check_past_n}, which is not possible!"
        assert normal_mode_thres <= normal_mode_check_past_n, f"normal_mode checker condition is invaild!, normal_mode needs {normal_mode_thres} fails in {normal_mode_check_past_n}, which is not possible!"
        assert buffer_size >= failure_mode_check_past_n, f"'buffer_size' is smaller then failure_mode_thres! This is not possible! This will cause program to never enter failure mode!! \nPrinting relevent args and shutting down!\n buffer_size : {buffer_size}\n failure_mode_thres : {failure_mode_thres}"
        assert buffer_size >= normal_mode_check_past_n, f"'buffer_size' is smaller then normal_mode_thres! This is will cause the program to never revert back to normal mode!! \nPrinting relevent args and shutting down!\n buffer_size : {buffer_size}\n normal_mode_thres : {normal_mode_thres}"
    
        self.sources = {}
        self.buffer_size = buffer_size
        self.normal_send_interval = normal_send_interval
        
        if failure_mode_thres == failure_mode_check_past_n:
            self.switching_fail_consecutive_mode = True
        else:
            self.switching_fail_consecutive_mode = False
        if normal_mode_thres == normal_mode_check_past_n:
            self.switching_normal_consecutive_mode = True
        else:
            self.switching_normal_consecutive_mode = False
            
        self.failure_mode_thres = failure_mode_thres
        self.failure_mode_check_past_n = failure_mode_check_past_n
        self.normal_mode_thres = normal_mode_thres

    def __setitem__(self, key, value):
        if key not in self.sources:
            self.sources[key] = {
                "status_counts": [],
                "ok_counts": 0,
                "force_send_mode": False,
                "most_recent_image" : None,
                "most_recent_mask" : None,
                "most_recent_seg_iamge" : None,
                "cctv_name" : value,
            }
        else : 
            raise KeyError(f"Error! Source {key} already initialized.")
        # Update logic here if needed

    def __getitem__(self, key):
        return self.sources[key]

    def add_status(self, source, status):
        assert status in ["OK", "FAIL"], f"Invalid status was given!, status must be one of 'OK' or 'FAIL', but given '{status}'!"
        
        if source not in self.sources:
            raise ValueError(f"No key found for source. Did you forgot to add it? \n source : {source}")

        self.sources[source]["status_counts"].append(status)
        if len(self.sources[source]["status_counts"]) > self.buffer_size:
            self.sources[source]["status_counts"].pop(0)
        
        # Your existing logic for updating counts and checking statuses
        if status == 'OK':
            self.sources[source]["ok_counts"] += 1
            if self.sources[source]["force_send_mode"] and self.sources[source]["ok_counts"] >= self.normal_mode_thres:
                self.sources[source]["force_send_mode"] = False
                self.send_message(source, "NORMAL SEND")
        else:
            self.sources[source]["ok_counts"] = 0  # Reset on FAIL
            self.check_failures(source)

    def check_failures(self, source):
        if self.switching_fail_consecutive_mode:
            if (len(self.sources[source]["status_counts"]) >= self.failure_mode_thres
                    and all(status == 'FAIL' for status in self.sources[source]["status_counts"][-self.failure_mode_thres:])):
                print(f"Source {source} has 5 consecutive FAILs!")
                self.sources[source]["force_send_mode"] = True
                self.send_message(source, "FORCE SEND")
        else :
            pass

    def send_message(self, source, message_type):
        print(f"Sending message for {source} - Status: {message_type}")
        # Reset the count after sending message
        self.sources[source]["ok_counts"] = 0


@ns.route('/postprocess', )
class PostProcesser(Resource):
    def __init__(self, *args, **kargs):
        super().__init__(*args, **kargs)
        self.time_sent = None
        self.cctv_latitude = None
        self.cctv_longitude = None
        self.cctv_name = None
        self.cctv_info = None
        self.mask = None
        self.mask_blob = None
        self.image = None
        self.image_type = None
        self.seg_image = None
        self.area_percent = 0
        self.detected = False
        self.memory = StreamSources(
            buffer_size=15,
            normal_send_interval=10,
            failure_mode_thres=8,
            failure_mode_check_past_n=12,
            normal_mode_thres=8,
            normal_mode_check_past_n=12,
        )
        pass

    @ns.response(200, 'Success')
    @ns.response(400, 'Validation Error')
    def post(self):
        try:
            self.image_type = request.headers.get('Content-Type')
            self.cctv_name = base64.b64decode(request.headers.get('x-cctv-name', '')).decode('UTF-8')
            self.time_sent = request.headers.get('x-time-sent', '')
            self.cctv_latitude = request.headers.get('x-cctv-latitude', 'Not provided')
            self.cctv_longitude = request.headers.get('x-cctv-longitude', 'Not provided')
            self.detected = request.headers.get('X-Flag-Detected')

            if self.detected == "True":
                self.detected = True
            elif self.detected == "False":
                self.detected = False
            else:
                raise ValueError(f"Invalid value for x-flag-detected: {self.detected}")

            self.area_percent = request.headers.get('x-area-percentage')
            try:
                self.area_percent = float(self.area_percent)
            except (TypeError, ValueError) as e:
                raise ValueError(f"Invalid value for x-area-percentage: {self.area_percent}")
            print(len(request.files))
            self.image = request.files.get('image')
            self.mask = request.files.get('mask')
            self.seg_image = request.files.get('seg_mask')
            self.image.save(f"network_test/image_p{time.time()}.png")
            self.mask.save(f"network_test/mask_p{time.time()}.png")
            self.seg_image.save(f"network_test/seg_p{time.time()}.png")

            if not self.image or not self.mask or not self.seg_image:
                raise ValueError("Missing one or more required files: 'image', 'mask', 'seg_mask'")

            self.time_sent = time.time()

            self.cctv_info = {
                'cctv_name': self.cctv_name,
                'cctv_latitude': self.cctv_latitude,
                'cctv_longitude': self.cctv_longitude,
                'source_frame': self.image,
                'frame_mask': self.mask,
                'seg_frame': self.seg_image,
                'time_sent': self.time_sent
            }

            self.memory[self.cctv_info['cctv_name']] = self.cctv_info
            pass_fail = self.pass_fail()
            self.memory.add_status(self.cctv_name, pass_fail)

        except ValueError as e:
            print(e)
        except Exception as e:
            print(e)

    def pass_fail(self):
        thres = 0.1
        #TODO temporal pass_fail threshold
        if self.area_percent > thres:
            ret = 'FAIL'
        else:
            ret = 'OK'
        return ret


if __name__ == "__main__":
    print("Postprocess Online")
    app.run(debug=True, port=13579)