from torch.utils.data import Dataset
from torchvision.io import read_image
import numpy as np


class ImagePairDataset(Dataset):
    def __init__(self, clean_img_dir, rainy_img_dirs, transform=None):
        self.clean_img = clean_img_dir
        self.rainy_img = rainy_img_dirs
        self.transform = transform

    def __len__(self):
        return len(self.clean_img)

    def __getitem__(self, idx):
        clean_img_path = self.clean_img[idx]

        i = 0
        if len(self.rainy_img[idx]) is list:
            rng = np.random.default_rng()
            i = rng.integers(low=0, high=len(self.rainy_img)-1)
            rainy_img_path = self.rainy_img[idx][i]
        else:
            rainy_img_path = self.rainy_img[idx]
        clean_image = read_image(clean_img_path)
        rainy_image = read_image(rainy_img_path)
        if self.transform:
            clean_image = self.transform(clean_image)
            rainy_image = self.transform(rainy_image)

        ret = {
            "clean_image" : clean_image,
            "rainy_image" : rainy_image
        }
        return ret

    def __add__(self, other):
        return ImagePairDataset(
            clean_img_dir=self.clean_img+other.clean_img,
            rainy_img_dirs=self.rainy_img+other.rainy_img,
            transform=self.transform
            )
