import numpy as np
from flask import Flask, request
from flask_restx import Api, Resource, fields
import os
from datetime import datetime
from yoloseg.inference_ import Inference, overlay_mask
import cv2
import time
import base64
import requests
from requests_toolbelt import MultipartEncoder
from config_files import API_ENDPOINT_MAIN

app = Flask(__name__)
api = Api(app, version='1.0', title='CCTV Image Upload API',
          description='A simple API for receiving CCTV images')

# Namespace definition
ns = api.namespace('cctv', description='CCTV operations')

model_path = 'yoloseg/weight/best.onnx'
classes_txt_file = 'config_files/yolo_config.txt'
image_path = 'yoloseg/img3.jpg'

model_input_shape = (640, 640)
inference_engine = Inference(
    onnx_model_path=model_path,
    model_input_shape=model_input_shape,
    classes_txt_file=classes_txt_file,
    run_with_cuda=True
)

# Define the expected model for incoming data
image_upload_model = api.model('ImageUpload', {
    'image': fields.String(required=True, description='Image file', dt='File'),
    'x-cctv-info': fields.String(required=False, description='CCTV identifier'),
    'x-time-sent': fields.String(required=False, description='Time image was sent'),
    'x-cctv-latitude': fields.String(required=False, description='Latitude of CCTV'),
    'x-cctv-longitude': fields.String(required=False, description='Longitude of CCTV')
})

# Define the directory where images will be saved
IMAGE_DIR = "network_test"
if not os.path.exists(IMAGE_DIR):
    os.makedirs(IMAGE_DIR)

@ns.route('/infer', )
class ImageUpload(Resource):
    # @ns.expect(image_upload_model, validate=True)
    def __init__(self):
        super().__init__(api)
        self.time_sent = None
        self.cctv_latitude = None
        self.cctv_longitude = None
        self.cctv_info = None
        self.mask = None
        self.mask_blob = None
        self.image = None
        self.image_type = None
        self.area_percent = 0

    @ns.response(200, 'Success')
    @ns.response(400, 'Validation Error')
    def post(self):
        if 'file' not in request.files:
            ns.abort(400, 'No image part in the request')
        self.image = request.files['file']
        self.image_type = request.headers.get('Content-Type')
        self.cctv_info = base64.b64decode(request.headers.get('x-cctv-name', '')).decode('UTF-8')
        self.time_sent = request.headers.get('x-time-sent', '')
        self.cctv_latitude = request.headers.get('x-cctv-latitude', 'Not provided')
        self.cctv_longitude = request.headers.get('x-cctv-longitude', 'Not provided')

        # timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        image = self.image.read()
        image = np.frombuffer(image, np.uint8)
        image = cv2.imdecode(image, cv2.IMREAD_COLOR)
        # filename = f"{timestamp}_{self.cctv_info}.png"

        t1 = time.time()
        detections, self.mask = inference_engine.run_inference(image)
        self.mask_blob = cv2.imencode(self.mask)
        self.mask_blob = self.mask.tobytes()
        t2 = time.time()

        print(t2 - t1)

        if len(self.mask) != 0:
            seg_image = overlay_mask(image, self.mask[0], color=(0, 255, 0), alpha=0.3)
            self.area_percent = 0
        else :
            self.area_percent = np.sum(self.mask) / image.shape[0] * image.shape[1]

        self.send_result()
        # write another post request for pushing a detection result
        return {"message": f"Image {self.mask} uploaded successfully!"}

    def send_result(self):
        time_sent = datetime.now(self.time_zone).strftime("yyyy-MM-dd'T'HH:mm:ss'Z'")
        header = {
            'Content-Type': f'{self.image_type}',
            'x-time-sent': time_sent,
            'x-cctv-name': base64.b64encode(str(self.cctv_info).encode('utf-8')).decode('ascii'),
            'x-cctv-latitude': str(self.cctv_latitude),
            'x-cctv-longitude': str(self.cctv_longitude),
            'x-area-percentage' : str(self.area_percent),
        }
        session = requests.Session()

        try:
            multipart_data = MultipartEncoder(
                fields={
                    'image': (
                        f'frame_{self.cctv_info}.{self.image_type}',
                        self.image,
                        f'image/{self.image_type}'
                    ),
                    'mask' : (
                        f'frame_mask_{self.cctv_info}.{self.image_type}',
                        self.mask_blob,
                        f'image/{self.image_type}'
                    )
                }
            )
            header["Content-Type"] = multipart_data.content_type
            response = session.post(self.endpoint, headers=header, data=multipart_data)

        except Exception as e:
            print(e)
            print("Can not connect to the analyzer server. Check the endpoint address or connection.\n"
                  f"Can not connect to : {self.endpoint}")


if __name__ == '__main__':
    app.run(debug=True, port=12345)
