import cv2
import numpy as np
import random
from config import CLASS_NAME, CLASS_NUM

class Inference:
    def __init__(self, onnx_model_path, model_input_shape, classes_txt_file, run_with_cuda):
        self.model_path = onnx_model_path
        self.model_shape = model_input_shape
        self.classes_path = classes_txt_file
        self.cuda_enabled = run_with_cuda
        self.letter_box_for_square = True
        self.model_score_threshold = 0.3
        self.model_nms_threshold = 0.5
        self.classes = []

        self.load_onnx_network()
        self.load_classes_from_file()

    def sigmoid(self, x):
        return 1 / (1 + np.exp(-x))

    def run_inference(self, input_image):
        model_input = input_image
        if self.letter_box_for_square and self.model_shape[0] == self.model_shape[1]:
            model_input = self.format_to_square(model_input)

        blob = cv2.dnn.blobFromImage(model_input, 1.0 / 255.0, self.model_shape, (0, 0, 0), True, False)
        self.net.setInput(blob)

        outputs = self.net.forward(self.net.getUnconnectedOutLayersNames())
        outputs_bbox = outputs[0]
        outputs_mask = outputs[1]

        detections = self.process_detections(outputs_bbox, model_input)
        mask_maps = self.process_mask_output(detections, outputs_mask, model_input.shape)

        return detections, mask_maps

    def process_detections(self, outputs_bbox, model_input):
        # Assuming outputs_bbox is already in the (x, y, w, h, confidence, class_probs...) format
        x_factor = model_input.shape[1] / self.model_shape[0]
        y_factor = model_input.shape[0] / self.model_shape[1]

        class_ids = []
        confidences = []
        mask_coefficients = []
        boxes = []

        for detection in outputs_bbox[0].T:
            # This segmentation model uses yolact architecture to predict mask
            # the output tensor dimension for yolo-v8-seg is B x [X, Y, W, H, C1, C2, ..., P1, ...,P32] * 8400
            # where C{n} are confidence score for each class
            # and P{n} are coefficient for each proto masks. (32 by default)
            scores_classification = detection[4:4+CLASS_NUM]
            scores_segmentation = detection[4+CLASS_NUM:]
            class_id = np.argmax(scores_classification, axis=0)
            confidence = scores_classification[class_id]

            thres = self.model_score_threshold
            if confidence > thres:
                x, y, w, h = detection[:4]
                left = int((x - 0.5 * w) * x_factor)
                top = int((y - 0.5 * h) * y_factor)
                width = int(w * x_factor)
                height = int(h * y_factor)

                boxes.append([left, top, width, height])
                confidences.append(float(confidence))
                mask_coefficients.append(scores_segmentation)
                class_ids.append(class_id)
        confidences = (confidences)
        indices = cv2.dnn.NMSBoxes(boxes, confidences, self.model_score_threshold, self.model_nms_threshold)

        detections = []
        for i in indices:
            idx = i
            result = {
                'class_id': class_ids[i],
                'confidence': confidences[i],
                'mask_coefficients': np.array(mask_coefficients[i]),
                'box': boxes[idx],
                'class_name': self.classes[class_ids[i]],
                'color': (random.randint(100, 255), random.randint(100, 255), random.randint(100, 255))
            }
            detections.append(result)

        return detections

    def process_mask_output(self, detections, proto_masks, image_shape):
        if not detections:
            return []

        batch_size, num_protos, proto_height, proto_width = proto_masks.shape  # Correct shape unpacking
        full_masks = np.zeros((len(detections), image_shape[0], image_shape[1]), dtype=np.float32)

        for idx, det in enumerate(detections):
            box = det['box']
            x1, y1, w, h = box

            #... why the model outputs ... negative values?...
            if x1 <= 0 :
                x1 = 0
            if y1 <= 0 :
                y1 = 0
            x1, y1, x2, y2 = x1, y1, x1 + w, y1 + h

            # To handle edge cases where you get bboxes that pass beyond the original image_binary
            if y2 > image_shape[1]:
                h = h + image_shape[1] - h - y1
            if x2 > image_shape[0]:
                w = w + image_shape[1] - w - y1

            # Get the corresponding mask coefficients for this detection
            coeffs = det["mask_coefficients"]

            # Compute the linear combination of proto masks
            # for now, plural batch operation is not supported, and this is the point where you should start.
            # instead of hardcoded proto_masks[0], do some iterative operation.
            mask = np.tensordot(coeffs, proto_masks[0], axes=[0, 0])  # Dot product along the number of prototypes

            # Resize mask to the bounding box size, using sigmoid to normalize
            resized_mask = cv2.resize(mask, (w, h))
            resized_mask = self.sigmoid(resized_mask)

            # Threshold to create a binary mask
            final_mask = (resized_mask > 0.5).astype(np.uint8)

            # Place the mask in the corresponding location on a full-sized mask image_binary
            full_mask = np.zeros((image_shape[0], image_shape[1]), dtype=np.uint8)
            print(final_mask.shape)
            print(full_mask[y1:y2, x1:x2].shape)
            full_mask[y1:y2, x1:x2] = final_mask

            # Combine the mask with the masks of other detections
            full_masks[idx] = full_mask

        return full_masks

    def load_classes_from_file(self):
        with open(self.classes_path, 'r') as f:
            self.classes = f.read().strip().split('\n')

    def load_onnx_network(self):
        self.net = cv2.dnn.readNetFromONNX(self.model_path)
        if self.cuda_enabled:
            print("\nRunning on CUDA")
            self.net.setPreferableBackend(cv2.dnn.DNN_BACKEND_CUDA)
            self.net.setPreferableTarget(cv2.dnn.DNN_TARGET_CUDA)
        else:
            print("\nRunning on CPU")
            self.net.setPreferableBackend(cv2.dnn.DNN_BACKEND_OPENCV)
            self.net.setPreferableTarget(cv2.dnn.DNN_TARGET_CPU)

    def format_to_square(self, source):
        col, row = source.shape[1], source.shape[0]
        max_side = max(col, row)
        result = np.zeros((max_side, max_side, 3), dtype=np.uint8)
        result[0:row, 0:col] = source
        return result

