from __future__ import annotations

from collections import OrderedDict
import gc
import logging
import threading
import time
from typing import TYPE_CHECKING

from faster_whisper import WhisperModel

if TYPE_CHECKING:
    from collections.abc import Callable

    from faster_whisper_server.config import (
        WhisperConfig,
    )

logger = logging.getLogger(__name__)

# TODO: enable concurrent model downloads


class SelfDisposingWhisperModel:
    def __init__(
        self,
        model_id: str,
        whisper_config: WhisperConfig,
        *,
        on_unload: Callable[[str], None] | None = None,
    ) -> None:
        self.model_id = model_id
        self.whisper_config = whisper_config
        self.on_unload = on_unload

        self.ref_count: int = 0
        self.rlock = threading.RLock()
        self.expire_timer: threading.Timer | None = None
        self.whisper: WhisperModel | None = None

    def unload(self) -> None:
        with self.rlock:
            if self.whisper is None:
                raise ValueError(f"Model {self.model_id} is not loaded. {self.ref_count=}")
            if self.ref_count > 0:
                raise ValueError(f"Model {self.model_id} is still in use. {self.ref_count=}")
            if self.expire_timer:
                self.expire_timer.cancel()
            self.whisper = None
            # WARN: ~300 MB of memory will still be held by the model. See https://github.com/SYSTRAN/faster-whisper/issues/992
            gc.collect()
            logger.info(f"Model {self.model_id} unloaded")
            if self.on_unload is not None:
                self.on_unload(self.model_id)

    def _load(self) -> None:
        with self.rlock:
            assert self.whisper is None
            logger.debug(f"Loading model {self.model_id}")
            start = time.perf_counter()
            self.whisper = WhisperModel(
                self.model_id,
                device=self.whisper_config.inference_device,
                device_index=self.whisper_config.device_index,
                compute_type=self.whisper_config.compute_type,
                cpu_threads=self.whisper_config.cpu_threads,
                num_workers=self.whisper_config.num_workers,
            )
            logger.info(f"Model {self.model_id} loaded in {time.perf_counter() - start:.2f}s")

    def _increment_ref(self) -> None:
        with self.rlock:
            self.ref_count += 1
            if self.expire_timer:
                logger.debug(f"Model was set to expire in {self.expire_timer.interval}s, cancelling")
                self.expire_timer.cancel()
            logger.debug(f"Incremented ref count for {self.model_id}, {self.ref_count=}")

    def _decrement_ref(self) -> None:
        with self.rlock:
            self.ref_count -= 1
            logger.debug(f"Decremented ref count for {self.model_id}, {self.ref_count=}")
            if self.ref_count <= 0:
                if self.whisper_config.ttl > 0:
                    logger.info(f"Model {self.model_id} is idle, scheduling offload in {self.whisper_config.ttl}s")
                    self.expire_timer = threading.Timer(self.whisper_config.ttl, self.unload)
                    self.expire_timer.start()
                elif self.whisper_config.ttl == 0:
                    logger.info(f"Model {self.model_id} is idle, unloading immediately")
                    self.unload()
                else:
                    logger.info(f"Model {self.model_id} is idle, not unloading")

    def __enter__(self) -> WhisperModel:
        with self.rlock:
            if self.whisper is None:
                self._load()
            self._increment_ref()
            assert self.whisper is not None
            return self.whisper

    def __exit__(self, *_args) -> None:  # noqa: ANN002
        self._decrement_ref()


class ModelManager:
    def __init__(self, whisper_config: WhisperConfig) -> None:
        self.whisper_config = whisper_config
        self.loaded_models: OrderedDict[str, SelfDisposingWhisperModel] = OrderedDict()
        self._lock = threading.Lock()

    def _handle_model_unload(self, model_name: str) -> None:
        with self._lock:
            if model_name in self.loaded_models:
                del self.loaded_models[model_name]

    def unload_model(self, model_name: str) -> None:
        with self._lock:
            model = self.loaded_models.get(model_name)
            if model is None:
                raise KeyError(f"Model {model_name} not found")
            self.loaded_models[model_name].unload()

    def load_model(self, model_name: str) -> SelfDisposingWhisperModel:
        with self._lock:
            if model_name in self.loaded_models:
                logger.debug(f"{model_name} model already loaded")
                return self.loaded_models[model_name]
            self.loaded_models[model_name] = SelfDisposingWhisperModel(
                model_name,
                self.whisper_config,
                on_unload=self._handle_model_unload,
            )
            return self.loaded_models[model_name]
