import time
import av
import cv2
import numpy as np
import requests
from zoneinfo import ZoneInfo
from datetime import datetime
from threading import Lock, Thread, Event


class FrameCapturer:
    def __init__(self, hls_url, cctv_id, interval=5, buffer_duration=15, buffer_size=600, time_zone="Asia/Seoul", endpoint="http://localhost:12345/cctv/infer"):
        '''
        :param hls_url: hls address
        :param cctv_id: cctv_id number(whatever it is, this exists to distinguish from where. Further disscusion is needed with frontend developers.)
        :param interval: interval of sampling in seconds
        :param buffer_duration: video buffer, 15 seconds is default for ITS HLS video streaming
        :param time_zone: default Asia/Seoul
        :param endpoint: API endpoint
        '''
        self.hls_url = hls_url
        self.interval = interval
        self.buffer_duration = buffer_duration
        self.buffer_size = buffer_size
        self.frame_buffer = []
        self.current_frame = []
        self.frame_buffer_lock = Lock() # for no memory sharing between receive_stream_packet and process_frames
        self.captured_frame_count = 0
        self.last_capture_time = 0
        self.start_time = time.time()
        self.stop_event = Event()

        self.input_stream = av.open(self.hls_url)
        self.video_stream = next(s for s in self.input_stream.streams if s.type == 'video')
        self.fps = self.video_stream.guessed_rate.numerator
        self.capture_interval = 1 / self.fps

        self.cctvid = cctv_id
        self.time_zone = ZoneInfo(time_zone)
        self.endpoint = endpoint

    def __call__(self, *args, **kwargs):
        return self.current_frame


    # ```receive_stream_packet``` and ```process_frames``` work asynchronously (called with Thread)
    #  so that it always run as intended (for every '''interval''' sec, send a photo)
    #  regardless of how you buffer frames as long as there are enough buffer.
    # They are triggered by ```start``` and halts by ```stop```
    def receive_stream_packet(self):
        for packet in self.input_stream.demux(self.video_stream):
            for frame in packet.decode():
                with self.frame_buffer_lock:
                    self.frame_buffer.append(frame)
                time.sleep(self.capture_interval)

    def process_frames(self):
        while not self.stop_event.is_set():
            current_time = time.time()
            if current_time - self.start_time >= self.buffer_duration:
                if self.last_capture_time == 0 or current_time - self.last_capture_time >= self.interval:
                    with self.frame_buffer_lock:
                        if self.frame_buffer:
                            if len(self.frame_buffer) > self.buffer_size:
                                self.frame_buffer = self.frame_buffer[-self.buffer_size:]
                            buffered_frame = self.frame_buffer[-1]
                            # print(len(self.frame_buffer))
                            self.current_frame = buffered_frame.to_image()
                            self.current_frame = cv2.cvtColor(np.array(self.current_frame), cv2.COLOR_RGB2BGR)
                            frame_name = f"captured_frame_{self.captured_frame_count}.jpg"
                            img_binary = cv2.imencode('.png', self.current_frame)
                            img_binary = img_binary[1].tobytes()
                            self.send_image_to_server(img_binary, self.endpoint)
                            # cv2.imwrite(f'hls_streaming/captured_frame_/{datetime.now()}_{frame_name}', img)
                            self.last_capture_time = current_time
                            print(f"Captured {frame_name} at time: {current_time - self.start_time:.2f}s")
                            self.captured_frame_count +=1

            time.sleep(0.1)

    def send_image_to_server(self, image, endpoint, image_type="png"):
        time_sent = datetime.now(self.time_zone).strftime("yyyy-MM-dd'T'HH:mm:ss'Z'")
        header = {
            'Content-Type': f'image/{image_type}',
            'x-time-sent': time_sent,
            'x-cctv-info': str(self.cctvid),
            'x-cctv-latitude' : '',
            'x-cctv-longitude' : '',
        }
        try:
            file = {
                'image': (f'frame_{self.cctvid}.{image_type}',
                              image,
                              f'image/{image_type}')
                    }
            requests.post(endpoint, headers=header, files=file)
        except:
            print("Can not connect to the analyzer server. Check the endpoint address or connection.\n"
                  f"Can not connect to : {self.endpoint}")

    def start(self):
        self.receive_stream_packet = Thread(target=self.receive_stream_packet)
        self.process_thread = Thread(target=self.process_frames)
        self.receive_stream_packet.start()
        self.process_thread.start()

    def stop(self):
        self.stop_event.set()
        self.receive_stream_packet.join()
        self.process_thread.join()
        self.input_stream.close()


# Example usage
if __name__ == "__main__":
    capturer = FrameCapturer(
        'http://cctvsec.ktict.co.kr/71187/bWDrL7fpStZDeDZgCybpJH8gagWJOynbaA/l91ExpmUPKzc3bCsHJtIblDkzG3Tff2tHy5NNkb6NtYTbie/jNQ0F+PnejViTbKHkpMWNGpc=',
        101, 10
    )
    t1 = time.time()
    try:
        capturer.start()
        time.sleep(600000)
    finally:
        capturer.stop()
        del capturer
        t2 = time.time()
        with open("result.txt", "w") as file:
            file.write(f'{t2-t1} seconds before terminating')
        exit()
