import time

import cv2
import numpy as np
import random
import onnxruntime as ort
from config_files.yolo_config import CLASS_NUM
from typing import List, Tuple


class Inference:
    def __init__(self, onnx_model_path, model_input_shape, classes_txt_file, run_with_cuda):
        self.model_path = onnx_model_path
        self.model_shape = model_input_shape
        self.classes_path = classes_txt_file
        self.cuda_enabled = run_with_cuda
        self.letter_box_for_square = True
        self.model_score_threshold = 0.3
        self.model_nms_threshold = 0.6
        self.classes = []
        self.session = None

        self.load_onnx_network()
        self.load_classes_from_file()

    def sigmoid(self, x):
        return 1 / (1 + np.exp(-x))

    def run_inference(self, input_image):
        model_input = input_image
        # print(input_image)
        if self.letter_box_for_square and self.model_shape[0] == self.model_shape[1]:
            model_input = self.format_to_square(model_input)

        blob = cv2.dnn.blobFromImage(model_input, 1.0 / 255.0, self.model_shape, (0, 0, 0), True, False)

        # Prepare input data as a dictionary
        inputs = {self.session.get_inputs()[0].name: blob}
        # Run model
        outputs = self.session.run(None, inputs)
        outputs_bbox = outputs[0]
        outputs_mask = outputs[1]
        detections = self.process_detections(outputs_bbox, model_input)
        mask_maps = self.process_mask_output(detections, outputs_mask, model_input.shape)
        return detections, mask_maps

    def load_onnx_network(self):
        # Set up the ONNX Runtime session with appropriate device settings
        try:
            if self.cuda_enabled:
                providers = [('CUDAExecutionProvider', {'device_id': 0})]
            else:
                providers = ['CPUExecutionProvider']

            self.session = ort.InferenceSession(self.model_path, providers=providers)
            print(f"Running on {'CUDA' if self.cuda_enabled else 'CPU'}")
            print(f"Model loaded successfully. Input name: {self.session.get_inputs()[0].name}")
        except Exception as e:
            print(f"Failed to load the ONNX model: {e}")
            self.session = None

    def load_classes_from_file(self):
        with open(self.classes_path, 'r') as f:
            self.classes = f.read().strip().split('\n')

    def format_to_square(self, source):
        col, row = source.shape[1], source.shape[0]
        max_side = max(col, row)
        result = np.zeros((max_side, max_side, 3), dtype=np.uint8)
        result[0:row, 0:col] = source
        return result

    def process_detections(self, outputs_bbox, model_input):
        # Assuming outputs_bbox is already in the (x, y, w, h, confidence, class_probs...) format
        x_factor = model_input.shape[1] / self.model_shape[0]
        y_factor = model_input.shape[0] / self.model_shape[1]

        t1 = time.time()
        # Assuming outputs_bbox is an array with shape (N, 4+CLASS_NUM+32) where N is the number of detections
        # Example outputs_bbox.shape -> (batch_size, 4+CLASS_NUM+32, 8400)

        # Extract basic bbox coordinates and scores
        x, y, w, h = outputs_bbox[:, 0], outputs_bbox[:, 1], outputs_bbox[:, 2], outputs_bbox[:, 3]
        scores = outputs_bbox[:, 4:4 + CLASS_NUM]

        # Calculate confidences and class IDs
        confidences = np.max(scores, axis=1)
        class_ids = np.argmax(scores, axis=1)

        # Filter out small boxes
        min_width, min_height = 20, 20
        valid_size = (w >= min_width) & (h >= min_height)

        # Apply confidence threshold
        valid_confidence = (confidences > self.model_score_threshold)

        # Combine all conditions
        valid_detections = valid_size & valid_confidence

        # proto_mask_score
        scores_segmentation = outputs_bbox[:, 4 + CLASS_NUM:]

        # Filter arrays based on valid detections
        filtered_x = x[valid_detections]
        filtered_y = y[valid_detections]
        filtered_w = w[valid_detections]
        filtered_h = h[valid_detections]
        filtered_confidences = confidences[valid_detections]
        filtered_class_ids = class_ids[valid_detections]
        filtered_mask_coefficient = np.transpose(scores_segmentation, (2,0,1))[valid_detections.T]


        # Calculate adjusted box coordinates
        left = (filtered_x - 0.5 * filtered_w) * x_factor
        top = (filtered_y - 0.5 * filtered_h) * y_factor
        width = filtered_w * x_factor
        height = filtered_h * y_factor

        # Prepare final arrays
        boxes = np.vstack([left, top, width, height]).T

        # Change it into int for mask operation
        boxes = boxes.astype(int)
        boxes = boxes.tolist()
        filtered_confidences = filtered_confidences.tolist()
        filtered_class_ids = filtered_class_ids.tolist()
        if not len(boxes) <= 0 :
            indices = cv2.dnn.NMSBoxes(boxes, filtered_confidences, self.model_score_threshold, self.model_nms_threshold)
        else:
            indices = []

        detections = []
        for i in indices:
            idx = i
            result = {
                'class_id': filtered_class_ids[i],
                'confidence': filtered_confidences[i],
                'mask_coefficients': np.array(filtered_mask_coefficient[i]),
                'box': boxes[idx],
                'class_name': self.classes[filtered_class_ids[i]],
                'color': (random.randint(100, 255), random.randint(100, 255), random.randint(100, 255))
            }
            detections.append(result)

        return detections

    def process_mask_output(self, detections, proto_masks, image_shape):
        if not detections:
            return []

        batch_size, num_protos, proto_height, proto_width = proto_masks.shape
        full_masks = np.zeros((len(detections), image_shape[0], image_shape[1]), dtype=np.float32)

        for idx, det in enumerate(detections):
            box = det['box']

            x1, y1, w, h = self.adjust_box_coordinates(box, (image_shape[0], image_shape[1]))

            if w <=1 or h <= 1:
                continue

            # Get the corresponding mask coefficients for this detection
            coeffs = det["mask_coefficients"]

            # Compute the linear combination of proto masks
            # for now, plural batch operation is not supported, and this is the point where you should start.
            # instead of hardcoded proto_masks[0], do some iterative/vectorize operation
            mask = np.tensordot(coeffs, proto_masks[0], axes=[0, 0])  # Dot product along the number of prototypes
            resized_mask = cv2.resize(mask,(image_shape[0], image_shape[1]))
            # Resize mask to the bounding box size, using sigmoid to normalize
            cropped_mask = resized_mask[y1:y1+h, x1:x1+w]
            resized_mask = self.sigmoid(cropped_mask)

            # Threshold to create a binary mask
            final_mask = (resized_mask > 0.5).astype(np.uint8)

            # Place the mask in the corresponding location on a full-sized mask image_binary
            full_mask = np.zeros((image_shape[0], image_shape[1]), dtype=np.uint8)
            full_mask[y1:y1+h, x1:x1+w] = final_mask

            # Combine the mask with the masks of other detections
            full_masks[idx] = full_mask


        all_mask = full_masks.sum(axis=0)
        all_mask = np.clip(all_mask, 0, 1)
        # Append a dimension so that cv2 can understand ```all_mask``` argument as an image.
        # This is because for this particular application, there is only single class ```water_body```
        # However, if that is not the case, you must modify this part.
        all_mask = all_mask.reshape((image_shape[0], image_shape[1], 1))
        return all_mask.astype(np.uint8)

    def adjust_box_coordinates(self, box: List[int], image_shape: Tuple[int, int]) -> Tuple[int, int, int, int]:
        """
        Adjusts bounding box coordinates to ensure they lie within image boundaries.
        """
        x1, y1, w, h = box
        x2, y2 = x1 + w, y1 + h

        # Clamp coordinates to image boundaries
        x1 = max(0, x1)
        y1 = max(0, y1)
        x2 = min(image_shape[1], x2)
        y2 = min(image_shape[0], y2)

        # Recalculate width and height
        w = x2 - x1
        h = y2 - y1

        return x1, y1, w, h


    def load_classes_from_file(self):
        with open(self.classes_path, 'r') as f:
            self.classes = f.read().strip().split('\n')

    def format_to_square(self, source):
        col, row = source.shape[1], source.shape[0]
        max_side = max(col, row)
        result = np.zeros((max_side, max_side, 3), dtype=np.uint8)
        result[0:row, 0:col] = source
        return result

