import time
import av
import cv2
import numpy as np
import requests
from requests_toolbelt.multipart.encoder import MultipartEncoder
from zoneinfo import ZoneInfo
from datetime import datetime
from threading import Lock, Thread, Event


class FrameCapturer:
    def __init__(
            self,
            hls_url,
            cctv_name,
            lat,
            lon,
            interval=5,
            buffer_duration=15,
            buffer_size=600,
            time_zone="Asia/Seoul",
            endpoint="http://localhost:12345/cctv/infer"
    ):
        '''
        :param hls_url: hls address
        :param cctv_name: CCTV_name gathered from ITS api
        :param interval: interval of sampling in seconds
        :param buffer_duration: video buffer, 15 seconds is default for ITS HLS video streaming
        :param time_zone: default Asia/Seoul
        :param endpoint: API endpoint
        '''
        self.hls_url = hls_url
        self.interval = interval
        self.buffer_duration = buffer_duration
        self.buffer_size = buffer_size
        self.frame_buffer = []
        self.current_frame = []
        self.frame_buffer_lock = Lock() # for no memory sharing between receive_stream_packet and process_frames
        self.captured_frame_count = 0
        self.last_capture_time = 0
        self.start_time = time.time()
        self.stop_event = Event()

        self.input_stream = av.open(self.hls_url)
        self.video_stream = next(s for s in self.input_stream.streams if s.type == 'video')
        self.fps = self.video_stream.guessed_rate.numerator
        self.capture_interval = 1 / self.fps

        self.cctvid = cctv_name
        self.time_zone = ZoneInfo(time_zone)
        self.endpoint = endpoint

        self.lat = lat
        self.lon = lon

    def __call__(self, *args, **kwargs):
        return self.current_frame


    # ```receive_stream_packet``` and ```process_frames``` work asynchronously (called with Thread)
    #  so that it always run as intended (for every '''interval''' sec, send a photo)
    #  regardless of how you buffer frames as long as there are enough buffer.
    # They are triggered by ```start``` and halts by ```stop```
    def receive_stream_packet(self):
        for packet in self.input_stream.demux(self.video_stream):
            for frame in packet.decode():
                with self.frame_buffer_lock:
                    self.frame_buffer.append(frame)
                time.sleep(self.capture_interval)

    def process_frames(self):
        while not self.stop_event.is_set():
            current_time = time.time()
            if current_time - self.start_time >= self.buffer_duration:
                if self.last_capture_time == 0 or current_time - self.last_capture_time >= self.interval:
                    with self.frame_buffer_lock:
                        if self.frame_buffer:
                            if len(self.frame_buffer) > self.buffer_size:
                                self.frame_buffer = self.frame_buffer[-self.buffer_size:]
                            buffered_frame = self.frame_buffer[-1]
                            # print(len(self.frame_buffer))
                            self.current_frame = buffered_frame.to_image()
                            self.current_frame = cv2.cvtColor(np.array(self.current_frame), cv2.COLOR_RGB2BGR)
                            frame_name = f"captured_frame_{self.captured_frame_count}.jpg"
                            img_binary = cv2.imencode('.png', self.current_frame)
                            img_binary = img_binary[1].tobytes()
                            self.send_image_to_server(img_binary, self.endpoint)
                            cv2.imwrite(f'hls_streaming/captured_frame_/{datetime.now()}_{frame_name}', self.current_frame)
                            self.last_capture_time = current_time
                            print(f"Captured {frame_name} at time: {current_time - self.start_time:.2f}s")
                            self.captured_frame_count +=1

            time.sleep(0.1)

    def send_image_to_server(self, image, endpoint, image_type="png"):
        time_sent = datetime.now(self.time_zone).strftime("yyyy-MM-dd'T'HH:mm:ss'Z'")
        header = {
            'Content-Type': f'image/{image_type}',
            'x-time-sent': time_sent,
            'x-cctv-info': str(self.cctvid),
            'x-cctv-latitude' : self.lat,
            'x-cctv-longitude' : self.lon,
        }
        session = requests.Session()
        try:
            multipart_data = MultipartEncoder(
                fields = {
                    'file': (f'frame_{self.cctvid}.{image_type}',
                                  image,
                                  f'image/{image_type}')
                        }
            )
            header["Content-Type"] = multipart_data.content_type
            response = session.post(endpoint, headers=header, data=multipart_data)

        except Exception as e:
            print(e)
            print("Can not connect to the analyzer server. Check the endpoint address or connection.\n"
                  f"Can not connect to : {self.endpoint}")

    def start(self):
        self.receive_stream_packet = Thread(target=self.receive_stream_packet)
        self.process_thread = Thread(target=self.process_frames)
        self.receive_stream_packet.start()
        self.process_thread.start()

    def stop(self):
        self.stop_event.set()
        self.receive_stream_packet.join()
        self.process_thread.join()
        self.input_stream.close()


# Example usage
if __name__ == "__main__":
    capturer = FrameCapturer(
        'http://cctvsec.ktict.co.kr/5545/LFkDslDT81tcSYh3G4306+mcGlLb3yShF9rx2vcPfltwUL4+I950kcBlD15uWm6K0cKCtAMlxsIptMkCDo5lGQiLlARP+SyUloz8vIMNB18=',
        101, 10, 5
    )
    t1 = time.time()
    try:
        capturer.start()
        time.sleep(600000)
    finally:
        capturer.stop()
        del capturer
        t2 = time.time()
        with open("result.txt", "w") as file:
            file.write(f'{t2-t1} seconds before terminating')
        exit()
