L1-SR-GUI/SRGAN-PyTorch-master/dataset.py
2020-11-12 23:30:33 +01:00

135 lines
4 KiB
Python

import torch
from torch.utils.data import Dataset
import os
from PIL import Image
import numpy as np
import random
class mydata(Dataset):
def __init__(self, LR_path, GT_path, in_memory = True, transform = None):
self.LR_path = LR_path
self.GT_path = GT_path
self.in_memory = in_memory
self.transform = transform
self.LR_img = sorted(os.listdir(LR_path))
self.GT_img = sorted(os.listdir(GT_path))
if in_memory:
self.LR_img = [np.array(Image.open(os.path.join(self.LR_path, lr)).convert("RGB")).astype(np.uint8) for lr in self.LR_img]
self.GT_img = [np.array(Image.open(os.path.join(self.GT_path, gt)).convert("RGB")).astype(np.uint8) for gt in self.GT_img]
def __len__(self):
return len(self.LR_img)
def __getitem__(self, i):
img_item = {}
if self.in_memory:
GT = self.GT_img[i].astype(np.float32)
LR = self.LR_img[i].astype(np.float32)
else:
GT = np.array(Image.open(os.path.join(self.GT_path, self.GT_img[i])).convert("RGB"))
LR = np.array(Image.open(os.path.join(self.LR_path, self.LR_img[i])).convert("RGB"))
img_item['GT'] = (GT / 127.5) - 1.0
img_item['LR'] = (LR / 127.5) - 1.0
if self.transform is not None:
img_item = self.transform(img_item)
img_item['GT'] = img_item['GT'].transpose(2, 0, 1).astype(np.float32)
img_item['LR'] = img_item['LR'].transpose(2, 0, 1).astype(np.float32)
return img_item
class testOnly_data(Dataset):
def __init__(self, LR_path, in_memory = True, transform = None):
self.LR_path = LR_path
self.LR_img = sorted(os.listdir(LR_path))
self.in_memory = in_memory
if in_memory:
self.LR_img = [np.array(Image.open(os.path.join(self.LR_path, lr))) for lr in self.LR_img]
def __len__(self):
return len(self.LR_img)
def __getitem__(self, i):
img_item = {}
if self.in_memory:
LR = self.LR_img[i]
else:
LR = np.array(Image.open(os.path.join(self.LR_path, self.LR_img[i])))
img_item['LR'] = (LR / 127.5) - 1.0
img_item['LR'] = img_item['LR'].transpose(2, 0, 1).astype(np.float32)
return img_item
class crop(object):
def __init__(self, scale, patch_size):
self.scale = scale
self.patch_size = patch_size
def __call__(self, sample):
LR_img, GT_img = sample['LR'], sample['GT']
ih, iw = LR_img.shape[:2]
ix = random.randrange(0, iw - self.patch_size +1)
iy = random.randrange(0, ih - self.patch_size +1)
tx = ix * self.scale
ty = iy * self.scale
LR_patch = LR_img[iy : iy + self.patch_size, ix : ix + self.patch_size]
GT_patch = GT_img[ty : ty + (self.scale * self.patch_size), tx : tx + (self.scale * self.patch_size)]
return {'LR' : LR_patch, 'GT' : GT_patch}
class augmentation(object):
def __call__(self, sample):
LR_img, GT_img = sample['LR'], sample['GT']
hor_flip = random.randrange(0,2)
ver_flip = random.randrange(0,2)
rot = random.randrange(0,2)
if hor_flip:
temp_LR = np.fliplr(LR_img)
LR_img = temp_LR.copy()
temp_GT = np.fliplr(GT_img)
GT_img = temp_GT.copy()
del temp_LR, temp_GT
if ver_flip:
temp_LR = np.flipud(LR_img)
LR_img = temp_LR.copy()
temp_GT = np.flipud(GT_img)
GT_img = temp_GT.copy()
del temp_LR, temp_GT
if rot:
LR_img = LR_img.transpose(1, 0, 2)
GT_img = GT_img.transpose(1, 0, 2)
return {'LR' : LR_img, 'GT' : GT_img}