from torch.utils.data import Dataset
from torchvision.io import read_image
import numpy as np

# the dataset for this model needs to be in at least, a pair of clean and dirty image.
# in other words, 'annotations' is not really appropriate word
# however, the problem here is that we need should be able to pair more than one dirty image per clean image,
# I am lost now

class ImageDataSet(Dataset):
    def __init__(self, img_dir, annotations, transform=None):
        self.annotations = annotations
        self.img_dir = img_dir
        self.transform = transform
        # self.target_transform = target_transform
        # print(self.transform)

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        img_path = self.img_dir[idx]
        image = read_image(img_path)
        label = self.annotations[idx]
        if self.transform:
            image = self.transform(image)
        # if self.target_transform(label):
        #     label = self.target_transform(label)
        return image, label

    def __add__(self, other):
        return ImageDataSet(
            img_dir=self.img_dir+other.img_dir,
            annotations=np.array(list(self.annotations)+list(other.annotations)),
            transform=self.transform
            )
