From 50e4c082d82281b1fc40b7a097c0551f6644acb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=84=9D=CE=BB=E2=84=99=E2=84=8D=CE=BB=CE=A3=CA=9F?= Date: Thu, 12 Nov 2020 23:30:33 +0100 Subject: [PATCH] Initial commit --- SRCNN-pytorch-master/README.md | 129 +++++++++++ .../__pycache__/models.cpython-38.pyc | Bin 0 -> 956 bytes .../__pycache__/utils.cpython-38.pyc | Bin 0 -> 3389 bytes SRCNN-pytorch-master/datasets.py | 31 +++ SRCNN-pytorch-master/information.txt | 2 + SRCNN-pytorch-master/models.py | 16 ++ SRCNN-pytorch-master/prepare.py | 78 +++++++ SRCNN-pytorch-master/test.py | 67 ++++++ SRCNN-pytorch-master/train.py | 112 ++++++++++ SRCNN-pytorch-master/utils.py | 68 ++++++ .../__pycache__/dataset.cpython-38.pyc | Bin 0 -> 4284 bytes .../__pycache__/losses.cpython-38.pyc | Bin 0 -> 2947 bytes .../__pycache__/mode.cpython-38.pyc | Bin 0 -> 5095 bytes .../__pycache__/ops.cpython-38.pyc | Bin 0 -> 4029 bytes .../__pycache__/srgan_model.cpython-38.pyc | Bin 0 -> 2903 bytes .../__pycache__/vgg19.cpython-38.pyc | Bin 0 -> 2784 bytes SRGAN-PyTorch-master/dataset.py | 135 +++++++++++ SRGAN-PyTorch-master/information.txt | 2 + SRGAN-PyTorch-master/losses.py | 60 +++++ SRGAN-PyTorch-master/main.py | 39 ++++ SRGAN-PyTorch-master/mode.py | 208 +++++++++++++++++ SRGAN-PyTorch-master/ops.py | 97 ++++++++ SRGAN-PyTorch-master/readme.md | 68 ++++++ SRGAN-PyTorch-master/srgan_model.py | 74 ++++++ SRGAN-PyTorch-master/vgg19.py | 80 +++++++ Super-résolution.pyw | 211 ++++++++++++++++++ 26 files changed, 1477 insertions(+) create mode 100644 SRCNN-pytorch-master/README.md create mode 100644 SRCNN-pytorch-master/__pycache__/models.cpython-38.pyc create mode 100644 SRCNN-pytorch-master/__pycache__/utils.cpython-38.pyc create mode 100644 SRCNN-pytorch-master/datasets.py create mode 100644 SRCNN-pytorch-master/information.txt create mode 100644 SRCNN-pytorch-master/models.py create mode 100644 SRCNN-pytorch-master/prepare.py create mode 100644 SRCNN-pytorch-master/test.py create mode 100644 SRCNN-pytorch-master/train.py create mode 100644 SRCNN-pytorch-master/utils.py create mode 100644 SRGAN-PyTorch-master/__pycache__/dataset.cpython-38.pyc create mode 100644 SRGAN-PyTorch-master/__pycache__/losses.cpython-38.pyc create mode 100644 SRGAN-PyTorch-master/__pycache__/mode.cpython-38.pyc create mode 100644 SRGAN-PyTorch-master/__pycache__/ops.cpython-38.pyc create mode 100644 SRGAN-PyTorch-master/__pycache__/srgan_model.cpython-38.pyc create mode 100644 SRGAN-PyTorch-master/__pycache__/vgg19.cpython-38.pyc create mode 100644 SRGAN-PyTorch-master/dataset.py create mode 100644 SRGAN-PyTorch-master/information.txt create mode 100644 SRGAN-PyTorch-master/losses.py create mode 100644 SRGAN-PyTorch-master/main.py create mode 100644 SRGAN-PyTorch-master/mode.py create mode 100644 SRGAN-PyTorch-master/ops.py create mode 100644 SRGAN-PyTorch-master/readme.md create mode 100644 SRGAN-PyTorch-master/srgan_model.py create mode 100644 SRGAN-PyTorch-master/vgg19.py create mode 100644 Super-résolution.pyw diff --git a/SRCNN-pytorch-master/README.md b/SRCNN-pytorch-master/README.md new file mode 100644 index 0000000..7a7621d --- /dev/null +++ b/SRCNN-pytorch-master/README.md @@ -0,0 +1,129 @@ +# SRCNN + +This repository is implementation of the ["Image Super-Resolution Using Deep Convolutional Networks"](https://arxiv.org/abs/1501.00092). + +
+ +## Differences from the original + +- Added the zero-padding +- Used the Adam instead of the SGD +- Removed the weights initialization + +## Requirements + +- PyTorch 1.0.0 +- Numpy 1.15.4 +- Pillow 5.4.1 +- h5py 2.8.0 +- tqdm 4.30.0 + +## Train + +The 91-image, Set5 dataset converted to HDF5 can be downloaded from the links below. + +| Dataset | Scale | Type | Link | +|---------|-------|------|------| +| 91-image | 2 | Train | [Download](https://www.dropbox.com/s/2hsah93sxgegsry/91-image_x2.h5?dl=0) | +| 91-image | 3 | Train | [Download](https://www.dropbox.com/s/curldmdf11iqakd/91-image_x3.h5?dl=0) | +| 91-image | 4 | Train | [Download](https://www.dropbox.com/s/22afykv4amfxeio/91-image_x4.h5?dl=0) | +| Set5 | 2 | Eval | [Download](https://www.dropbox.com/s/r8qs6tp395hgh8g/Set5_x2.h5?dl=0) | +| Set5 | 3 | Eval | [Download](https://www.dropbox.com/s/58ywjac4te3kbqq/Set5_x3.h5?dl=0) | +| Set5 | 4 | Eval | [Download](https://www.dropbox.com/s/0rz86yn3nnrodlb/Set5_x4.h5?dl=0) | + +Otherwise, you can use `prepare.py` to create custom dataset. + +```bash +python train.py --train-file "BLAH_BLAH/91-image_x3.h5" \ + --eval-file "BLAH_BLAH/Set5_x3.h5" \ + --outputs-dir "BLAH_BLAH/outputs" \ + --scale 3 \ + --lr 1e-4 \ + --batch-size 16 \ + --num-epochs 400 \ + --num-workers 8 \ + --seed 123 +``` + +## Test + +Pre-trained weights can be downloaded from the links below. + +| Model | Scale | Link | +|-------|-------|------| +| 9-5-5 | 2 | [Download](https://www.dropbox.com/s/rxluu1y8ptjm4rn/srcnn_x2.pth?dl=0) | +| 9-5-5 | 3 | [Download](https://www.dropbox.com/s/zn4fdobm2kw0c58/srcnn_x3.pth?dl=0) | +| 9-5-5 | 4 | [Download](https://www.dropbox.com/s/pd5b2ketm0oamhj/srcnn_x4.pth?dl=0) | + +The results are stored in the same path as the query image. + +```bash +python test.py --weights-file "BLAH_BLAH/srcnn_x3.pth" \ + --image-file "data/butterfly_GT.bmp" \ + --scale 3 +``` + +## Results + +We used the network settings for experiments, i.e., . + +PSNR was calculated on the Y channel. + +### Set5 + +| Eval. Mat | Scale | SRCNN | SRCNN (Ours) | +|-----------|-------|-------|--------------| +| PSNR | 2 | 36.66 | 36.65 | +| PSNR | 3 | 32.75 | 33.29 | +| PSNR | 4 | 30.49 | 30.25 | + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Original
BICUBIC x3
SRCNN x3 (27.53 dB)
+
+
+
+
+
+
Original
BICUBIC x3
SRCNN x3 (29.30 dB)
+
+
+
+
+
+
Original
BICUBIC x3
SRCNN x3 (28.58 dB)
+
+
+
+
+
+
diff --git a/SRCNN-pytorch-master/__pycache__/models.cpython-38.pyc b/SRCNN-pytorch-master/__pycache__/models.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..746dd92f558e00f073e7f34ca82f6fa495fd0483 GIT binary patch literal 956 zcmZWo&2AGh5FYPOHVH}jM+j62aX{KbD=1PA2%$=;#DzvEB?qDv)@tpx+wM%%9)!#|MUl{Pc@P;5{R%NuoD8n!K7C~AJvpA7q)A)KcluA*~rp|dJUNI9%t23?X>guovDTbV;rCah+txA;~ ztmUz;q>A!VV5Ifosk)2oRHC+D4_jf6>>RIHzxBU`{nsz@>5$(#!UV2pu?f4z`4ivT z%FdGP(itrbJHWc3ryir9x=-+c;2}cjs3*7?_Cq}*orUM53HD0$JyxQHg03R~yN9#M z#_f5Y3wy|v?P!f!QpCCB+%!4Iaz`1RTbv({;%siQwX0R)Z!aD^ky(aayp8@8t%_2x qjV5kQP-iQF)nFyYw@KYN!;nx?6VCZM&Vo|-f*(rW#;wG{=;wKPqw$I;DZXS==2-R@;} zFZRL(#8j|?Snz|Nvc1=3ukM;s*Tn4O>sQC8dmmTVB;?@+LaOQ-{-}R4llF}5ylUTbqFyz# zDL>}H+n4OBpR}3IxnB3bY}%xSQitc(QFVwtZ*&VIc{a(YdL4>DqYg+RhtsVz$ zZVwPbd+15%Jv2>EN_+rTHiQqJy%j!hf?wLh=OX2!wz0kM`JJ4r=9c9x2M>>#zxkMX5+ z(~sa2X{jsll=HJhL zX`SpD`Q-ge$5pHJx%ckia6z@Ee*5m;pDw%{e)GaUm8W12#}vNO&AV`%LDCJxj|YHH zd0e;MssorsEb(nfdXa2L(g(zs4c+sTRg+h00LdM&l#TIzm;35MA{_1AY{XpXj^ z*)5>?sDP#v)G4sESsg{CTtA1Xk{tqmlaIFAk3Le`y|gYn^wAd5H~45P>FA>^q_=$3 z4m(M}PIxJVT~7pd+j15Z&}&DRc~C+Fgbn=o?OziY>Z$;p_1T~QeShxf8P)pqo4GG8 z^}V7BR9YWapGyBAmsIQ8(YN2TZvY-Y`2C}=U)=qgYQ1#e!l9}0**4JeK?vIxRJczq zpH6QHlsX~=D;pHJU9hnQ6qTN)Y6~dgOyau{D40H)4acoIcMnV$^~e%jP=LY(sX@Wr z@(!|0hoJZ4r6QIK^U8f{IQZikar~?-vVk9m8ohVSk5x?F&Pw?h6#Kn`?u#BHW$0x? ztGYJd2bx`B7v=ZE*nqyFVYXADyT~<39xga=Sp+-bQYbUHF2Xwobq=qK6t9b5Y1+>m zg6gkZ=C}ztXhVX)#EJSug1`g^VNaF`g)pBWD>MhUlsGqC(_5i_%r(nrqP8oVR1AKm=PX}#YZX6k zI5pej4}pM)WjqUO2hk|}uY@MVGvBwtG;uT$od$Zg*>H^2M0n*f7(iVD;~KD7qFAH) z$pLVi2CJ|lH6$^EH~}uqhv3V=aQ+}{hA19^7vFhdd1W(@BA^ VJQ*1GG)AC{kteitDcj+R{SW&yy-)xE literal 0 HcmV?d00001 diff --git a/SRCNN-pytorch-master/datasets.py b/SRCNN-pytorch-master/datasets.py new file mode 100644 index 0000000..1f1bea7 --- /dev/null +++ b/SRCNN-pytorch-master/datasets.py @@ -0,0 +1,31 @@ +import h5py +import numpy as np +from torch.utils.data import Dataset + + +class TrainDataset(Dataset): + def __init__(self, h5_file): + super(TrainDataset, self).__init__() + self.h5_file = h5_file + + def __getitem__(self, idx): + with h5py.File(self.h5_file, 'r') as f: + return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0) + + def __len__(self): + with h5py.File(self.h5_file, 'r') as f: + return len(f['lr']) + + +class EvalDataset(Dataset): + def __init__(self, h5_file): + super(EvalDataset, self).__init__() + self.h5_file = h5_file + + def __getitem__(self, idx): + with h5py.File(self.h5_file, 'r') as f: + return np.expand_dims(f['lr'][str(idx)][:, :] / 255., 0), np.expand_dims(f['hr'][str(idx)][:, :] / 255., 0) + + def __len__(self): + with h5py.File(self.h5_file, 'r') as f: + return len(f['lr']) diff --git a/SRCNN-pytorch-master/information.txt b/SRCNN-pytorch-master/information.txt new file mode 100644 index 0000000..b793941 --- /dev/null +++ b/SRCNN-pytorch-master/information.txt @@ -0,0 +1,2 @@ +Programme écrit par : yjn870 +Code disponible sur github au lien : https://github.com/yjn870/SRCNN-pytorch \ No newline at end of file diff --git a/SRCNN-pytorch-master/models.py b/SRCNN-pytorch-master/models.py new file mode 100644 index 0000000..0465ddc --- /dev/null +++ b/SRCNN-pytorch-master/models.py @@ -0,0 +1,16 @@ +from torch import nn + + +class SRCNN(nn.Module): + def __init__(self, num_channels=1): + super(SRCNN, self).__init__() + self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2) + self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2) + self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.relu(self.conv1(x)) + x = self.relu(self.conv2(x)) + x = self.conv3(x) + return x diff --git a/SRCNN-pytorch-master/prepare.py b/SRCNN-pytorch-master/prepare.py new file mode 100644 index 0000000..c88dce7 --- /dev/null +++ b/SRCNN-pytorch-master/prepare.py @@ -0,0 +1,78 @@ +import argparse +import glob +import h5py +import numpy as np +import PIL.Image as pil_image +from utils import convert_rgb_to_y + + +def train(args): + h5_file = h5py.File(args.output_path, 'w') + + lr_patches = [] + hr_patches = [] + + for image_path in sorted(glob.glob('{}/*'.format(args.images_dir))): + hr = pil_image.open(image_path).convert('RGB') + hr_width = (hr.width // args.scale) * args.scale + hr_height = (hr.height // args.scale) * args.scale + hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC) + lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC) + lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC) + hr = np.array(hr).astype(np.float32) + lr = np.array(lr).astype(np.float32) + hr = convert_rgb_to_y(hr) + lr = convert_rgb_to_y(lr) + + for i in range(0, lr.shape[0] - args.patch_size + 1, args.stride): + for j in range(0, lr.shape[1] - args.patch_size + 1, args.stride): + lr_patches.append(lr[i:i + args.patch_size, j:j + args.patch_size]) + hr_patches.append(hr[i:i + args.patch_size, j:j + args.patch_size]) + + lr_patches = np.array(lr_patches) + hr_patches = np.array(hr_patches) + + h5_file.create_dataset('lr', data=lr_patches) + h5_file.create_dataset('hr', data=hr_patches) + + h5_file.close() + + +def eval(args): + h5_file = h5py.File(args.output_path, 'w') + + lr_group = h5_file.create_group('lr') + hr_group = h5_file.create_group('hr') + + for i, image_path in enumerate(sorted(glob.glob('{}/*'.format(args.images_dir)))): + hr = pil_image.open(image_path).convert('RGB') + hr_width = (hr.width // args.scale) * args.scale + hr_height = (hr.height // args.scale) * args.scale + hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC) + lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC) + lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC) + hr = np.array(hr).astype(np.float32) + lr = np.array(lr).astype(np.float32) + hr = convert_rgb_to_y(hr) + lr = convert_rgb_to_y(lr) + + lr_group.create_dataset(str(i), data=lr) + hr_group.create_dataset(str(i), data=hr) + + h5_file.close() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--images-dir', type=str, required=True) + parser.add_argument('--output-path', type=str, required=True) + parser.add_argument('--patch-size', type=int, default=33) + parser.add_argument('--stride', type=int, default=14) + parser.add_argument('--scale', type=int, default=2) + parser.add_argument('--eval', action='store_true') + args = parser.parse_args() + + if not args.eval: + train(args) + else: + eval(args) diff --git a/SRCNN-pytorch-master/test.py b/SRCNN-pytorch-master/test.py new file mode 100644 index 0000000..5541c43 --- /dev/null +++ b/SRCNN-pytorch-master/test.py @@ -0,0 +1,67 @@ +import argparse + +import torch +import torch.backends.cudnn as cudnn +import numpy as np +import PIL.Image as pil_image + +from models import SRCNN +from utils import convert_rgb_to_ycbcr, convert_ycbcr_to_rgb, calc_psnr + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--weights-file', type=str, required=True) + parser.add_argument('--image-file', type=str, required=True) + parser.add_argument('--scale', type=int, default=3) + args = parser.parse_args() + + cudnn.benchmark = True + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + + model = SRCNN().to(device) + + state_dict = model.state_dict() + for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items(): + if n in state_dict.keys(): + state_dict[n].copy_(p) + else: + raise KeyError(n) + + model.eval() + + image = pil_image.open(args.image_file).convert('RGB') + + image_width = (image.width // args.scale) * args.scale + image_height = (image.height // args.scale) * args.scale + image = image.resize((image_width, image_height), resample=pil_image.BICUBIC) + image = image.resize((image.width // args.scale, image.height // args.scale), resample=pil_image.BICUBIC) + image = image.resize((image.width * args.scale, image.height * args.scale), resample=pil_image.BICUBIC) + + + image = np.array(image).astype(np.float32) + ycbcr = convert_rgb_to_ycbcr(image) + + y = ycbcr[..., 0] + y /= 255. + y = torch.from_numpy(y).to(device) + y = y.unsqueeze(0).unsqueeze(0) + + with torch.no_grad(): + preds = model(y).clamp(0.0, 1.0) + + psnr = calc_psnr(y, preds) + print('PSNR: {:.2f}'.format(psnr)) + + preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0) + + output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0]) + output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8) + output = pil_image.fromarray(output) + + l = args.image_file.split(".") + l[-2] += '_srcnn_x{}'.format(args.scale) + args.image_file = ".".join(l) + + output.save(args.image_file) + diff --git a/SRCNN-pytorch-master/train.py b/SRCNN-pytorch-master/train.py new file mode 100644 index 0000000..31d5fd7 --- /dev/null +++ b/SRCNN-pytorch-master/train.py @@ -0,0 +1,112 @@ +import argparse +import os +import copy + +import torch +from torch import nn +import torch.optim as optim +import torch.backends.cudnn as cudnn +from torch.utils.data.dataloader import DataLoader +from tqdm import tqdm + +from models import SRCNN +from datasets import TrainDataset, EvalDataset +from utils import AverageMeter, calc_psnr + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--train-file', type=str, required=True) + parser.add_argument('--eval-file', type=str, required=True) + parser.add_argument('--outputs-dir', type=str, required=True) + parser.add_argument('--scale', type=int, default=3) + parser.add_argument('--lr', type=float, default=1e-4) + parser.add_argument('--batch-size', type=int, default=16) + parser.add_argument('--num-epochs', type=int, default=400) + parser.add_argument('--num-workers', type=int, default=8) + parser.add_argument('--seed', type=int, default=123) + args = parser.parse_args() + + args.outputs_dir = os.path.join(args.outputs_dir, 'x{}'.format(args.scale)) + + if not os.path.exists(args.outputs_dir): + os.makedirs(args.outputs_dir) + + cudnn.benchmark = True + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + + torch.manual_seed(args.seed) + + model = SRCNN().to(device) + criterion = nn.MSELoss() + optimizer = optim.Adam([ + {'params': model.conv1.parameters()}, + {'params': model.conv2.parameters()}, + {'params': model.conv3.parameters(), 'lr': args.lr * 0.1} + ], lr=args.lr) + + train_dataset = TrainDataset(args.train_file) + train_dataloader = DataLoader(dataset=train_dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.num_workers, + pin_memory=True, + drop_last=True) + eval_dataset = EvalDataset(args.eval_file) + eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1) + + best_weights = copy.deepcopy(model.state_dict()) + best_epoch = 0 + best_psnr = 0.0 + + for epoch in range(args.num_epochs): + model.train() + epoch_losses = AverageMeter() + + with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size)) as t: + t.set_description('epoch: {}/{}'.format(epoch, args.num_epochs - 1)) + + for data in train_dataloader: + inputs, labels = data + + inputs = inputs.to(device) + labels = labels.to(device) + + preds = model(inputs) + + loss = criterion(preds, labels) + + epoch_losses.update(loss.item(), len(inputs)) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg)) + t.update(len(inputs)) + + torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch))) + + model.eval() + epoch_psnr = AverageMeter() + + for data in eval_dataloader: + inputs, labels = data + + inputs = inputs.to(device) + labels = labels.to(device) + + with torch.no_grad(): + preds = model(inputs).clamp(0.0, 1.0) + + epoch_psnr.update(calc_psnr(preds, labels), len(inputs)) + + print('eval psnr: {:.2f}'.format(epoch_psnr.avg)) + + if epoch_psnr.avg > best_psnr: + best_epoch = epoch + best_psnr = epoch_psnr.avg + best_weights = copy.deepcopy(model.state_dict()) + + print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr)) + torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth')) diff --git a/SRCNN-pytorch-master/utils.py b/SRCNN-pytorch-master/utils.py new file mode 100644 index 0000000..3a4540a --- /dev/null +++ b/SRCNN-pytorch-master/utils.py @@ -0,0 +1,68 @@ +import torch +import numpy as np + + +def convert_rgb_to_y(img): + if type(img) == np.ndarray: + return 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256. + elif type(img) == torch.Tensor: + if len(img.shape) == 4: + img = img.squeeze(0) + return 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256. + else: + raise Exception('Unknown Type', type(img)) + + +def convert_rgb_to_ycbcr(img): + if type(img) == np.ndarray: + y = 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256. + cb = 128. + (-37.945 * img[:, :, 0] - 74.494 * img[:, :, 1] + 112.439 * img[:, :, 2]) / 256. + cr = 128. + (112.439 * img[:, :, 0] - 94.154 * img[:, :, 1] - 18.285 * img[:, :, 2]) / 256. + return np.array([y, cb, cr]).transpose([1, 2, 0]) + elif type(img) == torch.Tensor: + if len(img.shape) == 4: + img = img.squeeze(0) + y = 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256. + cb = 128. + (-37.945 * img[0, :, :] - 74.494 * img[1, :, :] + 112.439 * img[2, :, :]) / 256. + cr = 128. + (112.439 * img[0, :, :] - 94.154 * img[1, :, :] - 18.285 * img[2, :, :]) / 256. + return torch.cat([y, cb, cr], 0).permute(1, 2, 0) + else: + raise Exception('Unknown Type', type(img)) + + +def convert_ycbcr_to_rgb(img): + if type(img) == np.ndarray: + r = 298.082 * img[:, :, 0] / 256. + 408.583 * img[:, :, 2] / 256. - 222.921 + g = 298.082 * img[:, :, 0] / 256. - 100.291 * img[:, :, 1] / 256. - 208.120 * img[:, :, 2] / 256. + 135.576 + b = 298.082 * img[:, :, 0] / 256. + 516.412 * img[:, :, 1] / 256. - 276.836 + return np.array([r, g, b]).transpose([1, 2, 0]) + elif type(img) == torch.Tensor: + if len(img.shape) == 4: + img = img.squeeze(0) + r = 298.082 * img[0, :, :] / 256. + 408.583 * img[2, :, :] / 256. - 222.921 + g = 298.082 * img[0, :, :] / 256. - 100.291 * img[1, :, :] / 256. - 208.120 * img[2, :, :] / 256. + 135.576 + b = 298.082 * img[0, :, :] / 256. + 516.412 * img[1, :, :] / 256. - 276.836 + return torch.cat([r, g, b], 0).permute(1, 2, 0) + else: + raise Exception('Unknown Type', type(img)) + + +def calc_psnr(img1, img2): + return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2)) + + +class AverageMeter(object): + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count diff --git a/SRGAN-PyTorch-master/__pycache__/dataset.cpython-38.pyc b/SRGAN-PyTorch-master/__pycache__/dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e71b28ebc588e88a4e99a8c532abfa9e5afc333a GIT binary patch literal 4284 zcmcIn&2JmW6`z^?8XdJ*ZdcAD$qmWTMp^Hzc))#q+}cEp-ar;eC?Z=_kQzx zZyzr$RT-Y6S5)}=3S-_&y8jPq72G z8*uGyH&ru_6{^k$y=FUZ(a8K`p|ORM{{SLp4IZ(Eh$@j4*-tq3>O}4n*03Wl^6~FP z0{W}k?F}PnKG0tU zu`yR-j7@a02~v)YVCgZ`oMKDkielaBDgU$h*xqMnTw8lP?*}z4`UCAWrECr*9Z=iZ zAnt3C<=P%JRY%+1Ea_|S;a+H(UMuVWK9)+m&0Gx!v38D=zIyjeXt8lui&R3hoyY0X z8T(HI`nKOYe4NKJKiqq;d3W=0JI+5++2HVTKY^A!QD2@N{`$)^)ybkP&Du$P_(*1L z+3fY=!~MOTU+vy`G(5;;t8=G^^tqe`k;?=-d98Hg zMi?gjM1|qTy)MMfMwNt=pe8;1NuLB+R8~lk&t#269YT9M2gaUqnYyl! zz~y6skXLC~w3WO@ovS2nka!t_TC-yY6}dzMYY@7U^uu1<%j8g3lx+6%qfGX+n`ctR zky%`OboMBb+J!HZUR%4^1O5X?lKIY+$ycfQZ4zdAvk}^Z_NNx5)8w>MCR}cQVuh~= zo3HYdD;KQ4U}=IfLHGemPFfV;+r@n$NbUkfRvP!^cQG(`{+SW97ls%QLo!BAdYCv_ zo~>twTBe(Z@Q_2u{ARhYQAyZHU_mDD(y|3#V`HJLu6@c!7J8h89=rcYSz}juV_#Le z)o~3o>eKhk-2ySu-5Ob!<}FXPEzGkET=G{zlrMumBKsaI?76msHtSxA*hdxlw}OqY z6u7_GFJ@w&D@2)*e`-y4JkcN`=Q*RKkBF_ZD_xxZ-zbdzGu$dG?5kgDv2&os!@Z|u zGfBRAN+NivgupuW8f+4xwRe^IywSa{2_qo$+Vc3op%?U$Nl)nPOgbnCj<=3p+{I&sKl~=`0ep8%WodII@S|AQ~ z&s@T9zuAk!P*=mSmqo`Z)$3vS@o_UPN96Zt2j3@gm&EyAW*d@3G=!?0tm^aXvTvKd z7McqWUH>hV|1AnFsW?}^>8HbR5{ziQ=|goYq&j~Oq6|jLaQ`gy*d{nRg^-&I0mq(% z7Tbgt>jj~OykSDi`tw7}G?*A81^zeSy^>^Jg6_HNJ)^gKbS65SVevN>%}BXLgY=U^hJ z-KSLzZ;6xZ7gxK?YGk4rt98+svD)gCokmvwR+$Zq9p!exU>K=PM~RaR<5>|08wtZk z7O_QSflI2CtVJ%Erlv%gImA>V`umP+n0wkO(naUd-H zc8BH}qIru{OuV&0y+0)J3|26=^L-5F6hTWis9U$d1Ym%*i#A*!0{J7-`K*gg-v2Sm z^kP?Vv8nIU#U@{E(L~~-S$aun9Ybx4_!t4HZ~!KRvBPQ}Qb z^x6x(Ug4r`Ut=S`@FulNAyB85b%T8_FoWip`VrO`RST=A;D`69{~9y1{DJ^H*JI|^ zifU1d+`mwe127FtOc+_j_~dn!Q09 zmrSyb0WFda${$cZX^3Q~h57_#*zn5)xz)kPj!ErWoeQp^=mU4AVfb4pDzhg9u&ZK~ zuL=rY>-;S-^Hw=?7m@O9%$DzvAZV8Hi5xbk{+vxVF|dbX-0ySWk<`;dS$JmnB}Tf8 zU;6CO_2zNA7xz_DC0Ty~HpD%YSsB>;9hJ+lp|26T2Y`k>S9j*>Zehz0iK9%G14xYR zQ$Wu2SYRi(-1JzVdY}4*BmaaEL^O=d&vlB+b6xq@LX2H-812A928Aqc*1~$o!kz&d z!(j$FL}c=35P>IuPIZb*CcJ6)C`|^b#I4To31V8ukt6~`K5p{pU(j4bf9t3<{7xoA zda2>#i^tToB{S9VkkAamhkKLC&cU2R%~dAzCMY>UngJ+vjvpdYl_TlkJ~E)i`Z8a( zPOi_p4cC4OB>N(eo{OPvtk$M=2_?Re)DD)lqv$(Kub97$W0j=&1|=wxrj?%<)}s#| zmS2%h|F}08mKlMuAjK@hrP|Ft?8Yr+fN6BkoqGmk?cXc&v-{)~b6I?xzAV1A^5)9T HmFxcn>R3{% literal 0 HcmV?d00001 diff --git a/SRGAN-PyTorch-master/__pycache__/losses.cpython-38.pyc b/SRGAN-PyTorch-master/__pycache__/losses.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec84dd5e247011602b434b4c1ab0b57c8d2074fc GIT binary patch literal 2947 zcmaJ@&2J>d6|buPoblLRoFy!~Ngz_PVF+tJ0w^JXLA!w*vd(JFMxtfZ>g{UJwCU-Y zt?C(Lt9@D%4iUny2Jy?OPj>Q&Wy zzxQ6>SzK&0w5R@|;y;Fr{hbD9n~T94=sE(CO!9(7%|9QANO0ms;SAi!1JZvyYECkB%bsk@gWHy2_Xq&2uV1Xv?SkP-PSSmv-K`F?Mfx3-WsZbPDysQ(QT;d zAAn>m;*v!|1|P94ZqI2&&PU8;kqeEXY2D0{a-*MbRmU9b51pg-U;jLMt^M`d(dUy- z|8)7;zpWj8a{Uj*(Z9Z0!{qlP`OB|=`{mlv?(x;Jwyhoi2QBsJCJFz+^*SlamH@2_uBiu^lLv}y7zm@ zV1!`(n69YEaIo#d`((?PA{~&vuNzljXgBWn7t5iE;1di8j8>CV85v3#6oz>o}Ptd6C>LGSeO< zDj8%|rnFg9*}ZYDG9C9+BD;ZcbyjReE!Ddl!}q<+v`V-?9>9Gjw5+ZqR^`3YIK=YN zAvdi!PK!kAI6h?md>2RFcxCgB#sQnRufP1<%bPbc{eCqZZQd#K38Z=T+2Q8R&kn2p zP!5Y>FV8k_sUdDK7-X9px8Htq{o1YFcj&p-28qT+H;bXxnO+_3s_#HMxmz!RF!#H> z&7bC%`DM}Jdn>0XTRkyb*R=*$gc#K#vIx@k$gm}>0K>N9crcXXg7{7x-y0`IGh@#( zi}SaoSy9BX`U&Z<#V@8OnGqTo8H2{x(e+y(NDr(=^4WJS zhqQu(1e;Dp?ZU#&eX;M)$COZL9AMU|{W_QlqFc@eJLKyC_OsAuf?if7Ri&Dj{0yRQ zXrC5|U*C?z#%<#j$!?~M=%Z_8g1(BkvZOMBZkRKWrstag34bS1lI&v<%|x5G#RYyr z>^*k=`SZ11u3n_0ebvREdYZ^jiI6STt3;k5^04S3|38YZ$xxx_gccDg1`o6ov+5Th zK8YI{+Q>nUKS05;af6bR0pc<^VE1|C+ZYO{tRZ#&j!<9|!vDk%0f@T31oB|~Sj;a|886x^|wlC(q0PoCs2RtE**Qm#meJJMC{f6IiA7_NRhpcLL*tPRjekFjt|EufV zEUV#$OQ$E3o=f;HnTe@8WBb8F^qo8PyuRkbA9hE!BBn)2&Y1w-9ksBAH{McAel6un^!0dR?#)3EoQSyhhN3-7QZF-mgYE^ zM^xGr5n740=mbs+gCBu`^X4#^a{vJz-ADm~i}ZMRy{Ua`*kB)zD18GwzCL(${&H_( z;uq-J78bF$GQZ9l3t>A2A@a+&1-a>HB%?gV*9A(uBAwqJ_Y^;*NK=vw`H~wh2(GNO z-8dt(pFFE&sh-CYc#A#>Cb>rbR&-6{Umuo}7i5$D{hPLmH}%8@=Qr8rG`XAb)u!0L QPBv`{HhIEXSzfvLUr|1x8~^|S literal 0 HcmV?d00001 diff --git a/SRGAN-PyTorch-master/__pycache__/mode.cpython-38.pyc b/SRGAN-PyTorch-master/__pycache__/mode.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7246c1c53c7695b198d6667fbe4e560816250f2d GIT binary patch literal 5095 zcmb7I&2QYs6(_meCAr^UT7BAM(qsgcxiT-ihBGAWjP-xvI5_{c{6a3wW?N12*IeP= z=)270)Dzt8MQeQ?Mr|$0-TB}~KZ=sdz~dcnAU55=2_V2k^(@A%>t4^}uITfmbj^=C z+zb=45A+IsASFfcN(pM=PI7) zUl~Oa%19N;J#CcPQ+S=JJ&kF5>QEK>C@XTYa?ikH?rOr?Rk}2yi993)JViXESYw$7 z`YXRw7UBXc&njb`Mf58By;7WwOS>xOm4}r$7Z>C5t~RX3#^@-m8&^li zWNVC$Q+xmKvB|10DMJ%&t);NihK3b;Vmda{ec+7~dvX`pTV3wPWQ}wt{}Q`D3HwjM z{$^~pls$y@{V&Do-HCOzxHdY2oz_Gf`=LD1!nKHdRwsaMDw^#k5$z&~g2 zJrf(O#irTJgUqlgo{ej}>Tn7kJO{lkv9zlUr(wr==tcmq;e5w>XCS@s4e4xbLG~Ox zo8kYAo2efr9)75Pbmn{dxDF4i!@0N~7tzkgt$043fzKDf&Be3e4#f-cp=o7!c>fd{ zy zq^S&#gL^^R0Bg`cf&Po}v>acIkIpEp)=-)_S&W^GkH;sbmG~r{W3cqpH!MB%blg1D z?!OybGIA`mlgP>N&k;G}8xWs*5;^hVC;T{yEVeL?H;peHh@(TXmBx{j)n#lQp2)E7 z5>t?#j@(T2&L(q{OlPq9NKcIGcbtJM))V8#a>`hK&k4P-&$kn6KWE)}lu7cduIQ{g zk-y{FMzp@Uwib9vp|=@2ANTo&$D_wcDazw3Anj~o`eD~u^IVY_J=sP{0h4-OAT=Zw zy|^0mI~z$h>bQZoGe6+oIjrdSJhmA2nHMbj(PBUelG9>cHJ;^|AuAo`|9$TQLeOAI~%y$9{2wZLw4pS zOWt0-{^m`mimA-7@I#RC<1|ooH@l(N6Rz<4y(AZIGgKp9rSTe(IuY^|pCZyCLd795 zxEDD%9J{cFlW@dl&+m1U3Y~)!L7nF?zatWz1WB1TCF@igJlO&Fdm_<99~+>fOtfCl zHr~7awk&tq{y_Mlt-s0K(1x?Q8+rnfV^{W1#qkDx?A!8S0%`VaYsce#r^{Vto2zbT z<71bzM5oh8bYFNOcz4UQ^G}T97HvfLw%mZ5VC`bSov5?!u}%2R&fUhlH`Yni}&1iFKOKobcOirzM4cYU?V!QO6rJBL{6Qt%!MbdJ`HTg-&{t#_>lo?G(wNeL<^^<;BcI zH;tForYdcH3Kr7+rwelxPjmq(0R`kcU?GYiLAV^t04M8!d73Z)0<(dZ$64+H)o)u7tvt5FB%BZ#n2>0n{fPD2hLm&KX8{F*+i z06OY%UP8bMfL}#I#i>1I)B;?ph-rYM3L*Wy3;?3_)To7#DxgMf9y6)Dk9o*Pv+z!x zmD9Hl^eizKS9VphfSv*@I>^J`>?hh<26ojj6VTw`=t2I#%(|G5YpemlQzy!*tO@AX zh;yTbxIr)*kn<3miUC8ZCk~H}!16{@87+$Gdq?SPhD|`bW2_Y$>74t&mhVhs|HtFp zo|;NOR9;kmtSl=(fo}mSW^T&*#LI}FAMy$)KT4!YWSYnf$RnN45|p2z~oPI2ix zvgl?Ymc&nyv_9rXXlRiLRr4+${J#1FKY@myBytQSExLc|Q_;;+0(9_w)%A8t^w~o| z;lc}Ri~0XfO|;V@D6cP|()YSM6?dz97> zK11YLBIk&lC-M?VqW1@058x*)UHo||Bk+_KY4TN)MUfk9qec!;zXN|TMqE_Td4V() zX;#i1Kov`}A9K_@PG8W8wZ{9Q%elLa;Ye%`yd=Ac()%*Mgkf6`J-64M;xCgDxv5u( zdzHu!KoS+ziuUgWekAy7Br>S96G(WSW>G1zwbdw5JBhl^DLL>rh+L)tbK6O)jID1w zU195_E6Jq?Df4Yg-o)fY7hWWO1uh~5NRLQEbxLkB$>miYX-_6Lt*SnYG)L*N3Rw{e ztE`<>&uT5qEQ3}x{sA;RmA=SN`{@gxQU51>jZ;?!sVldiy6~MUQy0?KICWY3sVk3U zB)@p~Q&(v}b(Ll6QYeLe`g2^UYMQD(8X!?+hq*Y%@<>hTd;{slkm<#wuT)mR2$I~W zj`Tubkwqkp5)ugbQ3FY%%nC>vrfA|j*<=-5(iSf1jJ%|j=4M$n&Wi6 zJ97Z>5%&1D%gNuOfkh&8>%Ml`_%)Kf4bq0!^Nrd@*ZVSIo3(y^r zNrb;c+>eM*wDX%p-Y4=dky}J4&fA6WjBpv#e3@oaC_jl^ewVse_K6Hpi|lv$*ud{G z5G{gGoK~NTQ;JI}-bJ)6l@p+7FGzRTrbO#Yy&iuKQoE3LC3fQ%Nm5NEo5ByGC3(xp zJ>!X6egu%70(!u~DDol+MX7;&cgYCD2WsMUJsEx8AuSJ!Y4p z4WIrywA2BT;WC-nm5B_=A=>1g%q|iS$Xg{hbigrJppi@On6F`7K<76oG$L{*p_TNn Q3iLrP8A~$e&&>1x0~-7Qxc~qF literal 0 HcmV?d00001 diff --git a/SRGAN-PyTorch-master/__pycache__/ops.cpython-38.pyc b/SRGAN-PyTorch-master/__pycache__/ops.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37efec45ac7da4bb1d873544135a40a1f278fef7 GIT binary patch literal 4029 zcmcJSTW=dh6vt;~uf9da4OIwt1gHiYN>o!)&qP;R;mw6Btf7?UJ_nq&?PV)U>T0V}z2P-F7IsPB`P%Bf_ni?Ya-Gwg(=Y zJK%A^Nyu}Ry89xUYqQ9r#Cc{9KFp^X*0ksM4$GMbZz zY?tLUXStQzxy9+SUhZR;lFZt-p*rx6!dt6jA5d-|u$=6=foI+vKo7fq;N@h$bU?0= zZ^{om&T=m(-)?aNQ*oM47>v~#yQ-3;L9`R5Y21~n+8<_zPnEq`fIE=MC|0h_M8ad` z4MNV7bX(cCl2C4cxA^n*Uq?S)Zd8;bhl5xM>{$6hkfcc#1j@SXLd}YhH+Yt_YPq=6>gaYFx zN@#QkbTFo>6a-N>lrjj$K1oaS8j=wO*kMtV%}7jlH3++`Ww!! zn=@`drrPUpvR4vcKuIWAP#ylrkMiaUNSQqnQUG2D}bc_wLh$1-%6kc#f1Qh8Ms+!7#RV zZIsBzAe0>=m=?DINrxkvr0r@XJwVb^{(3C0cKgv@3y!##q=Rl4#SzNEq~KXVJTY)` z9n3K_o!lwV1YPBzoiMhB?WmuF(i%TN=;n@w(9;lF&DNNSv(UL{fNA&!yARNT`z|oj zhQc?n^(GqJCDTXRUlb1RWFL(&pmqNC>~v3Lf0u*+F`a(amv zip0C{O1uZAVJH$S&{h^4r^qh=4Vw(nYq!MyC%$H=jH_=6L*8v4hvZ%#cU+;n0DC!G$Gfz#x2X5Y1qNnG>VWiES2u zD2j_{-bZsNJyRYpgJ0^@1Uv$fY*e3_UcsceboAHyfaxwUuDJUJZA(>wp8uS-(&j*h zy+Jn?Ed$%riiPlD$mF+Rj=?soSk@xRu`&ad)+}cq66}FVp|(z;P4wsXG=&=V(kavg zyU3KE!f^8>OI2?s_u}rx&al&gSsxQHuMjYk5FHtX@GUK4Cn?xpbYOJWCcpOSQbnLAf_kJ()7rn( zq8d+RB$A#vmnRT{;*W8VEiivW2ucgh?Md!J4nUTS>Yyg z{LwKd7p3{>t#l&}_wGMyj>8y!IT;MN@&tMs3j+QV!(S%-M`8Eg8Ezh*+nsQg+5_zGx=DJ_#e-uVdb#_2a)9C7PX3og;x~lw7&Hs?k@UkyYf$>4_ Q>sGy5KV7fY7wf0~0{w#tegFUf literal 0 HcmV?d00001 diff --git a/SRGAN-PyTorch-master/__pycache__/srgan_model.cpython-38.pyc b/SRGAN-PyTorch-master/__pycache__/srgan_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22a0be127c0f780a2c1e53d60dbbf024f2e90569 GIT binary patch literal 2903 zcmaKu&2HO95P)}=KO$wxNo>bS(>6_$Hi+EDa8lHVzzyhvv@|$c&hF06%=IO^QKg1$5LgLVXI4qz6zFMgD^b=U1WxgdXs*Kx?5mW-} zkT{^ziBo|EY-v9YlaPB^$_E^?*RowUq&e z@}zIbYLvMCrk5mPEGN>PZ2Ykua!|S=+6yInv?4Xn&wL8gh7W@VW@8#&aRo#hqWOMm zpJoI+Bj6W3)_!5HbXbvbOlCuqR=JqqHBj(L)KbS<4aK84^|$*>viSu`-9A%&vXZ#1 z(91-h{PhEjUB1`c5Fr=M)yInqi_PUwY-ee=xsgP#K^tZ7`pxI>`q^e0q;cAg!sZK} zwz=2ogw3_prw>)_ z7}>iyIwbS18zoWZy7NomuosK_&=^iqQLb~$%6qA@K&evo2{cryW`?N>C?ty=>t(Vu zWoCf+?d!5qW=tha8Uoe~%e+Z@qbs~l7Ze;%EJyqNB8piQmq4f`hjH-BsQw5~!3hxa}$){xK*Wn}w|IOx<#xi%&RaDvn$Xvp_6QeU8XgV7QMozSdHyXy}x}vrg!*YelQP@;T2Cn z#cKF$gJVAV0xIfdy8SSYA?3;w zCBDWXD*2c;1hKj+7P~sF4wMg?GtY?C_dtD?)uFbfV>viRbQ&eE0fvXjYb&tN@VYAN z%aQPT)QJ*gwz?|Z0HfZL)9WRy*jKq5!ff}gi?Y!^kUe?{=g2baW7~lPYCvBxFUcb?4479 zudb-PJmH9Wb?F->$J)&&P$={w;0nf?BI8jI{x`Ai!cax5IYlfe<2l6$AQor^6r-B~ zf`PwKc}E=6)ye~`8Ooi8?pRQ66we(Si;ZfYsszlb2gUB5;sO+p3t&Rh2#SKD2VR?h zVGdm!<9hDQ8To%?6VN8ua(9V#&%*jlxh^QaipMa6~y zdI+z;(~6+Oze$AL#k0|jK*#UEK#nJYe}h6X&3+bo+q-AE#=phx@p@yL;ohT1CC#qS zs?*S@NCQ~oZ%Rx3%qK}=}-3i95^>mU$C7VM6tP2Knh Dl}CUS literal 0 HcmV?d00001 diff --git a/SRGAN-PyTorch-master/__pycache__/vgg19.cpython-38.pyc b/SRGAN-PyTorch-master/__pycache__/vgg19.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7f718cb7a855ceec80cb644534b24be2a17a728 GIT binary patch literal 2784 zcmbuB$!^%| z5WK(HfKRYTpFyOMMkwk=$uKB3L$hv`ERL}=e3WK969esbJ^V1fsLjqAUq$rs(^-`P|(vtRd zx@1X5y4Rv)zZAA8IXH)xI2Wg<+jkz{t6KcyuFz0x*eMtN42eZ-XmMq}GG7Zr#MYUi z4QYI5oY`9Vjx?opEwngu4#gGfe3vdrx-um_nU)!ul{uN01vwxG<&YefBXU%Z$#FR$ zC*_o!mNRlz&dGVXAQ$D5T$U?xRj$c(xgj@C*Dzj6o6%2Nyz*X)L-90oY009t{Pg*^ zQs5$P9MP?5sUTLBS|gBE1CINfiNL{rj6E8MoWkhS!le<5SGb7R#t-JXeJtLHqM4YD zMv^%W&f7sFu2n+HY~4z%APkg?U$Xr61){@`%TJ;}Mde4I-@bLbd>BMu#m!dvNuzdw zY%PBKru^XToA_B%HpAv=Ehs-!%~MsW*Msu$qx+v8-F*1`aZ^>F-Kw+_N)wGu{m5%hnMq^ixv#kRklq*M^L zIp}6~`k9@?X*HW6D(JvYzm$V+rdPV-_e#NZV7Fh&K{wMY-9>4yX19BuT^~WWX0M;w z>z=uHV6Q(M2UNO8qmXS5x|!Z^d;Q_|{N8XB@P&Y>biY3w2i;6>xc$DneHw*Mz(KdB zH{5=Ixcz&@OyWfCR-hE!`NZqrZ$;B6@lJwD+*UzU%HnbOJxaVNIQQu>Bxy>AmFI!N ze5~Ms2Y9+sl2@2hHT=%UB#=crv7;b7DP=Kb{ob5Sa=jWZNz#A*`i+sKeZLx3qR98V z(=ZD!Y8R5m9zBEzqhJgg!$!t1FpTF zFpYplK%-2fpi$5m(->$BG|n^*8V5}~2F)YhoXbv>bG!L2wEif&B7C?(ki=ai&64MfB3AD_#3|a=QFs*=AK&wovpj8kW zGoI!eXbrR;8*&{P^UA#3(5AMutsU*^l=gI5XLMHQbY2(qfF9ICdRUL>Q9Y)|^@N_( zQ+is@=vh6d=kCnR6q1UEEUA2SNTXYH}10;haLnOl_BP63FV8Qr|6{itc}sI%%q}D@u~% z=tqb!3us>a_fOH^kNry~Xa4^NXuF>EBHwL)J+DRc@x7u;ivfLp`F@h|eSDI&LyB|0 zf8MTy-6PzURX`V-?=|ge5QhE6%^wj5ILh~eml&F;yvloKp}?79I?2ZLad1(KYRyJ> z<;hl?VHi|nBqE+OiNiIC$N#n(4Mig*?gRG41HE_a`NVwxy@RgvyY&7<^oHcn?>!@D F{Q=EKe@FlT literal 0 HcmV?d00001 diff --git a/SRGAN-PyTorch-master/dataset.py b/SRGAN-PyTorch-master/dataset.py new file mode 100644 index 0000000..9b23cdc --- /dev/null +++ b/SRGAN-PyTorch-master/dataset.py @@ -0,0 +1,135 @@ +import torch +from torch.utils.data import Dataset +import os +from PIL import Image +import numpy as np +import random + + +class mydata(Dataset): + def __init__(self, LR_path, GT_path, in_memory = True, transform = None): + + self.LR_path = LR_path + self.GT_path = GT_path + self.in_memory = in_memory + self.transform = transform + + self.LR_img = sorted(os.listdir(LR_path)) + self.GT_img = sorted(os.listdir(GT_path)) + + if in_memory: + self.LR_img = [np.array(Image.open(os.path.join(self.LR_path, lr)).convert("RGB")).astype(np.uint8) for lr in self.LR_img] + self.GT_img = [np.array(Image.open(os.path.join(self.GT_path, gt)).convert("RGB")).astype(np.uint8) for gt in self.GT_img] + + def __len__(self): + + return len(self.LR_img) + + def __getitem__(self, i): + + img_item = {} + + if self.in_memory: + GT = self.GT_img[i].astype(np.float32) + LR = self.LR_img[i].astype(np.float32) + + else: + GT = np.array(Image.open(os.path.join(self.GT_path, self.GT_img[i])).convert("RGB")) + LR = np.array(Image.open(os.path.join(self.LR_path, self.LR_img[i])).convert("RGB")) + + img_item['GT'] = (GT / 127.5) - 1.0 + img_item['LR'] = (LR / 127.5) - 1.0 + + if self.transform is not None: + img_item = self.transform(img_item) + + img_item['GT'] = img_item['GT'].transpose(2, 0, 1).astype(np.float32) + img_item['LR'] = img_item['LR'].transpose(2, 0, 1).astype(np.float32) + + return img_item + + +class testOnly_data(Dataset): + def __init__(self, LR_path, in_memory = True, transform = None): + + self.LR_path = LR_path + self.LR_img = sorted(os.listdir(LR_path)) + self.in_memory = in_memory + + if in_memory: + self.LR_img = [np.array(Image.open(os.path.join(self.LR_path, lr))) for lr in self.LR_img] + + def __len__(self): + + return len(self.LR_img) + + def __getitem__(self, i): + + img_item = {} + + if self.in_memory: + LR = self.LR_img[i] + + else: + LR = np.array(Image.open(os.path.join(self.LR_path, self.LR_img[i]))) + + img_item['LR'] = (LR / 127.5) - 1.0 + img_item['LR'] = img_item['LR'].transpose(2, 0, 1).astype(np.float32) + + return img_item + + +class crop(object): + def __init__(self, scale, patch_size): + + self.scale = scale + self.patch_size = patch_size + + def __call__(self, sample): + LR_img, GT_img = sample['LR'], sample['GT'] + ih, iw = LR_img.shape[:2] + + ix = random.randrange(0, iw - self.patch_size +1) + iy = random.randrange(0, ih - self.patch_size +1) + + tx = ix * self.scale + ty = iy * self.scale + + LR_patch = LR_img[iy : iy + self.patch_size, ix : ix + self.patch_size] + GT_patch = GT_img[ty : ty + (self.scale * self.patch_size), tx : tx + (self.scale * self.patch_size)] + + return {'LR' : LR_patch, 'GT' : GT_patch} + +class augmentation(object): + + def __call__(self, sample): + LR_img, GT_img = sample['LR'], sample['GT'] + + hor_flip = random.randrange(0,2) + ver_flip = random.randrange(0,2) + rot = random.randrange(0,2) + + if hor_flip: + temp_LR = np.fliplr(LR_img) + LR_img = temp_LR.copy() + temp_GT = np.fliplr(GT_img) + GT_img = temp_GT.copy() + + del temp_LR, temp_GT + + if ver_flip: + temp_LR = np.flipud(LR_img) + LR_img = temp_LR.copy() + temp_GT = np.flipud(GT_img) + GT_img = temp_GT.copy() + + del temp_LR, temp_GT + + if rot: + LR_img = LR_img.transpose(1, 0, 2) + GT_img = GT_img.transpose(1, 0, 2) + + + return {'LR' : LR_img, 'GT' : GT_img} + + diff --git a/SRGAN-PyTorch-master/information.txt b/SRGAN-PyTorch-master/information.txt new file mode 100644 index 0000000..e60ac7a --- /dev/null +++ b/SRGAN-PyTorch-master/information.txt @@ -0,0 +1,2 @@ +Programme écrit par dongheehand en Python +Disponible sur github au lien https://github.com/dongheehand/SRGAN-PyTorch \ No newline at end of file diff --git a/SRGAN-PyTorch-master/losses.py b/SRGAN-PyTorch-master/losses.py new file mode 100644 index 0000000..5b26ce4 --- /dev/null +++ b/SRGAN-PyTorch-master/losses.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn +from torchvision import transforms + + +class MeanShift(nn.Conv2d): + def __init__( + self, rgb_range = 1, + norm_mean=(0.485, 0.456, 0.406), norm_std=(0.229, 0.224, 0.225), sign=-1): + + super(MeanShift, self).__init__(3, 3, kernel_size=1) + std = torch.Tensor(norm_std) + self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) + self.bias.data = sign * rgb_range * torch.Tensor(norm_mean) / std + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + for p in self.parameters(): + p.requires_grad = False + + +class perceptual_loss(nn.Module): + + def __init__(self, vgg): + super(perceptual_loss, self).__init__() + self.normalization_mean = [0.485, 0.456, 0.406] + self.normalization_std = [0.229, 0.224, 0.225] + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.transform = MeanShift(norm_mean = self.normalization_mean, norm_std = self.normalization_std).to(self.device) + self.vgg = vgg + self.criterion = nn.MSELoss() + def forward(self, HR, SR, layer = 'relu5_4'): + ## HR and SR should be normalized [0,1] + hr = self.transform(HR) + sr = self.transform(SR) + + hr_feat = getattr(self.vgg(hr), layer) + sr_feat = getattr(self.vgg(sr), layer) + + return self.criterion(hr_feat, sr_feat), hr_feat, sr_feat + +class TVLoss(nn.Module): + def __init__(self, tv_loss_weight=1): + super(TVLoss, self).__init__() + self.tv_loss_weight = tv_loss_weight + + def forward(self, x): + batch_size = x.size()[0] + h_x = x.size()[2] + w_x = x.size()[3] + count_h = self.tensor_size(x[:, :, 1:, :]) + count_w = self.tensor_size(x[:, :, :, 1:]) + h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum() + w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum() + + return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size + + @staticmethod + def tensor_size(t): + return t.size()[1] * t.size()[2] * t.size()[3] + diff --git a/SRGAN-PyTorch-master/main.py b/SRGAN-PyTorch-master/main.py new file mode 100644 index 0000000..15670e8 --- /dev/null +++ b/SRGAN-PyTorch-master/main.py @@ -0,0 +1,39 @@ +from mode import * +import argparse + + +parser = argparse.ArgumentParser() + +def str2bool(v): + return v.lower() in ('true') + +parser.add_argument("--LR_path", type = str, default = '../dataSet/DIV2K/DIV2K_train_LR_bicubic/X4') +parser.add_argument("--GT_path", type = str, default = '../dataSet/DIV2K/DIV2K_train_HR/') +parser.add_argument("--res_num", type = int, default = 16) +parser.add_argument("--num_workers", type = int, default = 0) +parser.add_argument("--batch_size", type = int, default = 16) +parser.add_argument("--L2_coeff", type = float, default = 1.0) +parser.add_argument("--adv_coeff", type = float, default = 1e-3) +parser.add_argument("--tv_loss_coeff", type = float, default = 0.0) +parser.add_argument("--pre_train_epoch", type = int, default = 8000) +parser.add_argument("--fine_train_epoch", type = int, default = 4000) +parser.add_argument("--scale", type = int, default = 4) +parser.add_argument("--patch_size", type = int, default = 24) +parser.add_argument("--feat_layer", type = str, default = 'relu5_4') +parser.add_argument("--vgg_rescale_coeff", type = float, default = 0.006) +parser.add_argument("--fine_tuning", type = str2bool, default = False) +parser.add_argument("--in_memory", type = str2bool, default = True) +parser.add_argument("--generator_path", type = str) +parser.add_argument("--mode", type = str, default = 'train') + +args = parser.parse_args() + +if args.mode == 'train': + train(args) + +elif args.mode == 'test': + test(args) + +elif args.mode == 'test_only': + test_only(args) + diff --git a/SRGAN-PyTorch-master/mode.py b/SRGAN-PyTorch-master/mode.py new file mode 100644 index 0000000..ce15605 --- /dev/null +++ b/SRGAN-PyTorch-master/mode.py @@ -0,0 +1,208 @@ +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader +from torchvision import transforms +from losses import TVLoss, perceptual_loss +from dataset import * +from srgan_model import Generator, Discriminator +from vgg19 import vgg19 +import numpy as np +from PIL import Image +from skimage.color import rgb2ycbcr +from skimage.measure import compare_psnr + + +def train(args): + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + transform = transforms.Compose([crop(args.scale, args.patch_size), augmentation()]) + dataset = mydata(GT_path = args.GT_path, LR_path = args.LR_path, in_memory = args.in_memory, transform = transform) + loader = DataLoader(dataset, batch_size = args.batch_size, shuffle = True, num_workers = args.num_workers) + + generator = Generator(img_feat = 3, n_feats = 64, kernel_size = 3, num_block = args.res_num, scale=args.scale) + + + if args.fine_tuning: + generator.load_state_dict(torch.load(args.generator_path)) + print("pre-trained model is loaded") + print("path : %s"%(args.generator_path)) + + generator = generator.to(device) + generator.train() + + l2_loss = nn.MSELoss() + g_optim = optim.Adam(generator.parameters(), lr = 1e-4) + + pre_epoch = 0 + fine_epoch = 0 + + #### Train using L2_loss + while pre_epoch < args.pre_train_epoch: + for i, tr_data in enumerate(loader): + gt = tr_data['GT'].to(device) + lr = tr_data['LR'].to(device) + + output, _ = generator(lr) + loss = l2_loss(gt, output) + + g_optim.zero_grad() + loss.backward() + g_optim.step() + + pre_epoch += 1 + + if pre_epoch % 2 == 0: + print(pre_epoch) + print(loss.item()) + print('=========') + + if pre_epoch % 800 ==0: + torch.save(generator.state_dict(), './model/pre_trained_model_%03d.pt'%pre_epoch) + + + #### Train using perceptual & adversarial loss + vgg_net = vgg19().to(device) + vgg_net = vgg_net.eval() + + discriminator = Discriminator(patch_size = args.patch_size * args.scale) + discriminator = discriminator.to(device) + discriminator.train() + + d_optim = optim.Adam(discriminator.parameters(), lr = 1e-4) + scheduler = optim.lr_scheduler.StepLR(g_optim, step_size = 2000, gamma = 0.1) + + VGG_loss = perceptual_loss(vgg_net) + cross_ent = nn.BCELoss() + tv_loss = TVLoss() + real_label = torch.ones((args.batch_size, 1)).to(device) + fake_label = torch.zeros((args.batch_size, 1)).to(device) + + while fine_epoch < args.fine_train_epoch: + + scheduler.step() + + for i, tr_data in enumerate(loader): + gt = tr_data['GT'].to(device) + lr = tr_data['LR'].to(device) + + ## Training Discriminator + output, _ = generator(lr) + fake_prob = discriminator(output) + real_prob = discriminator(gt) + + d_loss_real = cross_ent(real_prob, real_label) + d_loss_fake = cross_ent(fake_prob, fake_label) + + d_loss = d_loss_real + d_loss_fake + + g_optim.zero_grad() + d_optim.zero_grad() + d_loss.backward() + d_optim.step() + + ## Training Generator + output, _ = generator(lr) + fake_prob = discriminator(output) + + _percep_loss, hr_feat, sr_feat = VGG_loss((gt + 1.0) / 2.0, (output + 1.0) / 2.0, layer = args.feat_layer) + + L2_loss = l2_loss(output, gt) + percep_loss = args.vgg_rescale_coeff * _percep_loss + adversarial_loss = args.adv_coeff * cross_ent(fake_prob, real_label) + total_variance_loss = args.tv_loss_coeff * tv_loss(args.vgg_rescale_coeff * (hr_feat - sr_feat)**2) + + g_loss = percep_loss + adversarial_loss + total_variance_loss + L2_loss + + g_optim.zero_grad() + d_optim.zero_grad() + g_loss.backward() + g_optim.step() + + + fine_epoch += 1 + + if fine_epoch % 2 == 0: + print(fine_epoch) + print(g_loss.item()) + print(d_loss.item()) + print('=========') + + if fine_epoch % 500 ==0: + torch.save(generator.state_dict(), './model/SRGAN_gene_%03d.pt'%fine_epoch) + torch.save(discriminator.state_dict(), './model/SRGAN_discrim_%03d.pt'%fine_epoch) + + +# In[ ]: + +def test(args): + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dataset = mydata(GT_path = args.GT_path, LR_path = args.LR_path, in_memory = False, transform = None) + loader = DataLoader(dataset, batch_size = 1, shuffle = False, num_workers = args.num_workers) + + generator = Generator(img_feat = 3, n_feats = 64, kernel_size = 3, num_block = args.res_num) + generator.load_state_dict(torch.load(args.generator_path)) + generator = generator.to(device) + generator.eval() + + f = open('./result.txt', 'w') + psnr_list = [] + + with torch.no_grad(): + for i, te_data in enumerate(loader): + gt = te_data['GT'].to(device) + lr = te_data['LR'].to(device) + + bs, c, h, w = lr.size() + gt = gt[:, :, : h * args.scale, : w *args.scale] + + output, _ = generator(lr) + + output = output[0].cpu().numpy() + output = np.clip(output, -1.0, 1.0) + gt = gt[0].cpu().numpy() + + output = (output + 1.0) / 2.0 + gt = (gt + 1.0) / 2.0 + + output = output.transpose(1,2,0) + gt = gt.transpose(1,2,0) + + y_output = rgb2ycbcr(output)[args.scale:-args.scale, args.scale:-args.scale, :1] + y_gt = rgb2ycbcr(gt)[args.scale:-args.scale, args.scale:-args.scale, :1] + + psnr = compare_psnr(y_output / 255.0, y_gt / 255.0, data_range = 1.0) + psnr_list.append(psnr) + f.write('psnr : %04f \n' % psnr) + + result = Image.fromarray((output * 255.0).astype(np.uint8)) + result.save('./result/res_%04d.png'%i) + + f.write('avg psnr : %04f' % np.mean(psnr_list)) + + +def test_only(args): + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dataset = testOnly_data(LR_path = args.LR_path, in_memory = False, transform = None) + loader = DataLoader(dataset, batch_size = 1, shuffle = False, num_workers = args.num_workers) + + generator = Generator(img_feat = 3, n_feats = 64, kernel_size = 3, num_block = args.res_num) + generator.load_state_dict(torch.load(args.generator_path)) + generator = generator.to(device) + generator.eval() + + with torch.no_grad(): + for i, te_data in enumerate(loader): + lr = te_data['LR'].to(device) + output, _ = generator(lr) + output = output[0].cpu().numpy() + output = (output + 1.0) / 2.0 + output = output.transpose(1,2,0) + result = Image.fromarray((output * 255.0).astype(np.uint8)) + result.save('./result/res_%04d.png'%i) + + + diff --git a/SRGAN-PyTorch-master/ops.py b/SRGAN-PyTorch-master/ops.py new file mode 100644 index 0000000..24ac7d9 --- /dev/null +++ b/SRGAN-PyTorch-master/ops.py @@ -0,0 +1,97 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class _conv(nn.Conv2d): + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias): + super(_conv, self).__init__(in_channels = in_channels, out_channels = out_channels, + kernel_size = kernel_size, stride = stride, padding = (kernel_size) // 2, bias = True) + + self.weight.data = torch.normal(torch.zeros((out_channels, in_channels, kernel_size, kernel_size)), 0.02) + self.bias.data = torch.zeros((out_channels)) + + for p in self.parameters(): + p.requires_grad = True + + +class conv(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size, BN = False, act = None, stride = 1, bias = True): + super(conv, self).__init__() + m = [] + m.append(_conv(in_channels = in_channel, out_channels = out_channel, + kernel_size = kernel_size, stride = stride, padding = (kernel_size) // 2, bias = True)) + + if BN: + m.append(nn.BatchNorm2d(num_features = out_channel)) + + if act is not None: + m.append(act) + + self.body = nn.Sequential(*m) + + def forward(self, x): + out = self.body(x) + return out + +class ResBlock(nn.Module): + def __init__(self, channels, kernel_size, act = nn.ReLU(inplace = True), bias = True): + super(ResBlock, self).__init__() + m = [] + m.append(conv(channels, channels, kernel_size, BN = True, act = act)) + m.append(conv(channels, channels, kernel_size, BN = True, act = None)) + self.body = nn.Sequential(*m) + + def forward(self, x): + res = self.body(x) + res += x + return res + +class BasicBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, num_res_block, act = nn.ReLU(inplace = True)): + super(BasicBlock, self).__init__() + m = [] + + self.conv = conv(in_channels, out_channels, kernel_size, BN = False, act = act) + for i in range(num_res_block): + m.append(ResBlock(out_channels, kernel_size, act)) + + m.append(conv(out_channels, out_channels, kernel_size, BN = True, act = None)) + + self.body = nn.Sequential(*m) + + def forward(self, x): + res = self.conv(x) + out = self.body(res) + out += res + + return out + +class Upsampler(nn.Module): + def __init__(self, channel, kernel_size, scale, act = nn.ReLU(inplace = True)): + super(Upsampler, self).__init__() + m = [] + m.append(conv(channel, channel * scale * scale, kernel_size)) + m.append(nn.PixelShuffle(scale)) + + if act is not None: + m.append(act) + + self.body = nn.Sequential(*m) + + def forward(self, x): + out = self.body(x) + return out + +class discrim_block(nn.Module): + def __init__(self, in_feats, out_feats, kernel_size, act = nn.LeakyReLU(inplace = True)): + super(discrim_block, self).__init__() + m = [] + m.append(conv(in_feats, out_feats, kernel_size, BN = True, act = act)) + m.append(conv(out_feats, out_feats, kernel_size, BN = True, act = act, stride = 2)) + self.body = nn.Sequential(*m) + + def forward(self, x): + out = self.body(x) + return out + diff --git a/SRGAN-PyTorch-master/readme.md b/SRGAN-PyTorch-master/readme.md new file mode 100644 index 0000000..61ac2f0 --- /dev/null +++ b/SRGAN-PyTorch-master/readme.md @@ -0,0 +1,68 @@ +# Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network + +## Overview + +An unofficial implementation of SRGAN described in the paper using PyTorch. +* [ Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network](https://arxiv.org/abs/1609.04802) + +Published in CVPR 2017 + +## Requirement +- Python 3.6.5 +- PyTorch 1.1.0 +- Pillow 5.1.0 +- numpy 1.14.5 +- scikit-image 0.15.0 + +## Datasets +- [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) + +## Pre-trained model +- [SRResNet](https://drive.google.com/open?id=15F2zOrOg2hIjdI0WsrOwF1y8REOkmmm0) + + +- [SRGAN](https://drive.google.com/open?id=1-HmcV5X94u411HRa-KEMcGhAO1OXAjAc) + +## Train & Test +Train + +``` +python main.py --LR_path ./LR_imgs_dir --GT_path ./GT_imgs_dir +``` + +Test + +``` +python main.py --mode test --LR_path ./LR_imgs_dir --GT_path ./GT_imgs_dir --generator_path ./model/SRGAN.pt +``` + +Inference your own images + +``` +python main.py --mode test_only --LR_path ./LR_imgs_dir --generator_path ./model/SRGAN.pt +``` + +## Experimental Results +Experimental results on benchmarks. + +### Quantitative Results + +| Method| Set5| Set14| B100 | +|-------|-----|------|------| +|Bicubic|28.43|25.99|25.94| +|SRResNet(paper)|32.05|28.49|27.58| +|SRResNet(my model)|31.96|28.48|27.49| +|SRGAN(paper)|29.40|26.02|25.16| +|SRGAN(my model)|29.93|26.95|26.10| + +### Qualitative Results + +| Bicubic | SRResNet | SRGAN | +| --- | --- | --- | +| | | | +| | | | +| | | | + +## Comments +If you have any questions or comments on my codes, please email to me. [son1113@snu.ac.kr](mailto:son1113@snu.ac.kr) + diff --git a/SRGAN-PyTorch-master/srgan_model.py b/SRGAN-PyTorch-master/srgan_model.py new file mode 100644 index 0000000..644d1e1 --- /dev/null +++ b/SRGAN-PyTorch-master/srgan_model.py @@ -0,0 +1,74 @@ +import torch +import torch.nn as nn +from ops import * + + +class Generator(nn.Module): + + def __init__(self, img_feat = 3, n_feats = 64, kernel_size = 3, num_block = 16, act = nn.PReLU(), scale=4): + super(Generator, self).__init__() + + self.conv01 = conv(in_channel = img_feat, out_channel = n_feats, kernel_size = 9, BN = False, act = act) + + resblocks = [ResBlock(channels = n_feats, kernel_size = 3, act = act) for _ in range(num_block)] + self.body = nn.Sequential(*resblocks) + + self.conv02 = conv(in_channel = n_feats, out_channel = n_feats, kernel_size = 3, BN = True, act = None) + + if(scale == 4): + upsample_blocks = [Upsampler(channel = n_feats, kernel_size = 3, scale = 2, act = act) for _ in range(2)] + else: + upsample_blocks = [Upsampler(channel = n_feats, kernel_size = 3, scale = scale, act = act)] + + self.tail = nn.Sequential(*upsample_blocks) + + self.last_conv = conv(in_channel = n_feats, out_channel = img_feat, kernel_size = 3, BN = False, act = nn.Tanh()) + + def forward(self, x): + + x = self.conv01(x) + _skip_connection = x + + x = self.body(x) + x = self.conv02(x) + feat = x + _skip_connection + + x = self.tail(feat) + x = self.last_conv(x) + + return x, feat + +class Discriminator(nn.Module): + + def __init__(self, img_feat = 3, n_feats = 64, kernel_size = 3, act = nn.LeakyReLU(inplace = True), num_of_block = 3, patch_size = 96): + super(Discriminator, self).__init__() + self.act = act + + self.conv01 = conv(in_channel = img_feat, out_channel = n_feats, kernel_size = 3, BN = False, act = self.act) + self.conv02 = conv(in_channel = n_feats, out_channel = n_feats, kernel_size = 3, BN = False, act = self.act, stride = 2) + + body = [discrim_block(in_feats = n_feats * (2 ** i), out_feats = n_feats * (2 ** (i + 1)), kernel_size = 3, act = self.act) for i in range(num_of_block)] + self.body = nn.Sequential(*body) + + self.linear_size = ((patch_size // (2 ** (num_of_block + 1))) ** 2) * (n_feats * (2 ** num_of_block)) + + tail = [] + + tail.append(nn.Linear(self.linear_size, 1024)) + tail.append(self.act) + tail.append(nn.Linear(1024, 1)) + tail.append(nn.Sigmoid()) + + self.tail = nn.Sequential(*tail) + + + def forward(self, x): + + x = self.conv01(x) + x = self.conv02(x) + x = self.body(x) + x = x.view(-1, self.linear_size) + x = self.tail(x) + + return x + diff --git a/SRGAN-PyTorch-master/vgg19.py b/SRGAN-PyTorch-master/vgg19.py new file mode 100644 index 0000000..b3f265f --- /dev/null +++ b/SRGAN-PyTorch-master/vgg19.py @@ -0,0 +1,80 @@ +from torchvision import models +from collections import namedtuple +import torch +import torch.nn as nn + + +class vgg19(nn.Module): + + def __init__(self, pre_trained = True, require_grad = False): + super(vgg19, self).__init__() + self.vgg_feature = models.vgg19(pretrained = pre_trained).features + self.seq_list = [nn.Sequential(ele) for ele in self.vgg_feature] + self.vgg_layer = ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', + 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', + 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', + 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'] + + if not require_grad: + for parameter in self.parameters(): + parameter.requires_grad = False + + def forward(self, x): + + conv1_1 = self.seq_list[0](x) + relu1_1 = self.seq_list[1](conv1_1) + conv1_2 = self.seq_list[2](relu1_1) + relu1_2 = self.seq_list[3](conv1_2) + pool1 = self.seq_list[4](relu1_2) + + conv2_1 = self.seq_list[5](pool1) + relu2_1 = self.seq_list[6](conv2_1) + conv2_2 = self.seq_list[7](relu2_1) + relu2_2 = self.seq_list[8](conv2_2) + pool2 = self.seq_list[9](relu2_2) + + conv3_1 = self.seq_list[10](pool2) + relu3_1 = self.seq_list[11](conv3_1) + conv3_2 = self.seq_list[12](relu3_1) + relu3_2 = self.seq_list[13](conv3_2) + conv3_3 = self.seq_list[14](relu3_2) + relu3_3 = self.seq_list[15](conv3_3) + conv3_4 = self.seq_list[16](relu3_3) + relu3_4 = self.seq_list[17](conv3_4) + pool3 = self.seq_list[18](relu3_4) + + conv4_1 = self.seq_list[19](pool3) + relu4_1 = self.seq_list[20](conv4_1) + conv4_2 = self.seq_list[21](relu4_1) + relu4_2 = self.seq_list[22](conv4_2) + conv4_3 = self.seq_list[23](relu4_2) + relu4_3 = self.seq_list[24](conv4_3) + conv4_4 = self.seq_list[25](relu4_3) + relu4_4 = self.seq_list[26](conv4_4) + pool4 = self.seq_list[27](relu4_4) + + conv5_1 = self.seq_list[28](pool4) + relu5_1 = self.seq_list[29](conv5_1) + conv5_2 = self.seq_list[30](relu5_1) + relu5_2 = self.seq_list[31](conv5_2) + conv5_3 = self.seq_list[32](relu5_2) + relu5_3 = self.seq_list[33](conv5_3) + conv5_4 = self.seq_list[34](relu5_3) + relu5_4 = self.seq_list[35](conv5_4) + pool5 = self.seq_list[36](relu5_4) + + vgg_output = namedtuple("vgg_output", self.vgg_layer) + + vgg_list = [conv1_1, relu1_1, conv1_2, relu1_2, pool1, + conv2_1, relu2_1, conv2_2, relu2_2, pool2, + conv3_1, relu3_1, conv3_2, relu3_2, conv3_3, relu3_3, conv3_4, relu3_4, pool3, + conv4_1, relu4_1, conv4_2, relu4_2, conv4_3, relu4_3, conv4_4, relu4_4, pool4, + conv5_1, relu5_1, conv5_2, relu5_2, conv5_3, relu5_3, conv5_4, relu5_4, pool5] + + out = vgg_output(*vgg_list) + + + return out + + diff --git a/Super-résolution.pyw b/Super-résolution.pyw new file mode 100644 index 0000000..ff98704 --- /dev/null +++ b/Super-résolution.pyw @@ -0,0 +1,211 @@ +from tkinter import * +from tkinter.filedialog import askopenfilename +from tkinter.messagebox import showerror +import os, subprocess, shutil, glob, time +from PIL import Image, ImageTk +from threading import Thread + +class SRGUI(): + def __init__(self): + self.root = Tk() + + self.input_img_path = StringVar() + + self.canvas = {} + self.image = {} + self.imagetk = {} + self.id_canvas = {} + self.label = {} + + self.default_diff = 20 + self.diff = self.default_diff + + ##### ENTREE + self.frame_input = Frame(self.root) + self.frame_input.grid(row=1, column=1, rowspan=2) + + Label(self.frame_input, text="Entrée", font=("Purisa", 18)).grid(row=1, column=1, columnspan=2) + self.canvas["input"] = Canvas(self.frame_input, relief=SOLID, borderwidth=2, width=400, height=400) + self.canvas["input"].grid(row=2, column=1, columnspan=2) + Entry(self.frame_input, textvariable=self.input_img_path, width=60).grid(row=3, column=1, sticky="NEWS") + + Button(self.frame_input, text="...", relief=RIDGE, command=self.select_img).grid(row=3, column=2, sticky="NEWS") + Button(self.frame_input, text="Calculer", relief=RIDGE, command=self.calcul).grid(row=4, column=1, columnspan=2, sticky="NEWS") + + def zoom(event): + for type in ["input", "bil", "bic", "SRCNN", "SRGAN", "original"]: + if type in self.id_canvas: + x1 = ((event.x - self.diff) * self.image[type].width) // 400 + x2 = ((event.x + self.diff) * self.image[type].width) // 400 + y1 = ((event.y - self.diff) * self.image[type].height) // 400 + y2 = ((event.y + self.diff) * self.image[type].height) // 400 + + _img = self.image[type].crop((x1, y1, x2, y2)) + self.imagetk[type] = ImageTk.PhotoImage(_img.resize((400, 400))) + self.canvas[type].itemconfig(self.id_canvas[type], image = self.imagetk[type]) + + def change_diff(event): + if event.delta > 0: self.diff -= 5 + else: self.diff += 5 + zoom(event) + + def reset(event): + for type in ["input", "bil", "bic", "SRCNN", "SRGAN", "original"]: + if type in self.id_canvas: + self.imagetk[type] = ImageTk.PhotoImage(self.image[type].resize((400, 400))) + self.canvas[type].itemconfig(self.id_canvas[type], image = self.imagetk[type]) + + def activate_zoom(event): + self.canvas["input"].bind("", zoom) + def desactivate_zoom(event): + self.canvas["input"].unbind("") + + self.canvas["input"].bind("", activate_zoom) + self.canvas["input"].bind("", desactivate_zoom) + + self.canvas["input"].bind("", reset) + self.canvas["input"].bind("", change_diff) + + ##### BILINEAIRE + self.frame_bil = Frame(self.root) + self.frame_bil.grid(row=1, column=3) + + Label(self.frame_bil, text="Bilinéaire", font=("Purisa", 18)).grid(row=1, column=1) + self.canvas["bil"] = Canvas(self.frame_bil, relief=SOLID, borderwidth=2, width=400, height=400) + self.canvas["bil"].grid(row=2, column=1) + + self.label["bil"] = Label(self.frame_bil, fg="gray") + self.label["bil"].grid(row=3, column=1) + + ##### BICUBIQUE + self.frame_bic = Frame(self.root) + self.frame_bic.grid(row=2, column=3) + + Label(self.frame_bic, text="Bicubique", font=("Purisa", 18)).grid(row=1, column=1) + self.canvas["bic"] = Canvas(self.frame_bic, relief=SOLID, borderwidth=2, width=400, height=400) + self.canvas["bic"].grid(row=2, column=1) + + self.label["bic"] = Label(self.frame_bic, fg="gray") + self.label["bic"].grid(row=3, column=1) + + ##### SRCNN + self.frame_SRCNN = Frame(self.root) + self.frame_SRCNN.grid(row=1, column=4) + + Label(self.frame_SRCNN, text="SRCNN", font=("Purisa", 18)).grid(row=1, column=1) + self.canvas["SRCNN"] = Canvas(self.frame_SRCNN, relief=SOLID, borderwidth=2, width=400, height=400) + self.canvas["SRCNN"].grid(row=2, column=1) + + self.label["SRCNN"] = Label(self.frame_SRCNN, fg="gray") + self.label["SRCNN"].grid(row=3, column=1) + + ##### SRGAN + self.frame_SRGAN = Frame(self.root) + self.frame_SRGAN.grid(row=2, column=4) + + Label(self.frame_SRGAN, text="SRGAN", font=("Purisa", 18)).grid(row=1, column=1) + self.canvas["SRGAN"] = Canvas(self.frame_SRGAN, relief=SOLID, borderwidth=2, width=400, height=400) + self.canvas["SRGAN"].grid(row=2, column=1) + + self.label["SRGAN"] = Label(self.frame_SRGAN, fg="gray") + self.label["SRGAN"].grid(row=3, column=1) + + ##### ORIGINAL + self.frame_original = Frame(self.root) + + Label(self.frame_original, text="Original", font=("Purisa", 18)).grid(row=1, column=1) + self.canvas["original"] = Canvas(self.frame_original, relief=SOLID, borderwidth=2, width=400, height=400) + self.canvas["original"].grid(row=2, column=1) + + + def select_img(self): + self.input_img_path.set(askopenfilename(filetypes = [('Images', '*.png *.jpg *.jpeg *.bmp *.gif')])) + if os.path.exists(self.input_img_path.get()): + self.image["input"] = Image.open(self.input_img_path.get()) + self.imagetk["input"] = ImageTk.PhotoImage(self.image["input"].resize((400, 400))) + self.id_canvas["input"] = self.canvas["input"].create_image(204, 204, image = self.imagetk["input"]) + + _l = self.input_img_path.get().split(".") + _l[-2] += "_original" + original_img_path = ".".join(_l) + if os.path.exists(original_img_path): + self.image["original"] = Image.open(original_img_path) + self.imagetk["original"] = ImageTk.PhotoImage(self.image["original"].resize((400, 400))) + self.id_canvas["original"] = self.canvas["original"].create_image(204, 204, image=self.imagetk["original"]) + + self.frame_original.grid(row=1, column=5, rowspan=2) + else: self.frame_original.grid_forget() + + + else: showerror("Erreur", "Ce fichier n'existe pas.") + + def calcul(self): + def thread_func(): + type2func = {"bil": self.Bilinear, "bic": self.Bicubic, "SRCNN": self.SRCNN, "SRGAN": self.SRGAN} + for type in type2func: + try: + t1 = time.time() + type2func[type]() + t2 = time.time() + except: self.label["bil"].config(text="Erreur", fg="red") + else: self.label["bil"].config(text="Temps de calcul : %0.2fs" % (t2 - t1), fg="gray") + + thfunc = Thread(target=thread_func) + thfunc.setDaemon(True) + thfunc.start() + + def Bilinear(self): + self.image["bil"] = Image.open(self.input_img_path.get()).resize((400 * 4, 400 * 4), Image.BILINEAR) + self.imagetk["bil"] = ImageTk.PhotoImage(self.image["bil"].resize((400, 400))) + self.id_canvas["bil"] = self.canvas["bil"].create_image(204, 204, image=self.imagetk["bil"]) + + def Bicubic(self): + self.image["bic"] = Image.open(self.input_img_path.get()).resize((400 * 4, 400 * 4), Image.BICUBIC) + self.imagetk["bic"] = ImageTk.PhotoImage(self.image["bic"].resize((400, 400))) + self.id_canvas["bic"] = self.canvas["bic"].create_image(204, 204, image=self.imagetk["bic"]) + + def SRCNN(self): + SRCNN_img_path = "./SRCNN-pytorch-master/data/" + os.path.basename(self.input_img_path.get()) + self.image["input"].resize((self.image["input"].width * 4, self.image["input"].height * 4), Image.BICUBIC).save(SRCNN_img_path) + + subprocess.call( + '"D:/ProgramData/Anaconda3/python.exe" \ + "./SRCNN-pytorch-master/test.py" \ + --weights-file "./SRCNN-pytorch-master/weights/srcnn_x4.pth" \ + --image-file "%s" \ + --scale 4' % SRCNN_img_path) + + l = SRCNN_img_path.split(".") + l[-2] += "_srcnn_x4" + self.SRCNN_img_path = ".".join(l) + + self.image["SRCNN"] = Image.open(self.SRCNN_img_path) + self.imagetk["SRCNN"] = ImageTk.PhotoImage(self.image["SRCNN"].resize((400, 400))) + self.id_canvas["SRCNN"] = self.canvas["SRCNN"].create_image(204, 204, image=self.imagetk["SRCNN"]) + + def SRGAN(self): + SRGAN_img_path = "./SRGAN-PyTorch-master/LR_imgs_dir/" + os.path.basename(self.input_img_path.get()) + shutil.copy(self.input_img_path.get(), SRGAN_img_path) + + workingdir = os.getcwd() + os.chdir("./SRGAN-PyTorch-master/") + + subprocess.call( + '"D:/ProgramData/Anaconda3/python.exe" \ + main.py \ + --mode test_only \ + --LR_path LR_imgs_dir/ \ + --generator_path model/SRGAN.pt"') + + os.chdir(workingdir) + + os.remove(SRGAN_img_path) + + self.SRGAN_img_path = glob.glob("./SRGAN-PyTorch-master/result/res_*.png")[-1] + + self.image["SRGAN"] = Image.open(self.SRGAN_img_path) + self.imagetk["SRGAN"] = ImageTk.PhotoImage(self.image["SRGAN"].resize((400, 400))) + self.id_canvas["SRGAN"] = self.canvas["SRGAN"].create_image(204, 204, image=self.imagetk["SRGAN"]) + +App = SRGUI() +mainloop() \ No newline at end of file