commit 50e4c082d82281b1fc40b7a097c0551f6644acb1 Author: ℝλℙℍλΣʟ Date: Thu Nov 12 23:30:33 2020 +0100 Initial commit diff --git a/SRCNN-pytorch-master/README.md b/SRCNN-pytorch-master/README.md new file mode 100644 index 0000000..7a7621d --- /dev/null +++ b/SRCNN-pytorch-master/README.md @@ -0,0 +1,129 @@ +# SRCNN + +This repository is implementation of the ["Image Super-Resolution Using Deep Convolutional Networks"](https://arxiv.org/abs/1501.00092). + +
+ +## Differences from the original + +- Added the zero-padding +- Used the Adam instead of the SGD +- Removed the weights initialization + +## Requirements + +- PyTorch 1.0.0 +- Numpy 1.15.4 +- Pillow 5.4.1 +- h5py 2.8.0 +- tqdm 4.30.0 + +## Train + +The 91-image, Set5 dataset converted to HDF5 can be downloaded from the links below. + +| Dataset | Scale | Type | Link | +|---------|-------|------|------| +| 91-image | 2 | Train | [Download](https://www.dropbox.com/s/2hsah93sxgegsry/91-image_x2.h5?dl=0) | +| 91-image | 3 | Train | [Download](https://www.dropbox.com/s/curldmdf11iqakd/91-image_x3.h5?dl=0) | +| 91-image | 4 | Train | [Download](https://www.dropbox.com/s/22afykv4amfxeio/91-image_x4.h5?dl=0) | +| Set5 | 2 | Eval | [Download](https://www.dropbox.com/s/r8qs6tp395hgh8g/Set5_x2.h5?dl=0) | +| Set5 | 3 | Eval | [Download](https://www.dropbox.com/s/58ywjac4te3kbqq/Set5_x3.h5?dl=0) | +| Set5 | 4 | Eval | [Download](https://www.dropbox.com/s/0rz86yn3nnrodlb/Set5_x4.h5?dl=0) | + +Otherwise, you can use `prepare.py` to create custom dataset. + +```bash +python train.py --train-file "BLAH_BLAH/91-image_x3.h5" \ + --eval-file "BLAH_BLAH/Set5_x3.h5" \ + --outputs-dir "BLAH_BLAH/outputs" \ + --scale 3 \ + --lr 1e-4 \ + --batch-size 16 \ + --num-epochs 400 \ + --num-workers 8 \ + --seed 123 +``` + +## Test + +Pre-trained weights can be downloaded from the links below. + +| Model | Scale | Link | +|-------|-------|------| +| 9-5-5 | 2 | [Download](https://www.dropbox.com/s/rxluu1y8ptjm4rn/srcnn_x2.pth?dl=0) | +| 9-5-5 | 3 | [Download](https://www.dropbox.com/s/zn4fdobm2kw0c58/srcnn_x3.pth?dl=0) | +| 9-5-5 | 4 | [Download](https://www.dropbox.com/s/pd5b2ketm0oamhj/srcnn_x4.pth?dl=0) | + +The results are stored in the same path as the query image. + +```bash +python test.py --weights-file "BLAH_BLAH/srcnn_x3.pth" \ + --image-file "data/butterfly_GT.bmp" \ + --scale 3 +``` + +## Results + +We used the network settings for experiments, i.e., . + +PSNR was calculated on the Y channel. + +### Set5 + +| Eval. Mat | Scale | SRCNN | SRCNN (Ours) | +|-----------|-------|-------|--------------| +| PSNR | 2 | 36.66 | 36.65 | +| PSNR | 3 | 32.75 | 33.29 | +| PSNR | 4 | 30.49 | 30.25 | + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Original
BICUBIC x3
SRCNN x3 (27.53 dB)
+
+
+
+
+
+
Original
BICUBIC x3
SRCNN x3 (29.30 dB)
+
+
+
+
+
+
Original
BICUBIC x3
SRCNN x3 (28.58 dB)
+
+
+
+
+
+
diff --git a/SRCNN-pytorch-master/__pycache__/models.cpython-38.pyc b/SRCNN-pytorch-master/__pycache__/models.cpython-38.pyc new file mode 100644 index 0000000..746dd92 Binary files /dev/null and b/SRCNN-pytorch-master/__pycache__/models.cpython-38.pyc differ diff --git a/SRCNN-pytorch-master/__pycache__/utils.cpython-38.pyc b/SRCNN-pytorch-master/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000..908f52c Binary files /dev/null and b/SRCNN-pytorch-master/__pycache__/utils.cpython-38.pyc differ diff --git a/SRCNN-pytorch-master/datasets.py b/SRCNN-pytorch-master/datasets.py new file mode 100644 index 0000000..1f1bea7 --- /dev/null +++ b/SRCNN-pytorch-master/datasets.py @@ -0,0 +1,31 @@ +import h5py +import numpy as np +from torch.utils.data import Dataset + + +class TrainDataset(Dataset): + def __init__(self, h5_file): + super(TrainDataset, self).__init__() + self.h5_file = h5_file + + def __getitem__(self, idx): + with h5py.File(self.h5_file, 'r') as f: + return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0) + + def __len__(self): + with h5py.File(self.h5_file, 'r') as f: + return len(f['lr']) + + +class EvalDataset(Dataset): + def __init__(self, h5_file): + super(EvalDataset, self).__init__() + self.h5_file = h5_file + + def __getitem__(self, idx): + with h5py.File(self.h5_file, 'r') as f: + return np.expand_dims(f['lr'][str(idx)][:, :] / 255., 0), np.expand_dims(f['hr'][str(idx)][:, :] / 255., 0) + + def __len__(self): + with h5py.File(self.h5_file, 'r') as f: + return len(f['lr']) diff --git a/SRCNN-pytorch-master/information.txt b/SRCNN-pytorch-master/information.txt new file mode 100644 index 0000000..b793941 --- /dev/null +++ b/SRCNN-pytorch-master/information.txt @@ -0,0 +1,2 @@ +Programme écrit par : yjn870 +Code disponible sur github au lien : https://github.com/yjn870/SRCNN-pytorch \ No newline at end of file diff --git a/SRCNN-pytorch-master/models.py b/SRCNN-pytorch-master/models.py new file mode 100644 index 0000000..0465ddc --- /dev/null +++ b/SRCNN-pytorch-master/models.py @@ -0,0 +1,16 @@ +from torch import nn + + +class SRCNN(nn.Module): + def __init__(self, num_channels=1): + super(SRCNN, self).__init__() + self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2) + self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2) + self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.relu(self.conv1(x)) + x = self.relu(self.conv2(x)) + x = self.conv3(x) + return x diff --git a/SRCNN-pytorch-master/prepare.py b/SRCNN-pytorch-master/prepare.py new file mode 100644 index 0000000..c88dce7 --- /dev/null +++ b/SRCNN-pytorch-master/prepare.py @@ -0,0 +1,78 @@ +import argparse +import glob +import h5py +import numpy as np +import PIL.Image as pil_image +from utils import convert_rgb_to_y + + +def train(args): + h5_file = h5py.File(args.output_path, 'w') + + lr_patches = [] + hr_patches = [] + + for image_path in sorted(glob.glob('{}/*'.format(args.images_dir))): + hr = pil_image.open(image_path).convert('RGB') + hr_width = (hr.width // args.scale) * args.scale + hr_height = (hr.height // args.scale) * args.scale + hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC) + lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC) + lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC) + hr = np.array(hr).astype(np.float32) + lr = np.array(lr).astype(np.float32) + hr = convert_rgb_to_y(hr) + lr = convert_rgb_to_y(lr) + + for i in range(0, lr.shape[0] - args.patch_size + 1, args.stride): + for j in range(0, lr.shape[1] - args.patch_size + 1, args.stride): + lr_patches.append(lr[i:i + args.patch_size, j:j + args.patch_size]) + hr_patches.append(hr[i:i + args.patch_size, j:j + args.patch_size]) + + lr_patches = np.array(lr_patches) + hr_patches = np.array(hr_patches) + + h5_file.create_dataset('lr', data=lr_patches) + h5_file.create_dataset('hr', data=hr_patches) + + h5_file.close() + + +def eval(args): + h5_file = h5py.File(args.output_path, 'w') + + lr_group = h5_file.create_group('lr') + hr_group = h5_file.create_group('hr') + + for i, image_path in enumerate(sorted(glob.glob('{}/*'.format(args.images_dir)))): + hr = pil_image.open(image_path).convert('RGB') + hr_width = (hr.width // args.scale) * args.scale + hr_height = (hr.height // args.scale) * args.scale + hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC) + lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC) + lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC) + hr = np.array(hr).astype(np.float32) + lr = np.array(lr).astype(np.float32) + hr = convert_rgb_to_y(hr) + lr = convert_rgb_to_y(lr) + + lr_group.create_dataset(str(i), data=lr) + hr_group.create_dataset(str(i), data=hr) + + h5_file.close() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--images-dir', type=str, required=True) + parser.add_argument('--output-path', type=str, required=True) + parser.add_argument('--patch-size', type=int, default=33) + parser.add_argument('--stride', type=int, default=14) + parser.add_argument('--scale', type=int, default=2) + parser.add_argument('--eval', action='store_true') + args = parser.parse_args() + + if not args.eval: + train(args) + else: + eval(args) diff --git a/SRCNN-pytorch-master/test.py b/SRCNN-pytorch-master/test.py new file mode 100644 index 0000000..5541c43 --- /dev/null +++ b/SRCNN-pytorch-master/test.py @@ -0,0 +1,67 @@ +import argparse + +import torch +import torch.backends.cudnn as cudnn +import numpy as np +import PIL.Image as pil_image + +from models import SRCNN +from utils import convert_rgb_to_ycbcr, convert_ycbcr_to_rgb, calc_psnr + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--weights-file', type=str, required=True) + parser.add_argument('--image-file', type=str, required=True) + parser.add_argument('--scale', type=int, default=3) + args = parser.parse_args() + + cudnn.benchmark = True + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + + model = SRCNN().to(device) + + state_dict = model.state_dict() + for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items(): + if n in state_dict.keys(): + state_dict[n].copy_(p) + else: + raise KeyError(n) + + model.eval() + + image = pil_image.open(args.image_file).convert('RGB') + + image_width = (image.width // args.scale) * args.scale + image_height = (image.height // args.scale) * args.scale + image = image.resize((image_width, image_height), resample=pil_image.BICUBIC) + image = image.resize((image.width // args.scale, image.height // args.scale), resample=pil_image.BICUBIC) + image = image.resize((image.width * args.scale, image.height * args.scale), resample=pil_image.BICUBIC) + + + image = np.array(image).astype(np.float32) + ycbcr = convert_rgb_to_ycbcr(image) + + y = ycbcr[..., 0] + y /= 255. + y = torch.from_numpy(y).to(device) + y = y.unsqueeze(0).unsqueeze(0) + + with torch.no_grad(): + preds = model(y).clamp(0.0, 1.0) + + psnr = calc_psnr(y, preds) + print('PSNR: {:.2f}'.format(psnr)) + + preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0) + + output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0]) + output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8) + output = pil_image.fromarray(output) + + l = args.image_file.split(".") + l[-2] += '_srcnn_x{}'.format(args.scale) + args.image_file = ".".join(l) + + output.save(args.image_file) + diff --git a/SRCNN-pytorch-master/train.py b/SRCNN-pytorch-master/train.py new file mode 100644 index 0000000..31d5fd7 --- /dev/null +++ b/SRCNN-pytorch-master/train.py @@ -0,0 +1,112 @@ +import argparse +import os +import copy + +import torch +from torch import nn +import torch.optim as optim +import torch.backends.cudnn as cudnn +from torch.utils.data.dataloader import DataLoader +from tqdm import tqdm + +from models import SRCNN +from datasets import TrainDataset, EvalDataset +from utils import AverageMeter, calc_psnr + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--train-file', type=str, required=True) + parser.add_argument('--eval-file', type=str, required=True) + parser.add_argument('--outputs-dir', type=str, required=True) + parser.add_argument('--scale', type=int, default=3) + parser.add_argument('--lr', type=float, default=1e-4) + parser.add_argument('--batch-size', type=int, default=16) + parser.add_argument('--num-epochs', type=int, default=400) + parser.add_argument('--num-workers', type=int, default=8) + parser.add_argument('--seed', type=int, default=123) + args = parser.parse_args() + + args.outputs_dir = os.path.join(args.outputs_dir, 'x{}'.format(args.scale)) + + if not os.path.exists(args.outputs_dir): + os.makedirs(args.outputs_dir) + + cudnn.benchmark = True + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + + torch.manual_seed(args.seed) + + model = SRCNN().to(device) + criterion = nn.MSELoss() + optimizer = optim.Adam([ + {'params': model.conv1.parameters()}, + {'params': model.conv2.parameters()}, + {'params': model.conv3.parameters(), 'lr': args.lr * 0.1} + ], lr=args.lr) + + train_dataset = TrainDataset(args.train_file) + train_dataloader = DataLoader(dataset=train_dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.num_workers, + pin_memory=True, + drop_last=True) + eval_dataset = EvalDataset(args.eval_file) + eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1) + + best_weights = copy.deepcopy(model.state_dict()) + best_epoch = 0 + best_psnr = 0.0 + + for epoch in range(args.num_epochs): + model.train() + epoch_losses = AverageMeter() + + with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size)) as t: + t.set_description('epoch: {}/{}'.format(epoch, args.num_epochs - 1)) + + for data in train_dataloader: + inputs, labels = data + + inputs = inputs.to(device) + labels = labels.to(device) + + preds = model(inputs) + + loss = criterion(preds, labels) + + epoch_losses.update(loss.item(), len(inputs)) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg)) + t.update(len(inputs)) + + torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch))) + + model.eval() + epoch_psnr = AverageMeter() + + for data in eval_dataloader: + inputs, labels = data + + inputs = inputs.to(device) + labels = labels.to(device) + + with torch.no_grad(): + preds = model(inputs).clamp(0.0, 1.0) + + epoch_psnr.update(calc_psnr(preds, labels), len(inputs)) + + print('eval psnr: {:.2f}'.format(epoch_psnr.avg)) + + if epoch_psnr.avg > best_psnr: + best_epoch = epoch + best_psnr = epoch_psnr.avg + best_weights = copy.deepcopy(model.state_dict()) + + print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr)) + torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth')) diff --git a/SRCNN-pytorch-master/utils.py b/SRCNN-pytorch-master/utils.py new file mode 100644 index 0000000..3a4540a --- /dev/null +++ b/SRCNN-pytorch-master/utils.py @@ -0,0 +1,68 @@ +import torch +import numpy as np + + +def convert_rgb_to_y(img): + if type(img) == np.ndarray: + return 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256. + elif type(img) == torch.Tensor: + if len(img.shape) == 4: + img = img.squeeze(0) + return 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256. + else: + raise Exception('Unknown Type', type(img)) + + +def convert_rgb_to_ycbcr(img): + if type(img) == np.ndarray: + y = 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256. + cb = 128. + (-37.945 * img[:, :, 0] - 74.494 * img[:, :, 1] + 112.439 * img[:, :, 2]) / 256. + cr = 128. + (112.439 * img[:, :, 0] - 94.154 * img[:, :, 1] - 18.285 * img[:, :, 2]) / 256. + return np.array([y, cb, cr]).transpose([1, 2, 0]) + elif type(img) == torch.Tensor: + if len(img.shape) == 4: + img = img.squeeze(0) + y = 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256. + cb = 128. + (-37.945 * img[0, :, :] - 74.494 * img[1, :, :] + 112.439 * img[2, :, :]) / 256. + cr = 128. + (112.439 * img[0, :, :] - 94.154 * img[1, :, :] - 18.285 * img[2, :, :]) / 256. + return torch.cat([y, cb, cr], 0).permute(1, 2, 0) + else: + raise Exception('Unknown Type', type(img)) + + +def convert_ycbcr_to_rgb(img): + if type(img) == np.ndarray: + r = 298.082 * img[:, :, 0] / 256. + 408.583 * img[:, :, 2] / 256. - 222.921 + g = 298.082 * img[:, :, 0] / 256. - 100.291 * img[:, :, 1] / 256. - 208.120 * img[:, :, 2] / 256. + 135.576 + b = 298.082 * img[:, :, 0] / 256. + 516.412 * img[:, :, 1] / 256. - 276.836 + return np.array([r, g, b]).transpose([1, 2, 0]) + elif type(img) == torch.Tensor: + if len(img.shape) == 4: + img = img.squeeze(0) + r = 298.082 * img[0, :, :] / 256. + 408.583 * img[2, :, :] / 256. - 222.921 + g = 298.082 * img[0, :, :] / 256. - 100.291 * img[1, :, :] / 256. - 208.120 * img[2, :, :] / 256. + 135.576 + b = 298.082 * img[0, :, :] / 256. + 516.412 * img[1, :, :] / 256. - 276.836 + return torch.cat([r, g, b], 0).permute(1, 2, 0) + else: + raise Exception('Unknown Type', type(img)) + + +def calc_psnr(img1, img2): + return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2)) + + +class AverageMeter(object): + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count diff --git a/SRGAN-PyTorch-master/__pycache__/dataset.cpython-38.pyc b/SRGAN-PyTorch-master/__pycache__/dataset.cpython-38.pyc new file mode 100644 index 0000000..e71b28e Binary files /dev/null and b/SRGAN-PyTorch-master/__pycache__/dataset.cpython-38.pyc differ diff --git a/SRGAN-PyTorch-master/__pycache__/losses.cpython-38.pyc b/SRGAN-PyTorch-master/__pycache__/losses.cpython-38.pyc new file mode 100644 index 0000000..ec84dd5 Binary files /dev/null and b/SRGAN-PyTorch-master/__pycache__/losses.cpython-38.pyc differ diff --git a/SRGAN-PyTorch-master/__pycache__/mode.cpython-38.pyc b/SRGAN-PyTorch-master/__pycache__/mode.cpython-38.pyc new file mode 100644 index 0000000..7246c1c Binary files /dev/null and b/SRGAN-PyTorch-master/__pycache__/mode.cpython-38.pyc differ diff --git a/SRGAN-PyTorch-master/__pycache__/ops.cpython-38.pyc b/SRGAN-PyTorch-master/__pycache__/ops.cpython-38.pyc new file mode 100644 index 0000000..37efec4 Binary files /dev/null and b/SRGAN-PyTorch-master/__pycache__/ops.cpython-38.pyc differ diff --git a/SRGAN-PyTorch-master/__pycache__/srgan_model.cpython-38.pyc b/SRGAN-PyTorch-master/__pycache__/srgan_model.cpython-38.pyc new file mode 100644 index 0000000..22a0be1 Binary files /dev/null and b/SRGAN-PyTorch-master/__pycache__/srgan_model.cpython-38.pyc differ diff --git a/SRGAN-PyTorch-master/__pycache__/vgg19.cpython-38.pyc b/SRGAN-PyTorch-master/__pycache__/vgg19.cpython-38.pyc new file mode 100644 index 0000000..b7f718c Binary files /dev/null and b/SRGAN-PyTorch-master/__pycache__/vgg19.cpython-38.pyc differ diff --git a/SRGAN-PyTorch-master/dataset.py b/SRGAN-PyTorch-master/dataset.py new file mode 100644 index 0000000..9b23cdc --- /dev/null +++ b/SRGAN-PyTorch-master/dataset.py @@ -0,0 +1,135 @@ +import torch +from torch.utils.data import Dataset +import os +from PIL import Image +import numpy as np +import random + + +class mydata(Dataset): + def __init__(self, LR_path, GT_path, in_memory = True, transform = None): + + self.LR_path = LR_path + self.GT_path = GT_path + self.in_memory = in_memory + self.transform = transform + + self.LR_img = sorted(os.listdir(LR_path)) + self.GT_img = sorted(os.listdir(GT_path)) + + if in_memory: + self.LR_img = [np.array(Image.open(os.path.join(self.LR_path, lr)).convert("RGB")).astype(np.uint8) for lr in self.LR_img] + self.GT_img = [np.array(Image.open(os.path.join(self.GT_path, gt)).convert("RGB")).astype(np.uint8) for gt in self.GT_img] + + def __len__(self): + + return len(self.LR_img) + + def __getitem__(self, i): + + img_item = {} + + if self.in_memory: + GT = self.GT_img[i].astype(np.float32) + LR = self.LR_img[i].astype(np.float32) + + else: + GT = np.array(Image.open(os.path.join(self.GT_path, self.GT_img[i])).convert("RGB")) + LR = np.array(Image.open(os.path.join(self.LR_path, self.LR_img[i])).convert("RGB")) + + img_item['GT'] = (GT / 127.5) - 1.0 + img_item['LR'] = (LR / 127.5) - 1.0 + + if self.transform is not None: + img_item = self.transform(img_item) + + img_item['GT'] = img_item['GT'].transpose(2, 0, 1).astype(np.float32) + img_item['LR'] = img_item['LR'].transpose(2, 0, 1).astype(np.float32) + + return img_item + + +class testOnly_data(Dataset): + def __init__(self, LR_path, in_memory = True, transform = None): + + self.LR_path = LR_path + self.LR_img = sorted(os.listdir(LR_path)) + self.in_memory = in_memory + + if in_memory: + self.LR_img = [np.array(Image.open(os.path.join(self.LR_path, lr))) for lr in self.LR_img] + + def __len__(self): + + return len(self.LR_img) + + def __getitem__(self, i): + + img_item = {} + + if self.in_memory: + LR = self.LR_img[i] + + else: + LR = np.array(Image.open(os.path.join(self.LR_path, self.LR_img[i]))) + + img_item['LR'] = (LR / 127.5) - 1.0 + img_item['LR'] = img_item['LR'].transpose(2, 0, 1).astype(np.float32) + + return img_item + + +class crop(object): + def __init__(self, scale, patch_size): + + self.scale = scale + self.patch_size = patch_size + + def __call__(self, sample): + LR_img, GT_img = sample['LR'], sample['GT'] + ih, iw = LR_img.shape[:2] + + ix = random.randrange(0, iw - self.patch_size +1) + iy = random.randrange(0, ih - self.patch_size +1) + + tx = ix * self.scale + ty = iy * self.scale + + LR_patch = LR_img[iy : iy + self.patch_size, ix : ix + self.patch_size] + GT_patch = GT_img[ty : ty + (self.scale * self.patch_size), tx : tx + (self.scale * self.patch_size)] + + return {'LR' : LR_patch, 'GT' : GT_patch} + +class augmentation(object): + + def __call__(self, sample): + LR_img, GT_img = sample['LR'], sample['GT'] + + hor_flip = random.randrange(0,2) + ver_flip = random.randrange(0,2) + rot = random.randrange(0,2) + + if hor_flip: + temp_LR = np.fliplr(LR_img) + LR_img = temp_LR.copy() + temp_GT = np.fliplr(GT_img) + GT_img = temp_GT.copy() + + del temp_LR, temp_GT + + if ver_flip: + temp_LR = np.flipud(LR_img) + LR_img = temp_LR.copy() + temp_GT = np.flipud(GT_img) + GT_img = temp_GT.copy() + + del temp_LR, temp_GT + + if rot: + LR_img = LR_img.transpose(1, 0, 2) + GT_img = GT_img.transpose(1, 0, 2) + + + return {'LR' : LR_img, 'GT' : GT_img} + + diff --git a/SRGAN-PyTorch-master/information.txt b/SRGAN-PyTorch-master/information.txt new file mode 100644 index 0000000..e60ac7a --- /dev/null +++ b/SRGAN-PyTorch-master/information.txt @@ -0,0 +1,2 @@ +Programme écrit par dongheehand en Python +Disponible sur github au lien https://github.com/dongheehand/SRGAN-PyTorch \ No newline at end of file diff --git a/SRGAN-PyTorch-master/losses.py b/SRGAN-PyTorch-master/losses.py new file mode 100644 index 0000000..5b26ce4 --- /dev/null +++ b/SRGAN-PyTorch-master/losses.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn +from torchvision import transforms + + +class MeanShift(nn.Conv2d): + def __init__( + self, rgb_range = 1, + norm_mean=(0.485, 0.456, 0.406), norm_std=(0.229, 0.224, 0.225), sign=-1): + + super(MeanShift, self).__init__(3, 3, kernel_size=1) + std = torch.Tensor(norm_std) + self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) + self.bias.data = sign * rgb_range * torch.Tensor(norm_mean) / std + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + for p in self.parameters(): + p.requires_grad = False + + +class perceptual_loss(nn.Module): + + def __init__(self, vgg): + super(perceptual_loss, self).__init__() + self.normalization_mean = [0.485, 0.456, 0.406] + self.normalization_std = [0.229, 0.224, 0.225] + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.transform = MeanShift(norm_mean = self.normalization_mean, norm_std = self.normalization_std).to(self.device) + self.vgg = vgg + self.criterion = nn.MSELoss() + def forward(self, HR, SR, layer = 'relu5_4'): + ## HR and SR should be normalized [0,1] + hr = self.transform(HR) + sr = self.transform(SR) + + hr_feat = getattr(self.vgg(hr), layer) + sr_feat = getattr(self.vgg(sr), layer) + + return self.criterion(hr_feat, sr_feat), hr_feat, sr_feat + +class TVLoss(nn.Module): + def __init__(self, tv_loss_weight=1): + super(TVLoss, self).__init__() + self.tv_loss_weight = tv_loss_weight + + def forward(self, x): + batch_size = x.size()[0] + h_x = x.size()[2] + w_x = x.size()[3] + count_h = self.tensor_size(x[:, :, 1:, :]) + count_w = self.tensor_size(x[:, :, :, 1:]) + h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum() + w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum() + + return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size + + @staticmethod + def tensor_size(t): + return t.size()[1] * t.size()[2] * t.size()[3] + diff --git a/SRGAN-PyTorch-master/main.py b/SRGAN-PyTorch-master/main.py new file mode 100644 index 0000000..15670e8 --- /dev/null +++ b/SRGAN-PyTorch-master/main.py @@ -0,0 +1,39 @@ +from mode import * +import argparse + + +parser = argparse.ArgumentParser() + +def str2bool(v): + return v.lower() in ('true') + +parser.add_argument("--LR_path", type = str, default = '../dataSet/DIV2K/DIV2K_train_LR_bicubic/X4') +parser.add_argument("--GT_path", type = str, default = '../dataSet/DIV2K/DIV2K_train_HR/') +parser.add_argument("--res_num", type = int, default = 16) +parser.add_argument("--num_workers", type = int, default = 0) +parser.add_argument("--batch_size", type = int, default = 16) +parser.add_argument("--L2_coeff", type = float, default = 1.0) +parser.add_argument("--adv_coeff", type = float, default = 1e-3) +parser.add_argument("--tv_loss_coeff", type = float, default = 0.0) +parser.add_argument("--pre_train_epoch", type = int, default = 8000) +parser.add_argument("--fine_train_epoch", type = int, default = 4000) +parser.add_argument("--scale", type = int, default = 4) +parser.add_argument("--patch_size", type = int, default = 24) +parser.add_argument("--feat_layer", type = str, default = 'relu5_4') +parser.add_argument("--vgg_rescale_coeff", type = float, default = 0.006) +parser.add_argument("--fine_tuning", type = str2bool, default = False) +parser.add_argument("--in_memory", type = str2bool, default = True) +parser.add_argument("--generator_path", type = str) +parser.add_argument("--mode", type = str, default = 'train') + +args = parser.parse_args() + +if args.mode == 'train': + train(args) + +elif args.mode == 'test': + test(args) + +elif args.mode == 'test_only': + test_only(args) + diff --git a/SRGAN-PyTorch-master/mode.py b/SRGAN-PyTorch-master/mode.py new file mode 100644 index 0000000..ce15605 --- /dev/null +++ b/SRGAN-PyTorch-master/mode.py @@ -0,0 +1,208 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from torchvision import transforms +from losses import TVLoss, perceptual_loss +from dataset import * +from srgan_model import Generator, Discriminator +from vgg19 import vgg19 +import numpy as np +from PIL import Image +from skimage.color import rgb2ycbcr +from skimage.measure import compare_psnr + + +def train(args): + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + transform = transforms.Compose([crop(args.scale, args.patch_size), augmentation()]) + dataset = mydata(GT_path = args.GT_path, LR_path = args.LR_path, in_memory = args.in_memory, transform = transform) + loader = DataLoader(dataset, batch_size = args.batch_size, shuffle = True, num_workers = args.num_workers) + + generator = Generator(img_feat = 3, n_feats = 64, kernel_size = 3, num_block = args.res_num, scale=args.scale) + + + if args.fine_tuning: + generator.load_state_dict(torch.load(args.generator_path)) + print("pre-trained model is loaded") + print("path : %s"%(args.generator_path)) + + generator = generator.to(device) + generator.train() + + l2_loss = nn.MSELoss() + g_optim = optim.Adam(generator.parameters(), lr = 1e-4) + + pre_epoch = 0 + fine_epoch = 0 + + #### Train using L2_loss + while pre_epoch < args.pre_train_epoch: + for i, tr_data in enumerate(loader): + gt = tr_data['GT'].to(device) + lr = tr_data['LR'].to(device) + + output, _ = generator(lr) + loss = l2_loss(gt, output) + + g_optim.zero_grad() + loss.backward() + g_optim.step() + + pre_epoch += 1 + + if pre_epoch % 2 == 0: + print(pre_epoch) + print(loss.item()) + print('=========') + + if pre_epoch % 800 ==0: + torch.save(generator.state_dict(), './model/pre_trained_model_%03d.pt'%pre_epoch) + + + #### Train using perceptual & adversarial loss + vgg_net = vgg19().to(device) + vgg_net = vgg_net.eval() + + discriminator = Discriminator(patch_size = args.patch_size * args.scale) + discriminator = discriminator.to(device) + discriminator.train() + + d_optim = optim.Adam(discriminator.parameters(), lr = 1e-4) + scheduler = optim.lr_scheduler.StepLR(g_optim, step_size = 2000, gamma = 0.1) + + VGG_loss = perceptual_loss(vgg_net) + cross_ent = nn.BCELoss() + tv_loss = TVLoss() + real_label = torch.ones((args.batch_size, 1)).to(device) + fake_label = torch.zeros((args.batch_size, 1)).to(device) + + while fine_epoch < args.fine_train_epoch: + + scheduler.step() + + for i, tr_data in enumerate(loader): + gt = tr_data['GT'].to(device) + lr = tr_data['LR'].to(device) + + ## Training Discriminator + output, _ = generator(lr) + fake_prob = discriminator(output) + real_prob = discriminator(gt) + + d_loss_real = cross_ent(real_prob, real_label) + d_loss_fake = cross_ent(fake_prob, fake_label) + + d_loss = d_loss_real + d_loss_fake + + g_optim.zero_grad() + d_optim.zero_grad() + d_loss.backward() + d_optim.step() + + ## Training Generator + output, _ = generator(lr) + fake_prob = discriminator(output) + + _percep_loss, hr_feat, sr_feat = VGG_loss((gt + 1.0) / 2.0, (output + 1.0) / 2.0, layer = args.feat_layer) + + L2_loss = l2_loss(output, gt) + percep_loss = args.vgg_rescale_coeff * _percep_loss + adversarial_loss = args.adv_coeff * cross_ent(fake_prob, real_label) + total_variance_loss = args.tv_loss_coeff * tv_loss(args.vgg_rescale_coeff * (hr_feat - sr_feat)**2) + + g_loss = percep_loss + adversarial_loss + total_variance_loss + L2_loss + + g_optim.zero_grad() + d_optim.zero_grad() + g_loss.backward() + g_optim.step() + + + fine_epoch += 1 + + if fine_epoch % 2 == 0: + print(fine_epoch) + print(g_loss.item()) + print(d_loss.item()) + print('=========') + + if fine_epoch % 500 ==0: + torch.save(generator.state_dict(), './model/SRGAN_gene_%03d.pt'%fine_epoch) + torch.save(discriminator.state_dict(), './model/SRGAN_discrim_%03d.pt'%fine_epoch) + + +# In[ ]: + +def test(args): + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dataset = mydata(GT_path = args.GT_path, LR_path = args.LR_path, in_memory = False, transform = None) + loader = DataLoader(dataset, batch_size = 1, shuffle = False, num_workers = args.num_workers) + + generator = Generator(img_feat = 3, n_feats = 64, kernel_size = 3, num_block = args.res_num) + generator.load_state_dict(torch.load(args.generator_path)) + generator = generator.to(device) + generator.eval() + + f = open('./result.txt', 'w') + psnr_list = [] + + with torch.no_grad(): + for i, te_data in enumerate(loader): + gt = te_data['GT'].to(device) + lr = te_data['LR'].to(device) + + bs, c, h, w = lr.size() + gt = gt[:, :, : h * args.scale, : w *args.scale] + + output, _ = generator(lr) + + output = output[0].cpu().numpy() + output = np.clip(output, -1.0, 1.0) + gt = gt[0].cpu().numpy() + + output = (output + 1.0) / 2.0 + gt = (gt + 1.0) / 2.0 + + output = output.transpose(1,2,0) + gt = gt.transpose(1,2,0) + + y_output = rgb2ycbcr(output)[args.scale:-args.scale, args.scale:-args.scale, :1] + y_gt = rgb2ycbcr(gt)[args.scale:-args.scale, args.scale:-args.scale, :1] + + psnr = compare_psnr(y_output / 255.0, y_gt / 255.0, data_range = 1.0) + psnr_list.append(psnr) + f.write('psnr : %04f \n' % psnr) + + result = Image.fromarray((output * 255.0).astype(np.uint8)) + result.save('./result/res_%04d.png'%i) + + f.write('avg psnr : %04f' % np.mean(psnr_list)) + + +def test_only(args): + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dataset = testOnly_data(LR_path = args.LR_path, in_memory = False, transform = None) + loader = DataLoader(dataset, batch_size = 1, shuffle = False, num_workers = args.num_workers) + + generator = Generator(img_feat = 3, n_feats = 64, kernel_size = 3, num_block = args.res_num) + generator.load_state_dict(torch.load(args.generator_path)) + generator = generator.to(device) + generator.eval() + + with torch.no_grad(): + for i, te_data in enumerate(loader): + lr = te_data['LR'].to(device) + output, _ = generator(lr) + output = output[0].cpu().numpy() + output = (output + 1.0) / 2.0 + output = output.transpose(1,2,0) + result = Image.fromarray((output * 255.0).astype(np.uint8)) + result.save('./result/res_%04d.png'%i) + + + diff --git a/SRGAN-PyTorch-master/ops.py b/SRGAN-PyTorch-master/ops.py new file mode 100644 index 0000000..24ac7d9 --- /dev/null +++ b/SRGAN-PyTorch-master/ops.py @@ -0,0 +1,97 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class _conv(nn.Conv2d): + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias): + super(_conv, self).__init__(in_channels = in_channels, out_channels = out_channels, + kernel_size = kernel_size, stride = stride, padding = (kernel_size) // 2, bias = True) + + self.weight.data = torch.normal(torch.zeros((out_channels, in_channels, kernel_size, kernel_size)), 0.02) + self.bias.data = torch.zeros((out_channels)) + + for p in self.parameters(): + p.requires_grad = True + + +class conv(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size, BN = False, act = None, stride = 1, bias = True): + super(conv, self).__init__() + m = [] + m.append(_conv(in_channels = in_channel, out_channels = out_channel, + kernel_size = kernel_size, stride = stride, padding = (kernel_size) // 2, bias = True)) + + if BN: + m.append(nn.BatchNorm2d(num_features = out_channel)) + + if act is not None: + m.append(act) + + self.body = nn.Sequential(*m) + + def forward(self, x): + out = self.body(x) + return out + +class ResBlock(nn.Module): + def __init__(self, channels, kernel_size, act = nn.ReLU(inplace = True), bias = True): + super(ResBlock, self).__init__() + m = [] + m.append(conv(channels, channels, kernel_size, BN = True, act = act)) + m.append(conv(channels, channels, kernel_size, BN = True, act = None)) + self.body = nn.Sequential(*m) + + def forward(self, x): + res = self.body(x) + res += x + return res + +class BasicBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, num_res_block, act = nn.ReLU(inplace = True)): + super(BasicBlock, self).__init__() + m = [] + + self.conv = conv(in_channels, out_channels, kernel_size, BN = False, act = act) + for i in range(num_res_block): + m.append(ResBlock(out_channels, kernel_size, act)) + + m.append(conv(out_channels, out_channels, kernel_size, BN = True, act = None)) + + self.body = nn.Sequential(*m) + + def forward(self, x): + res = self.conv(x) + out = self.body(res) + out += res + + return out + +class Upsampler(nn.Module): + def __init__(self, channel, kernel_size, scale, act = nn.ReLU(inplace = True)): + super(Upsampler, self).__init__() + m = [] + m.append(conv(channel, channel * scale * scale, kernel_size)) + m.append(nn.PixelShuffle(scale)) + + if act is not None: + m.append(act) + + self.body = nn.Sequential(*m) + + def forward(self, x): + out = self.body(x) + return out + +class discrim_block(nn.Module): + def __init__(self, in_feats, out_feats, kernel_size, act = nn.LeakyReLU(inplace = True)): + super(discrim_block, self).__init__() + m = [] + m.append(conv(in_feats, out_feats, kernel_size, BN = True, act = act)) + m.append(conv(out_feats, out_feats, kernel_size, BN = True, act = act, stride = 2)) + self.body = nn.Sequential(*m) + + def forward(self, x): + out = self.body(x) + return out + diff --git a/SRGAN-PyTorch-master/readme.md b/SRGAN-PyTorch-master/readme.md new file mode 100644 index 0000000..61ac2f0 --- /dev/null +++ b/SRGAN-PyTorch-master/readme.md @@ -0,0 +1,68 @@ +# Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network + +## Overview + +An unofficial implementation of SRGAN described in the paper using PyTorch. +* [ Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network](https://arxiv.org/abs/1609.04802) + +Published in CVPR 2017 + +## Requirement +- Python 3.6.5 +- PyTorch 1.1.0 +- Pillow 5.1.0 +- numpy 1.14.5 +- scikit-image 0.15.0 + +## Datasets +- [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) + +## Pre-trained model +- [SRResNet](https://drive.google.com/open?id=15F2zOrOg2hIjdI0WsrOwF1y8REOkmmm0) + + +- [SRGAN](https://drive.google.com/open?id=1-HmcV5X94u411HRa-KEMcGhAO1OXAjAc) + +## Train & Test +Train + +``` +python main.py --LR_path ./LR_imgs_dir --GT_path ./GT_imgs_dir +``` + +Test + +``` +python main.py --mode test --LR_path ./LR_imgs_dir --GT_path ./GT_imgs_dir --generator_path ./model/SRGAN.pt +``` + +Inference your own images + +``` +python main.py --mode test_only --LR_path ./LR_imgs_dir --generator_path ./model/SRGAN.pt +``` + +## Experimental Results +Experimental results on benchmarks. + +### Quantitative Results + +| Method| Set5| Set14| B100 | +|-------|-----|------|------| +|Bicubic|28.43|25.99|25.94| +|SRResNet(paper)|32.05|28.49|27.58| +|SRResNet(my model)|31.96|28.48|27.49| +|SRGAN(paper)|29.40|26.02|25.16| +|SRGAN(my model)|29.93|26.95|26.10| + +### Qualitative Results + +| Bicubic | SRResNet | SRGAN | +| --- | --- | --- | +| | | | +| | | | +| | | | + +## Comments +If you have any questions or comments on my codes, please email to me. [son1113@snu.ac.kr](mailto:son1113@snu.ac.kr) + diff --git a/SRGAN-PyTorch-master/srgan_model.py b/SRGAN-PyTorch-master/srgan_model.py new file mode 100644 index 0000000..644d1e1 --- /dev/null +++ b/SRGAN-PyTorch-master/srgan_model.py @@ -0,0 +1,74 @@ +import torch +import torch.nn as nn +from ops import * + + +class Generator(nn.Module): + + def __init__(self, img_feat = 3, n_feats = 64, kernel_size = 3, num_block = 16, act = nn.PReLU(), scale=4): + super(Generator, self).__init__() + + self.conv01 = conv(in_channel = img_feat, out_channel = n_feats, kernel_size = 9, BN = False, act = act) + + resblocks = [ResBlock(channels = n_feats, kernel_size = 3, act = act) for _ in range(num_block)] + self.body = nn.Sequential(*resblocks) + + self.conv02 = conv(in_channel = n_feats, out_channel = n_feats, kernel_size = 3, BN = True, act = None) + + if(scale == 4): + upsample_blocks = [Upsampler(channel = n_feats, kernel_size = 3, scale = 2, act = act) for _ in range(2)] + else: + upsample_blocks = [Upsampler(channel = n_feats, kernel_size = 3, scale = scale, act = act)] + + self.tail = nn.Sequential(*upsample_blocks) + + self.last_conv = conv(in_channel = n_feats, out_channel = img_feat, kernel_size = 3, BN = False, act = nn.Tanh()) + + def forward(self, x): + + x = self.conv01(x) + _skip_connection = x + + x = self.body(x) + x = self.conv02(x) + feat = x + _skip_connection + + x = self.tail(feat) + x = self.last_conv(x) + + return x, feat + +class Discriminator(nn.Module): + + def __init__(self, img_feat = 3, n_feats = 64, kernel_size = 3, act = nn.LeakyReLU(inplace = True), num_of_block = 3, patch_size = 96): + super(Discriminator, self).__init__() + self.act = act + + self.conv01 = conv(in_channel = img_feat, out_channel = n_feats, kernel_size = 3, BN = False, act = self.act) + self.conv02 = conv(in_channel = n_feats, out_channel = n_feats, kernel_size = 3, BN = False, act = self.act, stride = 2) + + body = [discrim_block(in_feats = n_feats * (2 ** i), out_feats = n_feats * (2 ** (i + 1)), kernel_size = 3, act = self.act) for i in range(num_of_block)] + self.body = nn.Sequential(*body) + + self.linear_size = ((patch_size // (2 ** (num_of_block + 1))) ** 2) * (n_feats * (2 ** num_of_block)) + + tail = [] + + tail.append(nn.Linear(self.linear_size, 1024)) + tail.append(self.act) + tail.append(nn.Linear(1024, 1)) + tail.append(nn.Sigmoid()) + + self.tail = nn.Sequential(*tail) + + + def forward(self, x): + + x = self.conv01(x) + x = self.conv02(x) + x = self.body(x) + x = x.view(-1, self.linear_size) + x = self.tail(x) + + return x + diff --git a/SRGAN-PyTorch-master/vgg19.py b/SRGAN-PyTorch-master/vgg19.py new file mode 100644 index 0000000..b3f265f --- /dev/null +++ b/SRGAN-PyTorch-master/vgg19.py @@ -0,0 +1,80 @@ +from torchvision import models +from collections import namedtuple +import torch +import torch.nn as nn + + +class vgg19(nn.Module): + + def __init__(self, pre_trained = True, require_grad = False): + super(vgg19, self).__init__() + self.vgg_feature = models.vgg19(pretrained = pre_trained).features + self.seq_list = [nn.Sequential(ele) for ele in self.vgg_feature] + self.vgg_layer = ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', + 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', + 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', + 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'] + + if not require_grad: + for parameter in self.parameters(): + parameter.requires_grad = False + + def forward(self, x): + + conv1_1 = self.seq_list[0](x) + relu1_1 = self.seq_list[1](conv1_1) + conv1_2 = self.seq_list[2](relu1_1) + relu1_2 = self.seq_list[3](conv1_2) + pool1 = self.seq_list[4](relu1_2) + + conv2_1 = self.seq_list[5](pool1) + relu2_1 = self.seq_list[6](conv2_1) + conv2_2 = self.seq_list[7](relu2_1) + relu2_2 = self.seq_list[8](conv2_2) + pool2 = self.seq_list[9](relu2_2) + + conv3_1 = self.seq_list[10](pool2) + relu3_1 = self.seq_list[11](conv3_1) + conv3_2 = self.seq_list[12](relu3_1) + relu3_2 = self.seq_list[13](conv3_2) + conv3_3 = self.seq_list[14](relu3_2) + relu3_3 = self.seq_list[15](conv3_3) + conv3_4 = self.seq_list[16](relu3_3) + relu3_4 = self.seq_list[17](conv3_4) + pool3 = self.seq_list[18](relu3_4) + + conv4_1 = self.seq_list[19](pool3) + relu4_1 = self.seq_list[20](conv4_1) + conv4_2 = self.seq_list[21](relu4_1) + relu4_2 = self.seq_list[22](conv4_2) + conv4_3 = self.seq_list[23](relu4_2) + relu4_3 = self.seq_list[24](conv4_3) + conv4_4 = self.seq_list[25](relu4_3) + relu4_4 = self.seq_list[26](conv4_4) + pool4 = self.seq_list[27](relu4_4) + + conv5_1 = self.seq_list[28](pool4) + relu5_1 = self.seq_list[29](conv5_1) + conv5_2 = self.seq_list[30](relu5_1) + relu5_2 = self.seq_list[31](conv5_2) + conv5_3 = self.seq_list[32](relu5_2) + relu5_3 = self.seq_list[33](conv5_3) + conv5_4 = self.seq_list[34](relu5_3) + relu5_4 = self.seq_list[35](conv5_4) + pool5 = self.seq_list[36](relu5_4) + + vgg_output = namedtuple("vgg_output", self.vgg_layer) + + vgg_list = [conv1_1, relu1_1, conv1_2, relu1_2, pool1, + conv2_1, relu2_1, conv2_2, relu2_2, pool2, + conv3_1, relu3_1, conv3_2, relu3_2, conv3_3, relu3_3, conv3_4, relu3_4, pool3, + conv4_1, relu4_1, conv4_2, relu4_2, conv4_3, relu4_3, conv4_4, relu4_4, pool4, + conv5_1, relu5_1, conv5_2, relu5_2, conv5_3, relu5_3, conv5_4, relu5_4, pool5] + + out = vgg_output(*vgg_list) + + + return out + + diff --git a/Super-résolution.pyw b/Super-résolution.pyw new file mode 100644 index 0000000..ff98704 --- /dev/null +++ b/Super-résolution.pyw @@ -0,0 +1,211 @@ +from tkinter import * +from tkinter.filedialog import askopenfilename +from tkinter.messagebox import showerror +import os, subprocess, shutil, glob, time +from PIL import Image, ImageTk +from threading import Thread + +class SRGUI(): + def __init__(self): + self.root = Tk() + + self.input_img_path = StringVar() + + self.canvas = {} + self.image = {} + self.imagetk = {} + self.id_canvas = {} + self.label = {} + + self.default_diff = 20 + self.diff = self.default_diff + + ##### ENTREE + self.frame_input = Frame(self.root) + self.frame_input.grid(row=1, column=1, rowspan=2) + + Label(self.frame_input, text="Entrée", font=("Purisa", 18)).grid(row=1, column=1, columnspan=2) + self.canvas["input"] = Canvas(self.frame_input, relief=SOLID, borderwidth=2, width=400, height=400) + self.canvas["input"].grid(row=2, column=1, columnspan=2) + Entry(self.frame_input, textvariable=self.input_img_path, width=60).grid(row=3, column=1, sticky="NEWS") + + Button(self.frame_input, text="...", relief=RIDGE, command=self.select_img).grid(row=3, column=2, sticky="NEWS") + Button(self.frame_input, text="Calculer", relief=RIDGE, command=self.calcul).grid(row=4, column=1, columnspan=2, sticky="NEWS") + + def zoom(event): + for type in ["input", "bil", "bic", "SRCNN", "SRGAN", "original"]: + if type in self.id_canvas: + x1 = ((event.x - self.diff) * self.image[type].width) // 400 + x2 = ((event.x + self.diff) * self.image[type].width) // 400 + y1 = ((event.y - self.diff) * self.image[type].height) // 400 + y2 = ((event.y + self.diff) * self.image[type].height) // 400 + + _img = self.image[type].crop((x1, y1, x2, y2)) + self.imagetk[type] = ImageTk.PhotoImage(_img.resize((400, 400))) + self.canvas[type].itemconfig(self.id_canvas[type], image = self.imagetk[type]) + + def change_diff(event): + if event.delta > 0: self.diff -= 5 + else: self.diff += 5 + zoom(event) + + def reset(event): + for type in ["input", "bil", "bic", "SRCNN", "SRGAN", "original"]: + if type in self.id_canvas: + self.imagetk[type] = ImageTk.PhotoImage(self.image[type].resize((400, 400))) + self.canvas[type].itemconfig(self.id_canvas[type], image = self.imagetk[type]) + + def activate_zoom(event): + self.canvas["input"].bind("", zoom) + def desactivate_zoom(event): + self.canvas["input"].unbind("") + + self.canvas["input"].bind("", activate_zoom) + self.canvas["input"].bind("", desactivate_zoom) + + self.canvas["input"].bind("", reset) + self.canvas["input"].bind("", change_diff) + + ##### BILINEAIRE + self.frame_bil = Frame(self.root) + self.frame_bil.grid(row=1, column=3) + + Label(self.frame_bil, text="Bilinéaire", font=("Purisa", 18)).grid(row=1, column=1) + self.canvas["bil"] = Canvas(self.frame_bil, relief=SOLID, borderwidth=2, width=400, height=400) + self.canvas["bil"].grid(row=2, column=1) + + self.label["bil"] = Label(self.frame_bil, fg="gray") + self.label["bil"].grid(row=3, column=1) + + ##### BICUBIQUE + self.frame_bic = Frame(self.root) + self.frame_bic.grid(row=2, column=3) + + Label(self.frame_bic, text="Bicubique", font=("Purisa", 18)).grid(row=1, column=1) + self.canvas["bic"] = Canvas(self.frame_bic, relief=SOLID, borderwidth=2, width=400, height=400) + self.canvas["bic"].grid(row=2, column=1) + + self.label["bic"] = Label(self.frame_bic, fg="gray") + self.label["bic"].grid(row=3, column=1) + + ##### SRCNN + self.frame_SRCNN = Frame(self.root) + self.frame_SRCNN.grid(row=1, column=4) + + Label(self.frame_SRCNN, text="SRCNN", font=("Purisa", 18)).grid(row=1, column=1) + self.canvas["SRCNN"] = Canvas(self.frame_SRCNN, relief=SOLID, borderwidth=2, width=400, height=400) + self.canvas["SRCNN"].grid(row=2, column=1) + + self.label["SRCNN"] = Label(self.frame_SRCNN, fg="gray") + self.label["SRCNN"].grid(row=3, column=1) + + ##### SRGAN + self.frame_SRGAN = Frame(self.root) + self.frame_SRGAN.grid(row=2, column=4) + + Label(self.frame_SRGAN, text="SRGAN", font=("Purisa", 18)).grid(row=1, column=1) + self.canvas["SRGAN"] = Canvas(self.frame_SRGAN, relief=SOLID, borderwidth=2, width=400, height=400) + self.canvas["SRGAN"].grid(row=2, column=1) + + self.label["SRGAN"] = Label(self.frame_SRGAN, fg="gray") + self.label["SRGAN"].grid(row=3, column=1) + + ##### ORIGINAL + self.frame_original = Frame(self.root) + + Label(self.frame_original, text="Original", font=("Purisa", 18)).grid(row=1, column=1) + self.canvas["original"] = Canvas(self.frame_original, relief=SOLID, borderwidth=2, width=400, height=400) + self.canvas["original"].grid(row=2, column=1) + + + def select_img(self): + self.input_img_path.set(askopenfilename(filetypes = [('Images', '*.png *.jpg *.jpeg *.bmp *.gif')])) + if os.path.exists(self.input_img_path.get()): + self.image["input"] = Image.open(self.input_img_path.get()) + self.imagetk["input"] = ImageTk.PhotoImage(self.image["input"].resize((400, 400))) + self.id_canvas["input"] = self.canvas["input"].create_image(204, 204, image = self.imagetk["input"]) + + _l = self.input_img_path.get().split(".") + _l[-2] += "_original" + original_img_path = ".".join(_l) + if os.path.exists(original_img_path): + self.image["original"] = Image.open(original_img_path) + self.imagetk["original"] = ImageTk.PhotoImage(self.image["original"].resize((400, 400))) + self.id_canvas["original"] = self.canvas["original"].create_image(204, 204, image=self.imagetk["original"]) + + self.frame_original.grid(row=1, column=5, rowspan=2) + else: self.frame_original.grid_forget() + + + else: showerror("Erreur", "Ce fichier n'existe pas.") + + def calcul(self): + def thread_func(): + type2func = {"bil": self.Bilinear, "bic": self.Bicubic, "SRCNN": self.SRCNN, "SRGAN": self.SRGAN} + for type in type2func: + try: + t1 = time.time() + type2func[type]() + t2 = time.time() + except: self.label["bil"].config(text="Erreur", fg="red") + else: self.label["bil"].config(text="Temps de calcul : %0.2fs" % (t2 - t1), fg="gray") + + thfunc = Thread(target=thread_func) + thfunc.setDaemon(True) + thfunc.start() + + def Bilinear(self): + self.image["bil"] = Image.open(self.input_img_path.get()).resize((400 * 4, 400 * 4), Image.BILINEAR) + self.imagetk["bil"] = ImageTk.PhotoImage(self.image["bil"].resize((400, 400))) + self.id_canvas["bil"] = self.canvas["bil"].create_image(204, 204, image=self.imagetk["bil"]) + + def Bicubic(self): + self.image["bic"] = Image.open(self.input_img_path.get()).resize((400 * 4, 400 * 4), Image.BICUBIC) + self.imagetk["bic"] = ImageTk.PhotoImage(self.image["bic"].resize((400, 400))) + self.id_canvas["bic"] = self.canvas["bic"].create_image(204, 204, image=self.imagetk["bic"]) + + def SRCNN(self): + SRCNN_img_path = "./SRCNN-pytorch-master/data/" + os.path.basename(self.input_img_path.get()) + self.image["input"].resize((self.image["input"].width * 4, self.image["input"].height * 4), Image.BICUBIC).save(SRCNN_img_path) + + subprocess.call( + '"D:/ProgramData/Anaconda3/python.exe" \ + "./SRCNN-pytorch-master/test.py" \ + --weights-file "./SRCNN-pytorch-master/weights/srcnn_x4.pth" \ + --image-file "%s" \ + --scale 4' % SRCNN_img_path) + + l = SRCNN_img_path.split(".") + l[-2] += "_srcnn_x4" + self.SRCNN_img_path = ".".join(l) + + self.image["SRCNN"] = Image.open(self.SRCNN_img_path) + self.imagetk["SRCNN"] = ImageTk.PhotoImage(self.image["SRCNN"].resize((400, 400))) + self.id_canvas["SRCNN"] = self.canvas["SRCNN"].create_image(204, 204, image=self.imagetk["SRCNN"]) + + def SRGAN(self): + SRGAN_img_path = "./SRGAN-PyTorch-master/LR_imgs_dir/" + os.path.basename(self.input_img_path.get()) + shutil.copy(self.input_img_path.get(), SRGAN_img_path) + + workingdir = os.getcwd() + os.chdir("./SRGAN-PyTorch-master/") + + subprocess.call( + '"D:/ProgramData/Anaconda3/python.exe" \ + main.py \ + --mode test_only \ + --LR_path LR_imgs_dir/ \ + --generator_path model/SRGAN.pt"') + + os.chdir(workingdir) + + os.remove(SRGAN_img_path) + + self.SRGAN_img_path = glob.glob("./SRGAN-PyTorch-master/result/res_*.png")[-1] + + self.image["SRGAN"] = Image.open(self.SRGAN_img_path) + self.imagetk["SRGAN"] = ImageTk.PhotoImage(self.image["SRGAN"].resize((400, 400))) + self.id_canvas["SRGAN"] = self.canvas["SRGAN"].create_image(204, 204, image=self.imagetk["SRGAN"]) + +App = SRGUI() +mainloop() \ No newline at end of file