def overlay_mask(image, mask, color=(0, 255, 0), alpha=0.5):
    """
    Overlays a mask onto an image_binary using a specified color and transparency level.

    Parameters:
        image (np.ndarray): The original image_binary.
        mask (np.ndarray): The mask to overlay. Must be the same size as the image_binary.
        color (tuple): The color for the mask overlay in BGR format (default is green).
        alpha (float): Transparency factor for the mask; 0 is fully transparent, 1 is opaque.

    Returns:
        np.ndarray: The image_binary with the overlay.
    """
    # Ensure the mask is a binary mask
    mask = (mask > 0).astype(np.uint8)  # Convert mask to binary if not already

    # Create an overlay with the same size as the image_binary but only using the mask area
    overlay = np.zeros_like(image, dtype=np.uint8)
    overlay[mask == 1] = color

    # Blend the overlay with the image_binary using the alpha factor
    return cv2.addWeighted(src1=overlay, alpha=alpha, src2=image, beta=1 - alpha, gamma=0)


def test():
    import time

    # Path to your ONNX model and classes text file
    model_path = 'yoloseg/weight/best.onnx'
    classes_txt_file = 'yoloseg/config/classes.txt'
    image_path = 'yoloseg/img3.jpg'

    model_input_shape = (640, 640)
    inference_engine = Inference(
        onnx_model_path=model_path,
        model_input_shape=model_input_shape,
        classes_txt_file=classes_txt_file,
        run_with_cuda=True
    )

    # Load an image_binary
    img = cv2.imread(image_path)
    if img is None:
        print("Error loading image_binary")
        return
    img = cv2.resize(img, model_input_shape)
    # Run inference
    t1 = time.time()
    detections, mask_maps = inference_engine.run_inference(img)
    t2 = time.time()

    print(t2-t1)

    # Display results
    for detection in detections:
        x, y, w, h = detection['box']
        class_name = detection['class_name']
        confidence = detection['confidence']
        cv2.rectangle(img, (x, y), (x+w, y+h), detection['color'], 2)
        label = f"{class_name}: {confidence:.2f}"
        cv2.putText(img, label, (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, detection['color'], 2)

    # Show the image_binary
    # cv2.imshow('Detections', img)
    # cv2.waitKey(0)
    # cv2.destroyAllWindows()

    # If you also want to display segmentation maps, you would need additional handling here
    # Example for displaying first mask if available:
    if mask_maps is not None:

        seg_image = overlay_mask(img, mask_maps[0], color=(0, 255, 0), alpha=0.3)
        cv2.imshow("segmentation", seg_image)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

def test2():
    import time
    import glob

    # Path to your ONNX model and classes text file
    model_path = 'yoloseg/weight/best.onnx'
    classes_txt_file = 'yoloseg/config/classes.txt'

    model_input_shape = (640, 640)
    inference_engine = Inference(
        onnx_model_path=model_path,
        model_input_shape=model_input_shape,
        classes_txt_file=classes_txt_file,
        run_with_cuda=True
    )

    image_dir = glob.glob("/home/juni/사진/sample_data/ex1/*.png")

    for iteration, image_path in enumerate(image_dir):
        img = cv2.imread(image_path)
        if img is None:
            print("Error loading image_binary")
            return
        img = cv2.resize(img, model_input_shape)
        # Run inference
        t1 = time.time()
        detections, mask_maps = inference_engine.run_inference(img)
        t2 = time.time()

        print(t2-t1)

        # Display results
        # for detection in detections:
        #     x, y, w, h = detection['box']
        #     class_name = detection['class_name']
        #     confidence = detection['confidence']
        #     cv2.rectangle(img, (x, y), (x+w, y+h), detection['color'], 2)
        #     label = f"{class_name}: {confidence:.2f}"
        #     cv2.putText(img, label, (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, detection['color'], 2)
        #
        # if len(mask_maps) > 0 :
        #     seg_image = overlay_mask(img, mask_maps[0], color=(0, 255, 0), alpha=0.3)
        #     cv2.imwrite(f"result/{iteration}.png", seg_image)


if __name__ == "__main__":
    test2()