def overlay_mask(image, mask, color=(0, 255, 0), alpha=0.5):
    """
    Overlays a mask onto an image_binary using a specified color and transparency level.

    Parameters:
        image (np.ndarray): The original image_binary.
        mask (np.ndarray): The mask to overlay. Must be the same size as the image_binary.
        color (tuple): The color for the mask overlay in BGR format (default is green).
        alpha (float): Transparency factor for the mask; 0 is fully transparent, 1 is opaque.

    Returns:
        np.ndarray: The image_binary with the overlay.
    """
    assert alpha <= 1 and 0 <= alpha, (f"Error! invalid alpha value, it must be float, inbetween including 0 to 1, "
                                       f"\n given alpha : {alpha}")

    # Ensure the mask is a binary mask
    mask = (mask > 0).astype(np.uint8)  # Convert mask to binary if not already

    # Create an overlay with the same size as the image_binary but only using the mask area
    overlay = np.zeros_like(image, dtype=np.uint8)
    overlay[mask == 1] = color

    # Blend the overlay with the image_binary using the alpha factor
    return cv2.addWeighted(src1=overlay, alpha=alpha, src2=image, beta=1 - alpha, gamma=0)


def test():
    import time
    import glob
    import os
    # Path to your ONNX model and classes text file
    model_path = 'yoloseg/weight/best.onnx'
    classes_txt_file = 'config_files/yolo_config.txt'
    # image_path = 'yoloseg/img3.jpg'
    image_path = 'yoloseg/img.jpg'

    model_input_shape = (480, 480)
    inference_engine = Inference(
        onnx_model_path=model_path,
        model_input_shape=model_input_shape,
        classes_txt_file=classes_txt_file,
        run_with_cuda=True
    )

    # Load an image_binary
    img = cv2.imread(image_path)
    if img is None:
        print("Error loading image_binary")
        return
    img = cv2.resize(img, model_input_shape)
    # Run inference

    # for i in range(10):
    #     t1 = time.time()
    #     detections, mask_maps = inference_engine.run_inference(img)
    #     t2 = time.time()
    #     print(t2 - t1)

    images = glob.glob("/home/juni/사진/flood/out-/*.jpg")
    images = sorted(images)
    for k, image in enumerate(images):
        image = cv2.imread(image)
        t1 = time.time()
        image = cv2.resize(image, model_input_shape)
        detections, mask_maps = inference_engine.run_inference(image)

        # Display results
        for detection in detections:
            x, y, w, h = detection['box']
            class_name = detection['class_name']
            confidence = detection['confidence']
            cv2.rectangle(image, (x, y), (x+w, y+h), detection['color'], 2)
            label = f"{class_name}: {confidence:.2f}"
            cv2.putText(image, label, (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, detection['color'], 2)
        if len(mask_maps) != 0:
            for i in range(mask_maps.shape[2]):  # Iterate over each mask
                seg_image = overlay_mask(image, mask_maps[:, :, i], color=(0, 255, 0), alpha=0.3)
                # cv2.imshow(f"Segmentation {i + 1}", seg_image)
            # cv2.waitKey(0)  # Wait for a key press before showing the next mask
            # cv2.destroyAllWindows()
        t2 = time.time()
        print(t2 - t1)
        cv2.imwrite( f"/home/juni/사진/flood/infer/{k}.jpg", seg_image)

    # Show the image_binary
    # cv2.imshow('Detections', img)
    # cv2.waitKey(0)
    # cv2.destroyAllWindows()




if __name__ == "__main__":
    test()