from torch.utils.data import Dataset
from torchvision.io import read_image

from typing import Union, List
import numpy as np


class ImagePairDataset(Dataset):
    def __init__(self, clean_img_dir: List[str], rainy_img_dirs: Union[List[str], List[List[str]]], transform=None):
        self.clean_img = clean_img_dir
        self.dirty_img = rainy_img_dirs
        self.transform = transform
        self.rng = np.random.default_rng()

    def __len__(self):
        return len(self.clean_img)

    def __getitem__(self, idx):
        clean_img_path = self.clean_img[idx]
        rainy_img_paths = self.dirty_img[idx]

        if isinstance(rainy_img_paths, list):
            i = self.rng.integers(low=0, high=len(rainy_img_paths) - 1)
            rainy_img_path = rainy_img_paths[i]
        else:
            rainy_img_path = rainy_img_paths

        clean_image = read_image(clean_img_path)
        rainy_image = read_image(rainy_img_path)

        if self.transform:
            clean_image = self.transform(clean_image)
            rainy_image = self.transform(rainy_image)

        return {
            "clean_image": clean_image,
            "rainy_image": rainy_image
        }

    def __add__(self, other):
        assert isinstance(other, ImagePairDataset), "other must be an instance of ImagePairDataset"
        return ImagePairDataset(
            clean_img_dir=self.clean_img + other.clean_img,
            rainy_img_dirs=self.dirty_img + other.dirty_img,
            transform=self.transform or other.transform
        )