Initial commit
This commit is contained in:
commit
50e4c082d8
26 changed files with 1477 additions and 0 deletions
129
SRCNN-pytorch-master/README.md
Normal file
129
SRCNN-pytorch-master/README.md
Normal file
|
@ -0,0 +1,129 @@
|
||||||
|
# SRCNN
|
||||||
|
|
||||||
|
This repository is implementation of the ["Image Super-Resolution Using Deep Convolutional Networks"](https://arxiv.org/abs/1501.00092).
|
||||||
|
|
||||||
|
<center><img src="./thumbnails/fig1.png"></center>
|
||||||
|
|
||||||
|
## Differences from the original
|
||||||
|
|
||||||
|
- Added the zero-padding
|
||||||
|
- Used the Adam instead of the SGD
|
||||||
|
- Removed the weights initialization
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
- PyTorch 1.0.0
|
||||||
|
- Numpy 1.15.4
|
||||||
|
- Pillow 5.4.1
|
||||||
|
- h5py 2.8.0
|
||||||
|
- tqdm 4.30.0
|
||||||
|
|
||||||
|
## Train
|
||||||
|
|
||||||
|
The 91-image, Set5 dataset converted to HDF5 can be downloaded from the links below.
|
||||||
|
|
||||||
|
| Dataset | Scale | Type | Link |
|
||||||
|
|---------|-------|------|------|
|
||||||
|
| 91-image | 2 | Train | [Download](https://www.dropbox.com/s/2hsah93sxgegsry/91-image_x2.h5?dl=0) |
|
||||||
|
| 91-image | 3 | Train | [Download](https://www.dropbox.com/s/curldmdf11iqakd/91-image_x3.h5?dl=0) |
|
||||||
|
| 91-image | 4 | Train | [Download](https://www.dropbox.com/s/22afykv4amfxeio/91-image_x4.h5?dl=0) |
|
||||||
|
| Set5 | 2 | Eval | [Download](https://www.dropbox.com/s/r8qs6tp395hgh8g/Set5_x2.h5?dl=0) |
|
||||||
|
| Set5 | 3 | Eval | [Download](https://www.dropbox.com/s/58ywjac4te3kbqq/Set5_x3.h5?dl=0) |
|
||||||
|
| Set5 | 4 | Eval | [Download](https://www.dropbox.com/s/0rz86yn3nnrodlb/Set5_x4.h5?dl=0) |
|
||||||
|
|
||||||
|
Otherwise, you can use `prepare.py` to create custom dataset.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python train.py --train-file "BLAH_BLAH/91-image_x3.h5" \
|
||||||
|
--eval-file "BLAH_BLAH/Set5_x3.h5" \
|
||||||
|
--outputs-dir "BLAH_BLAH/outputs" \
|
||||||
|
--scale 3 \
|
||||||
|
--lr 1e-4 \
|
||||||
|
--batch-size 16 \
|
||||||
|
--num-epochs 400 \
|
||||||
|
--num-workers 8 \
|
||||||
|
--seed 123
|
||||||
|
```
|
||||||
|
|
||||||
|
## Test
|
||||||
|
|
||||||
|
Pre-trained weights can be downloaded from the links below.
|
||||||
|
|
||||||
|
| Model | Scale | Link |
|
||||||
|
|-------|-------|------|
|
||||||
|
| 9-5-5 | 2 | [Download](https://www.dropbox.com/s/rxluu1y8ptjm4rn/srcnn_x2.pth?dl=0) |
|
||||||
|
| 9-5-5 | 3 | [Download](https://www.dropbox.com/s/zn4fdobm2kw0c58/srcnn_x3.pth?dl=0) |
|
||||||
|
| 9-5-5 | 4 | [Download](https://www.dropbox.com/s/pd5b2ketm0oamhj/srcnn_x4.pth?dl=0) |
|
||||||
|
|
||||||
|
The results are stored in the same path as the query image.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python test.py --weights-file "BLAH_BLAH/srcnn_x3.pth" \
|
||||||
|
--image-file "data/butterfly_GT.bmp" \
|
||||||
|
--scale 3
|
||||||
|
```
|
||||||
|
|
||||||
|
## Results
|
||||||
|
|
||||||
|
We used the network settings for experiments, i.e., <a href="https://www.codecogs.com/eqnedit.php?latex={&space;f&space;}_{&space;1&space;}=9,{&space;f&space;}_{&space;2&space;}=5,{&space;f&space;}_{&space;3&space;}=5,{&space;n&space;}_{&space;1&space;}=64,{&space;n&space;}_{&space;2&space;}=32,{&space;n&space;}_{&space;3&space;}=1" target="_blank"><img src="https://latex.codecogs.com/gif.latex?{&space;f&space;}_{&space;1&space;}=9,{&space;f&space;}_{&space;2&space;}=5,{&space;f&space;}_{&space;3&space;}=5,{&space;n&space;}_{&space;1&space;}=64,{&space;n&space;}_{&space;2&space;}=32,{&space;n&space;}_{&space;3&space;}=1" title="{ f }_{ 1 }=9,{ f }_{ 2 }=5,{ f }_{ 3 }=5,{ n }_{ 1 }=64,{ n }_{ 2 }=32,{ n }_{ 3 }=1" /></a>.
|
||||||
|
|
||||||
|
PSNR was calculated on the Y channel.
|
||||||
|
|
||||||
|
### Set5
|
||||||
|
|
||||||
|
| Eval. Mat | Scale | SRCNN | SRCNN (Ours) |
|
||||||
|
|-----------|-------|-------|--------------|
|
||||||
|
| PSNR | 2 | 36.66 | 36.65 |
|
||||||
|
| PSNR | 3 | 32.75 | 33.29 |
|
||||||
|
| PSNR | 4 | 30.49 | 30.25 |
|
||||||
|
|
||||||
|
<table>
|
||||||
|
<tr>
|
||||||
|
<td><center>Original</center></td>
|
||||||
|
<td><center>BICUBIC x3</center></td>
|
||||||
|
<td><center>SRCNN x3 (27.53 dB)</center></td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td>
|
||||||
|
<center><img src="./data/butterfly_GT.bmp""></center>
|
||||||
|
</td>
|
||||||
|
<td>
|
||||||
|
<center><img src="./data/butterfly_GT_bicubic_x3.bmp"></center>
|
||||||
|
</td>
|
||||||
|
<td>
|
||||||
|
<center><img src="./data/butterfly_GT_srcnn_x3.bmp"></center>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td><center>Original</center></td>
|
||||||
|
<td><center>BICUBIC x3</center></td>
|
||||||
|
<td><center>SRCNN x3 (29.30 dB)</center></td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td>
|
||||||
|
<center><img src="./data/zebra.bmp""></center>
|
||||||
|
</td>
|
||||||
|
<td>
|
||||||
|
<center><img src="./data/zebra_bicubic_x3.bmp"></center>
|
||||||
|
</td>
|
||||||
|
<td>
|
||||||
|
<center><img src="./data/zebra_srcnn_x3.bmp"></center>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td><center>Original</center></td>
|
||||||
|
<td><center>BICUBIC x3</center></td>
|
||||||
|
<td><center>SRCNN x3 (28.58 dB)</center></td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td>
|
||||||
|
<center><img src="./data/ppt3.bmp""></center>
|
||||||
|
</td>
|
||||||
|
<td>
|
||||||
|
<center><img src="./data/ppt3_bicubic_x3.bmp"></center>
|
||||||
|
</td>
|
||||||
|
<td>
|
||||||
|
<center><img src="./data/ppt3_srcnn_x3.bmp"></center>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
BIN
SRCNN-pytorch-master/__pycache__/models.cpython-38.pyc
Normal file
BIN
SRCNN-pytorch-master/__pycache__/models.cpython-38.pyc
Normal file
Binary file not shown.
BIN
SRCNN-pytorch-master/__pycache__/utils.cpython-38.pyc
Normal file
BIN
SRCNN-pytorch-master/__pycache__/utils.cpython-38.pyc
Normal file
Binary file not shown.
31
SRCNN-pytorch-master/datasets.py
Normal file
31
SRCNN-pytorch-master/datasets.py
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
|
||||||
|
class TrainDataset(Dataset):
|
||||||
|
def __init__(self, h5_file):
|
||||||
|
super(TrainDataset, self).__init__()
|
||||||
|
self.h5_file = h5_file
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
with h5py.File(self.h5_file, 'r') as f:
|
||||||
|
return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
with h5py.File(self.h5_file, 'r') as f:
|
||||||
|
return len(f['lr'])
|
||||||
|
|
||||||
|
|
||||||
|
class EvalDataset(Dataset):
|
||||||
|
def __init__(self, h5_file):
|
||||||
|
super(EvalDataset, self).__init__()
|
||||||
|
self.h5_file = h5_file
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
with h5py.File(self.h5_file, 'r') as f:
|
||||||
|
return np.expand_dims(f['lr'][str(idx)][:, :] / 255., 0), np.expand_dims(f['hr'][str(idx)][:, :] / 255., 0)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
with h5py.File(self.h5_file, 'r') as f:
|
||||||
|
return len(f['lr'])
|
2
SRCNN-pytorch-master/information.txt
Normal file
2
SRCNN-pytorch-master/information.txt
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
Programme écrit par : yjn870
|
||||||
|
Code disponible sur github au lien : https://github.com/yjn870/SRCNN-pytorch
|
16
SRCNN-pytorch-master/models.py
Normal file
16
SRCNN-pytorch-master/models.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class SRCNN(nn.Module):
|
||||||
|
def __init__(self, num_channels=1):
|
||||||
|
super(SRCNN, self).__init__()
|
||||||
|
self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2)
|
||||||
|
self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2)
|
||||||
|
self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.relu(self.conv1(x))
|
||||||
|
x = self.relu(self.conv2(x))
|
||||||
|
x = self.conv3(x)
|
||||||
|
return x
|
78
SRCNN-pytorch-master/prepare.py
Normal file
78
SRCNN-pytorch-master/prepare.py
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
import PIL.Image as pil_image
|
||||||
|
from utils import convert_rgb_to_y
|
||||||
|
|
||||||
|
|
||||||
|
def train(args):
|
||||||
|
h5_file = h5py.File(args.output_path, 'w')
|
||||||
|
|
||||||
|
lr_patches = []
|
||||||
|
hr_patches = []
|
||||||
|
|
||||||
|
for image_path in sorted(glob.glob('{}/*'.format(args.images_dir))):
|
||||||
|
hr = pil_image.open(image_path).convert('RGB')
|
||||||
|
hr_width = (hr.width // args.scale) * args.scale
|
||||||
|
hr_height = (hr.height // args.scale) * args.scale
|
||||||
|
hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
|
||||||
|
lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)
|
||||||
|
lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)
|
||||||
|
hr = np.array(hr).astype(np.float32)
|
||||||
|
lr = np.array(lr).astype(np.float32)
|
||||||
|
hr = convert_rgb_to_y(hr)
|
||||||
|
lr = convert_rgb_to_y(lr)
|
||||||
|
|
||||||
|
for i in range(0, lr.shape[0] - args.patch_size + 1, args.stride):
|
||||||
|
for j in range(0, lr.shape[1] - args.patch_size + 1, args.stride):
|
||||||
|
lr_patches.append(lr[i:i + args.patch_size, j:j + args.patch_size])
|
||||||
|
hr_patches.append(hr[i:i + args.patch_size, j:j + args.patch_size])
|
||||||
|
|
||||||
|
lr_patches = np.array(lr_patches)
|
||||||
|
hr_patches = np.array(hr_patches)
|
||||||
|
|
||||||
|
h5_file.create_dataset('lr', data=lr_patches)
|
||||||
|
h5_file.create_dataset('hr', data=hr_patches)
|
||||||
|
|
||||||
|
h5_file.close()
|
||||||
|
|
||||||
|
|
||||||
|
def eval(args):
|
||||||
|
h5_file = h5py.File(args.output_path, 'w')
|
||||||
|
|
||||||
|
lr_group = h5_file.create_group('lr')
|
||||||
|
hr_group = h5_file.create_group('hr')
|
||||||
|
|
||||||
|
for i, image_path in enumerate(sorted(glob.glob('{}/*'.format(args.images_dir)))):
|
||||||
|
hr = pil_image.open(image_path).convert('RGB')
|
||||||
|
hr_width = (hr.width // args.scale) * args.scale
|
||||||
|
hr_height = (hr.height // args.scale) * args.scale
|
||||||
|
hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
|
||||||
|
lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)
|
||||||
|
lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)
|
||||||
|
hr = np.array(hr).astype(np.float32)
|
||||||
|
lr = np.array(lr).astype(np.float32)
|
||||||
|
hr = convert_rgb_to_y(hr)
|
||||||
|
lr = convert_rgb_to_y(lr)
|
||||||
|
|
||||||
|
lr_group.create_dataset(str(i), data=lr)
|
||||||
|
hr_group.create_dataset(str(i), data=hr)
|
||||||
|
|
||||||
|
h5_file.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--images-dir', type=str, required=True)
|
||||||
|
parser.add_argument('--output-path', type=str, required=True)
|
||||||
|
parser.add_argument('--patch-size', type=int, default=33)
|
||||||
|
parser.add_argument('--stride', type=int, default=14)
|
||||||
|
parser.add_argument('--scale', type=int, default=2)
|
||||||
|
parser.add_argument('--eval', action='store_true')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not args.eval:
|
||||||
|
train(args)
|
||||||
|
else:
|
||||||
|
eval(args)
|
67
SRCNN-pytorch-master/test.py
Normal file
67
SRCNN-pytorch-master/test.py
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
|
import numpy as np
|
||||||
|
import PIL.Image as pil_image
|
||||||
|
|
||||||
|
from models import SRCNN
|
||||||
|
from utils import convert_rgb_to_ycbcr, convert_ycbcr_to_rgb, calc_psnr
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--weights-file', type=str, required=True)
|
||||||
|
parser.add_argument('--image-file', type=str, required=True)
|
||||||
|
parser.add_argument('--scale', type=int, default=3)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
cudnn.benchmark = True
|
||||||
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
model = SRCNN().to(device)
|
||||||
|
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():
|
||||||
|
if n in state_dict.keys():
|
||||||
|
state_dict[n].copy_(p)
|
||||||
|
else:
|
||||||
|
raise KeyError(n)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
image = pil_image.open(args.image_file).convert('RGB')
|
||||||
|
|
||||||
|
image_width = (image.width // args.scale) * args.scale
|
||||||
|
image_height = (image.height // args.scale) * args.scale
|
||||||
|
image = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
|
||||||
|
image = image.resize((image.width // args.scale, image.height // args.scale), resample=pil_image.BICUBIC)
|
||||||
|
image = image.resize((image.width * args.scale, image.height * args.scale), resample=pil_image.BICUBIC)
|
||||||
|
|
||||||
|
|
||||||
|
image = np.array(image).astype(np.float32)
|
||||||
|
ycbcr = convert_rgb_to_ycbcr(image)
|
||||||
|
|
||||||
|
y = ycbcr[..., 0]
|
||||||
|
y /= 255.
|
||||||
|
y = torch.from_numpy(y).to(device)
|
||||||
|
y = y.unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
preds = model(y).clamp(0.0, 1.0)
|
||||||
|
|
||||||
|
psnr = calc_psnr(y, preds)
|
||||||
|
print('PSNR: {:.2f}'.format(psnr))
|
||||||
|
|
||||||
|
preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)
|
||||||
|
|
||||||
|
output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
|
||||||
|
output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
|
||||||
|
output = pil_image.fromarray(output)
|
||||||
|
|
||||||
|
l = args.image_file.split(".")
|
||||||
|
l[-2] += '_srcnn_x{}'.format(args.scale)
|
||||||
|
args.image_file = ".".join(l)
|
||||||
|
|
||||||
|
output.save(args.image_file)
|
||||||
|
|
112
SRCNN-pytorch-master/train.py
Normal file
112
SRCNN-pytorch-master/train.py
Normal file
|
@ -0,0 +1,112 @@
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import copy
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.optim as optim
|
||||||
|
import torch.backends.cudnn as cudnn
|
||||||
|
from torch.utils.data.dataloader import DataLoader
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from models import SRCNN
|
||||||
|
from datasets import TrainDataset, EvalDataset
|
||||||
|
from utils import AverageMeter, calc_psnr
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--train-file', type=str, required=True)
|
||||||
|
parser.add_argument('--eval-file', type=str, required=True)
|
||||||
|
parser.add_argument('--outputs-dir', type=str, required=True)
|
||||||
|
parser.add_argument('--scale', type=int, default=3)
|
||||||
|
parser.add_argument('--lr', type=float, default=1e-4)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=16)
|
||||||
|
parser.add_argument('--num-epochs', type=int, default=400)
|
||||||
|
parser.add_argument('--num-workers', type=int, default=8)
|
||||||
|
parser.add_argument('--seed', type=int, default=123)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
args.outputs_dir = os.path.join(args.outputs_dir, 'x{}'.format(args.scale))
|
||||||
|
|
||||||
|
if not os.path.exists(args.outputs_dir):
|
||||||
|
os.makedirs(args.outputs_dir)
|
||||||
|
|
||||||
|
cudnn.benchmark = True
|
||||||
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
|
||||||
|
model = SRCNN().to(device)
|
||||||
|
criterion = nn.MSELoss()
|
||||||
|
optimizer = optim.Adam([
|
||||||
|
{'params': model.conv1.parameters()},
|
||||||
|
{'params': model.conv2.parameters()},
|
||||||
|
{'params': model.conv3.parameters(), 'lr': args.lr * 0.1}
|
||||||
|
], lr=args.lr)
|
||||||
|
|
||||||
|
train_dataset = TrainDataset(args.train_file)
|
||||||
|
train_dataloader = DataLoader(dataset=train_dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
drop_last=True)
|
||||||
|
eval_dataset = EvalDataset(args.eval_file)
|
||||||
|
eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)
|
||||||
|
|
||||||
|
best_weights = copy.deepcopy(model.state_dict())
|
||||||
|
best_epoch = 0
|
||||||
|
best_psnr = 0.0
|
||||||
|
|
||||||
|
for epoch in range(args.num_epochs):
|
||||||
|
model.train()
|
||||||
|
epoch_losses = AverageMeter()
|
||||||
|
|
||||||
|
with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size)) as t:
|
||||||
|
t.set_description('epoch: {}/{}'.format(epoch, args.num_epochs - 1))
|
||||||
|
|
||||||
|
for data in train_dataloader:
|
||||||
|
inputs, labels = data
|
||||||
|
|
||||||
|
inputs = inputs.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
|
||||||
|
preds = model(inputs)
|
||||||
|
|
||||||
|
loss = criterion(preds, labels)
|
||||||
|
|
||||||
|
epoch_losses.update(loss.item(), len(inputs))
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
|
||||||
|
t.update(len(inputs))
|
||||||
|
|
||||||
|
torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch)))
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
epoch_psnr = AverageMeter()
|
||||||
|
|
||||||
|
for data in eval_dataloader:
|
||||||
|
inputs, labels = data
|
||||||
|
|
||||||
|
inputs = inputs.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
preds = model(inputs).clamp(0.0, 1.0)
|
||||||
|
|
||||||
|
epoch_psnr.update(calc_psnr(preds, labels), len(inputs))
|
||||||
|
|
||||||
|
print('eval psnr: {:.2f}'.format(epoch_psnr.avg))
|
||||||
|
|
||||||
|
if epoch_psnr.avg > best_psnr:
|
||||||
|
best_epoch = epoch
|
||||||
|
best_psnr = epoch_psnr.avg
|
||||||
|
best_weights = copy.deepcopy(model.state_dict())
|
||||||
|
|
||||||
|
print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
|
||||||
|
torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))
|
68
SRCNN-pytorch-master/utils.py
Normal file
68
SRCNN-pytorch-master/utils.py
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def convert_rgb_to_y(img):
|
||||||
|
if type(img) == np.ndarray:
|
||||||
|
return 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
|
||||||
|
elif type(img) == torch.Tensor:
|
||||||
|
if len(img.shape) == 4:
|
||||||
|
img = img.squeeze(0)
|
||||||
|
return 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256.
|
||||||
|
else:
|
||||||
|
raise Exception('Unknown Type', type(img))
|
||||||
|
|
||||||
|
|
||||||
|
def convert_rgb_to_ycbcr(img):
|
||||||
|
if type(img) == np.ndarray:
|
||||||
|
y = 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
|
||||||
|
cb = 128. + (-37.945 * img[:, :, 0] - 74.494 * img[:, :, 1] + 112.439 * img[:, :, 2]) / 256.
|
||||||
|
cr = 128. + (112.439 * img[:, :, 0] - 94.154 * img[:, :, 1] - 18.285 * img[:, :, 2]) / 256.
|
||||||
|
return np.array([y, cb, cr]).transpose([1, 2, 0])
|
||||||
|
elif type(img) == torch.Tensor:
|
||||||
|
if len(img.shape) == 4:
|
||||||
|
img = img.squeeze(0)
|
||||||
|
y = 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256.
|
||||||
|
cb = 128. + (-37.945 * img[0, :, :] - 74.494 * img[1, :, :] + 112.439 * img[2, :, :]) / 256.
|
||||||
|
cr = 128. + (112.439 * img[0, :, :] - 94.154 * img[1, :, :] - 18.285 * img[2, :, :]) / 256.
|
||||||
|
return torch.cat([y, cb, cr], 0).permute(1, 2, 0)
|
||||||
|
else:
|
||||||
|
raise Exception('Unknown Type', type(img))
|
||||||
|
|
||||||
|
|
||||||
|
def convert_ycbcr_to_rgb(img):
|
||||||
|
if type(img) == np.ndarray:
|
||||||
|
r = 298.082 * img[:, :, 0] / 256. + 408.583 * img[:, :, 2] / 256. - 222.921
|
||||||
|
g = 298.082 * img[:, :, 0] / 256. - 100.291 * img[:, :, 1] / 256. - 208.120 * img[:, :, 2] / 256. + 135.576
|
||||||
|
b = 298.082 * img[:, :, 0] / 256. + 516.412 * img[:, :, 1] / 256. - 276.836
|
||||||
|
return np.array([r, g, b]).transpose([1, 2, 0])
|
||||||
|
elif type(img) == torch.Tensor:
|
||||||
|
if len(img.shape) == 4:
|
||||||
|
img = img.squeeze(0)
|
||||||
|
r = 298.082 * img[0, :, :] / 256. + 408.583 * img[2, :, :] / 256. - 222.921
|
||||||
|
g = 298.082 * img[0, :, :] / 256. - 100.291 * img[1, :, :] / 256. - 208.120 * img[2, :, :] / 256. + 135.576
|
||||||
|
b = 298.082 * img[0, :, :] / 256. + 516.412 * img[1, :, :] / 256. - 276.836
|
||||||
|
return torch.cat([r, g, b], 0).permute(1, 2, 0)
|
||||||
|
else:
|
||||||
|
raise Exception('Unknown Type', type(img))
|
||||||
|
|
||||||
|
|
||||||
|
def calc_psnr(img1, img2):
|
||||||
|
return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))
|
||||||
|
|
||||||
|
|
||||||
|
class AverageMeter(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.val = 0
|
||||||
|
self.avg = 0
|
||||||
|
self.sum = 0
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
def update(self, val, n=1):
|
||||||
|
self.val = val
|
||||||
|
self.sum += val * n
|
||||||
|
self.count += n
|
||||||
|
self.avg = self.sum / self.count
|
BIN
SRGAN-PyTorch-master/__pycache__/dataset.cpython-38.pyc
Normal file
BIN
SRGAN-PyTorch-master/__pycache__/dataset.cpython-38.pyc
Normal file
Binary file not shown.
BIN
SRGAN-PyTorch-master/__pycache__/losses.cpython-38.pyc
Normal file
BIN
SRGAN-PyTorch-master/__pycache__/losses.cpython-38.pyc
Normal file
Binary file not shown.
BIN
SRGAN-PyTorch-master/__pycache__/mode.cpython-38.pyc
Normal file
BIN
SRGAN-PyTorch-master/__pycache__/mode.cpython-38.pyc
Normal file
Binary file not shown.
BIN
SRGAN-PyTorch-master/__pycache__/ops.cpython-38.pyc
Normal file
BIN
SRGAN-PyTorch-master/__pycache__/ops.cpython-38.pyc
Normal file
Binary file not shown.
BIN
SRGAN-PyTorch-master/__pycache__/srgan_model.cpython-38.pyc
Normal file
BIN
SRGAN-PyTorch-master/__pycache__/srgan_model.cpython-38.pyc
Normal file
Binary file not shown.
BIN
SRGAN-PyTorch-master/__pycache__/vgg19.cpython-38.pyc
Normal file
BIN
SRGAN-PyTorch-master/__pycache__/vgg19.cpython-38.pyc
Normal file
Binary file not shown.
135
SRGAN-PyTorch-master/dataset.py
Normal file
135
SRGAN-PyTorch-master/dataset.py
Normal file
|
@ -0,0 +1,135 @@
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
import os
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
class mydata(Dataset):
|
||||||
|
def __init__(self, LR_path, GT_path, in_memory = True, transform = None):
|
||||||
|
|
||||||
|
self.LR_path = LR_path
|
||||||
|
self.GT_path = GT_path
|
||||||
|
self.in_memory = in_memory
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
self.LR_img = sorted(os.listdir(LR_path))
|
||||||
|
self.GT_img = sorted(os.listdir(GT_path))
|
||||||
|
|
||||||
|
if in_memory:
|
||||||
|
self.LR_img = [np.array(Image.open(os.path.join(self.LR_path, lr)).convert("RGB")).astype(np.uint8) for lr in self.LR_img]
|
||||||
|
self.GT_img = [np.array(Image.open(os.path.join(self.GT_path, gt)).convert("RGB")).astype(np.uint8) for gt in self.GT_img]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
|
||||||
|
return len(self.LR_img)
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
|
||||||
|
img_item = {}
|
||||||
|
|
||||||
|
if self.in_memory:
|
||||||
|
GT = self.GT_img[i].astype(np.float32)
|
||||||
|
LR = self.LR_img[i].astype(np.float32)
|
||||||
|
|
||||||
|
else:
|
||||||
|
GT = np.array(Image.open(os.path.join(self.GT_path, self.GT_img[i])).convert("RGB"))
|
||||||
|
LR = np.array(Image.open(os.path.join(self.LR_path, self.LR_img[i])).convert("RGB"))
|
||||||
|
|
||||||
|
img_item['GT'] = (GT / 127.5) - 1.0
|
||||||
|
img_item['LR'] = (LR / 127.5) - 1.0
|
||||||
|
|
||||||
|
if self.transform is not None:
|
||||||
|
img_item = self.transform(img_item)
|
||||||
|
|
||||||
|
img_item['GT'] = img_item['GT'].transpose(2, 0, 1).astype(np.float32)
|
||||||
|
img_item['LR'] = img_item['LR'].transpose(2, 0, 1).astype(np.float32)
|
||||||
|
|
||||||
|
return img_item
|
||||||
|
|
||||||
|
|
||||||
|
class testOnly_data(Dataset):
|
||||||
|
def __init__(self, LR_path, in_memory = True, transform = None):
|
||||||
|
|
||||||
|
self.LR_path = LR_path
|
||||||
|
self.LR_img = sorted(os.listdir(LR_path))
|
||||||
|
self.in_memory = in_memory
|
||||||
|
|
||||||
|
if in_memory:
|
||||||
|
self.LR_img = [np.array(Image.open(os.path.join(self.LR_path, lr))) for lr in self.LR_img]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
|
||||||
|
return len(self.LR_img)
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
|
||||||
|
img_item = {}
|
||||||
|
|
||||||
|
if self.in_memory:
|
||||||
|
LR = self.LR_img[i]
|
||||||
|
|
||||||
|
else:
|
||||||
|
LR = np.array(Image.open(os.path.join(self.LR_path, self.LR_img[i])))
|
||||||
|
|
||||||
|
img_item['LR'] = (LR / 127.5) - 1.0
|
||||||
|
img_item['LR'] = img_item['LR'].transpose(2, 0, 1).astype(np.float32)
|
||||||
|
|
||||||
|
return img_item
|
||||||
|
|
||||||
|
|
||||||
|
class crop(object):
|
||||||
|
def __init__(self, scale, patch_size):
|
||||||
|
|
||||||
|
self.scale = scale
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
def __call__(self, sample):
|
||||||
|
LR_img, GT_img = sample['LR'], sample['GT']
|
||||||
|
ih, iw = LR_img.shape[:2]
|
||||||
|
|
||||||
|
ix = random.randrange(0, iw - self.patch_size +1)
|
||||||
|
iy = random.randrange(0, ih - self.patch_size +1)
|
||||||
|
|
||||||
|
tx = ix * self.scale
|
||||||
|
ty = iy * self.scale
|
||||||
|
|
||||||
|
LR_patch = LR_img[iy : iy + self.patch_size, ix : ix + self.patch_size]
|
||||||
|
GT_patch = GT_img[ty : ty + (self.scale * self.patch_size), tx : tx + (self.scale * self.patch_size)]
|
||||||
|
|
||||||
|
return {'LR' : LR_patch, 'GT' : GT_patch}
|
||||||
|
|
||||||
|
class augmentation(object):
|
||||||
|
|
||||||
|
def __call__(self, sample):
|
||||||
|
LR_img, GT_img = sample['LR'], sample['GT']
|
||||||
|
|
||||||
|
hor_flip = random.randrange(0,2)
|
||||||
|
ver_flip = random.randrange(0,2)
|
||||||
|
rot = random.randrange(0,2)
|
||||||
|
|
||||||
|
if hor_flip:
|
||||||
|
temp_LR = np.fliplr(LR_img)
|
||||||
|
LR_img = temp_LR.copy()
|
||||||
|
temp_GT = np.fliplr(GT_img)
|
||||||
|
GT_img = temp_GT.copy()
|
||||||
|
|
||||||
|
del temp_LR, temp_GT
|
||||||
|
|
||||||
|
if ver_flip:
|
||||||
|
temp_LR = np.flipud(LR_img)
|
||||||
|
LR_img = temp_LR.copy()
|
||||||
|
temp_GT = np.flipud(GT_img)
|
||||||
|
GT_img = temp_GT.copy()
|
||||||
|
|
||||||
|
del temp_LR, temp_GT
|
||||||
|
|
||||||
|
if rot:
|
||||||
|
LR_img = LR_img.transpose(1, 0, 2)
|
||||||
|
GT_img = GT_img.transpose(1, 0, 2)
|
||||||
|
|
||||||
|
|
||||||
|
return {'LR' : LR_img, 'GT' : GT_img}
|
||||||
|
|
||||||
|
|
2
SRGAN-PyTorch-master/information.txt
Normal file
2
SRGAN-PyTorch-master/information.txt
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
Programme écrit par dongheehand en Python
|
||||||
|
Disponible sur github au lien https://github.com/dongheehand/SRGAN-PyTorch
|
60
SRGAN-PyTorch-master/losses.py
Normal file
60
SRGAN-PyTorch-master/losses.py
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
|
||||||
|
class MeanShift(nn.Conv2d):
|
||||||
|
def __init__(
|
||||||
|
self, rgb_range = 1,
|
||||||
|
norm_mean=(0.485, 0.456, 0.406), norm_std=(0.229, 0.224, 0.225), sign=-1):
|
||||||
|
|
||||||
|
super(MeanShift, self).__init__(3, 3, kernel_size=1)
|
||||||
|
std = torch.Tensor(norm_std)
|
||||||
|
self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
|
||||||
|
self.bias.data = sign * rgb_range * torch.Tensor(norm_mean) / std
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
for p in self.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
|
||||||
|
|
||||||
|
class perceptual_loss(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, vgg):
|
||||||
|
super(perceptual_loss, self).__init__()
|
||||||
|
self.normalization_mean = [0.485, 0.456, 0.406]
|
||||||
|
self.normalization_std = [0.229, 0.224, 0.225]
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.transform = MeanShift(norm_mean = self.normalization_mean, norm_std = self.normalization_std).to(self.device)
|
||||||
|
self.vgg = vgg
|
||||||
|
self.criterion = nn.MSELoss()
|
||||||
|
def forward(self, HR, SR, layer = 'relu5_4'):
|
||||||
|
## HR and SR should be normalized [0,1]
|
||||||
|
hr = self.transform(HR)
|
||||||
|
sr = self.transform(SR)
|
||||||
|
|
||||||
|
hr_feat = getattr(self.vgg(hr), layer)
|
||||||
|
sr_feat = getattr(self.vgg(sr), layer)
|
||||||
|
|
||||||
|
return self.criterion(hr_feat, sr_feat), hr_feat, sr_feat
|
||||||
|
|
||||||
|
class TVLoss(nn.Module):
|
||||||
|
def __init__(self, tv_loss_weight=1):
|
||||||
|
super(TVLoss, self).__init__()
|
||||||
|
self.tv_loss_weight = tv_loss_weight
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
batch_size = x.size()[0]
|
||||||
|
h_x = x.size()[2]
|
||||||
|
w_x = x.size()[3]
|
||||||
|
count_h = self.tensor_size(x[:, :, 1:, :])
|
||||||
|
count_w = self.tensor_size(x[:, :, :, 1:])
|
||||||
|
h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
|
||||||
|
w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
|
||||||
|
|
||||||
|
return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tensor_size(t):
|
||||||
|
return t.size()[1] * t.size()[2] * t.size()[3]
|
||||||
|
|
39
SRGAN-PyTorch-master/main.py
Normal file
39
SRGAN-PyTorch-master/main.py
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
from mode import *
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
def str2bool(v):
|
||||||
|
return v.lower() in ('true')
|
||||||
|
|
||||||
|
parser.add_argument("--LR_path", type = str, default = '../dataSet/DIV2K/DIV2K_train_LR_bicubic/X4')
|
||||||
|
parser.add_argument("--GT_path", type = str, default = '../dataSet/DIV2K/DIV2K_train_HR/')
|
||||||
|
parser.add_argument("--res_num", type = int, default = 16)
|
||||||
|
parser.add_argument("--num_workers", type = int, default = 0)
|
||||||
|
parser.add_argument("--batch_size", type = int, default = 16)
|
||||||
|
parser.add_argument("--L2_coeff", type = float, default = 1.0)
|
||||||
|
parser.add_argument("--adv_coeff", type = float, default = 1e-3)
|
||||||
|
parser.add_argument("--tv_loss_coeff", type = float, default = 0.0)
|
||||||
|
parser.add_argument("--pre_train_epoch", type = int, default = 8000)
|
||||||
|
parser.add_argument("--fine_train_epoch", type = int, default = 4000)
|
||||||
|
parser.add_argument("--scale", type = int, default = 4)
|
||||||
|
parser.add_argument("--patch_size", type = int, default = 24)
|
||||||
|
parser.add_argument("--feat_layer", type = str, default = 'relu5_4')
|
||||||
|
parser.add_argument("--vgg_rescale_coeff", type = float, default = 0.006)
|
||||||
|
parser.add_argument("--fine_tuning", type = str2bool, default = False)
|
||||||
|
parser.add_argument("--in_memory", type = str2bool, default = True)
|
||||||
|
parser.add_argument("--generator_path", type = str)
|
||||||
|
parser.add_argument("--mode", type = str, default = 'train')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.mode == 'train':
|
||||||
|
train(args)
|
||||||
|
|
||||||
|
elif args.mode == 'test':
|
||||||
|
test(args)
|
||||||
|
|
||||||
|
elif args.mode == 'test_only':
|
||||||
|
test_only(args)
|
||||||
|
|
208
SRGAN-PyTorch-master/mode.py
Normal file
208
SRGAN-PyTorch-master/mode.py
Normal file
|
@ -0,0 +1,208 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torchvision import transforms
|
||||||
|
from losses import TVLoss, perceptual_loss
|
||||||
|
from dataset import *
|
||||||
|
from srgan_model import Generator, Discriminator
|
||||||
|
from vgg19 import vgg19
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from skimage.color import rgb2ycbcr
|
||||||
|
from skimage.measure import compare_psnr
|
||||||
|
|
||||||
|
|
||||||
|
def train(args):
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
transform = transforms.Compose([crop(args.scale, args.patch_size), augmentation()])
|
||||||
|
dataset = mydata(GT_path = args.GT_path, LR_path = args.LR_path, in_memory = args.in_memory, transform = transform)
|
||||||
|
loader = DataLoader(dataset, batch_size = args.batch_size, shuffle = True, num_workers = args.num_workers)
|
||||||
|
|
||||||
|
generator = Generator(img_feat = 3, n_feats = 64, kernel_size = 3, num_block = args.res_num, scale=args.scale)
|
||||||
|
|
||||||
|
|
||||||
|
if args.fine_tuning:
|
||||||
|
generator.load_state_dict(torch.load(args.generator_path))
|
||||||
|
print("pre-trained model is loaded")
|
||||||
|
print("path : %s"%(args.generator_path))
|
||||||
|
|
||||||
|
generator = generator.to(device)
|
||||||
|
generator.train()
|
||||||
|
|
||||||
|
l2_loss = nn.MSELoss()
|
||||||
|
g_optim = optim.Adam(generator.parameters(), lr = 1e-4)
|
||||||
|
|
||||||
|
pre_epoch = 0
|
||||||
|
fine_epoch = 0
|
||||||
|
|
||||||
|
#### Train using L2_loss
|
||||||
|
while pre_epoch < args.pre_train_epoch:
|
||||||
|
for i, tr_data in enumerate(loader):
|
||||||
|
gt = tr_data['GT'].to(device)
|
||||||
|
lr = tr_data['LR'].to(device)
|
||||||
|
|
||||||
|
output, _ = generator(lr)
|
||||||
|
loss = l2_loss(gt, output)
|
||||||
|
|
||||||
|
g_optim.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
g_optim.step()
|
||||||
|
|
||||||
|
pre_epoch += 1
|
||||||
|
|
||||||
|
if pre_epoch % 2 == 0:
|
||||||
|
print(pre_epoch)
|
||||||
|
print(loss.item())
|
||||||
|
print('=========')
|
||||||
|
|
||||||
|
if pre_epoch % 800 ==0:
|
||||||
|
torch.save(generator.state_dict(), './model/pre_trained_model_%03d.pt'%pre_epoch)
|
||||||
|
|
||||||
|
|
||||||
|
#### Train using perceptual & adversarial loss
|
||||||
|
vgg_net = vgg19().to(device)
|
||||||
|
vgg_net = vgg_net.eval()
|
||||||
|
|
||||||
|
discriminator = Discriminator(patch_size = args.patch_size * args.scale)
|
||||||
|
discriminator = discriminator.to(device)
|
||||||
|
discriminator.train()
|
||||||
|
|
||||||
|
d_optim = optim.Adam(discriminator.parameters(), lr = 1e-4)
|
||||||
|
scheduler = optim.lr_scheduler.StepLR(g_optim, step_size = 2000, gamma = 0.1)
|
||||||
|
|
||||||
|
VGG_loss = perceptual_loss(vgg_net)
|
||||||
|
cross_ent = nn.BCELoss()
|
||||||
|
tv_loss = TVLoss()
|
||||||
|
real_label = torch.ones((args.batch_size, 1)).to(device)
|
||||||
|
fake_label = torch.zeros((args.batch_size, 1)).to(device)
|
||||||
|
|
||||||
|
while fine_epoch < args.fine_train_epoch:
|
||||||
|
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
for i, tr_data in enumerate(loader):
|
||||||
|
gt = tr_data['GT'].to(device)
|
||||||
|
lr = tr_data['LR'].to(device)
|
||||||
|
|
||||||
|
## Training Discriminator
|
||||||
|
output, _ = generator(lr)
|
||||||
|
fake_prob = discriminator(output)
|
||||||
|
real_prob = discriminator(gt)
|
||||||
|
|
||||||
|
d_loss_real = cross_ent(real_prob, real_label)
|
||||||
|
d_loss_fake = cross_ent(fake_prob, fake_label)
|
||||||
|
|
||||||
|
d_loss = d_loss_real + d_loss_fake
|
||||||
|
|
||||||
|
g_optim.zero_grad()
|
||||||
|
d_optim.zero_grad()
|
||||||
|
d_loss.backward()
|
||||||
|
d_optim.step()
|
||||||
|
|
||||||
|
## Training Generator
|
||||||
|
output, _ = generator(lr)
|
||||||
|
fake_prob = discriminator(output)
|
||||||
|
|
||||||
|
_percep_loss, hr_feat, sr_feat = VGG_loss((gt + 1.0) / 2.0, (output + 1.0) / 2.0, layer = args.feat_layer)
|
||||||
|
|
||||||
|
L2_loss = l2_loss(output, gt)
|
||||||
|
percep_loss = args.vgg_rescale_coeff * _percep_loss
|
||||||
|
adversarial_loss = args.adv_coeff * cross_ent(fake_prob, real_label)
|
||||||
|
total_variance_loss = args.tv_loss_coeff * tv_loss(args.vgg_rescale_coeff * (hr_feat - sr_feat)**2)
|
||||||
|
|
||||||
|
g_loss = percep_loss + adversarial_loss + total_variance_loss + L2_loss
|
||||||
|
|
||||||
|
g_optim.zero_grad()
|
||||||
|
d_optim.zero_grad()
|
||||||
|
g_loss.backward()
|
||||||
|
g_optim.step()
|
||||||
|
|
||||||
|
|
||||||
|
fine_epoch += 1
|
||||||
|
|
||||||
|
if fine_epoch % 2 == 0:
|
||||||
|
print(fine_epoch)
|
||||||
|
print(g_loss.item())
|
||||||
|
print(d_loss.item())
|
||||||
|
print('=========')
|
||||||
|
|
||||||
|
if fine_epoch % 500 ==0:
|
||||||
|
torch.save(generator.state_dict(), './model/SRGAN_gene_%03d.pt'%fine_epoch)
|
||||||
|
torch.save(discriminator.state_dict(), './model/SRGAN_discrim_%03d.pt'%fine_epoch)
|
||||||
|
|
||||||
|
|
||||||
|
# In[ ]:
|
||||||
|
|
||||||
|
def test(args):
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
dataset = mydata(GT_path = args.GT_path, LR_path = args.LR_path, in_memory = False, transform = None)
|
||||||
|
loader = DataLoader(dataset, batch_size = 1, shuffle = False, num_workers = args.num_workers)
|
||||||
|
|
||||||
|
generator = Generator(img_feat = 3, n_feats = 64, kernel_size = 3, num_block = args.res_num)
|
||||||
|
generator.load_state_dict(torch.load(args.generator_path))
|
||||||
|
generator = generator.to(device)
|
||||||
|
generator.eval()
|
||||||
|
|
||||||
|
f = open('./result.txt', 'w')
|
||||||
|
psnr_list = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for i, te_data in enumerate(loader):
|
||||||
|
gt = te_data['GT'].to(device)
|
||||||
|
lr = te_data['LR'].to(device)
|
||||||
|
|
||||||
|
bs, c, h, w = lr.size()
|
||||||
|
gt = gt[:, :, : h * args.scale, : w *args.scale]
|
||||||
|
|
||||||
|
output, _ = generator(lr)
|
||||||
|
|
||||||
|
output = output[0].cpu().numpy()
|
||||||
|
output = np.clip(output, -1.0, 1.0)
|
||||||
|
gt = gt[0].cpu().numpy()
|
||||||
|
|
||||||
|
output = (output + 1.0) / 2.0
|
||||||
|
gt = (gt + 1.0) / 2.0
|
||||||
|
|
||||||
|
output = output.transpose(1,2,0)
|
||||||
|
gt = gt.transpose(1,2,0)
|
||||||
|
|
||||||
|
y_output = rgb2ycbcr(output)[args.scale:-args.scale, args.scale:-args.scale, :1]
|
||||||
|
y_gt = rgb2ycbcr(gt)[args.scale:-args.scale, args.scale:-args.scale, :1]
|
||||||
|
|
||||||
|
psnr = compare_psnr(y_output / 255.0, y_gt / 255.0, data_range = 1.0)
|
||||||
|
psnr_list.append(psnr)
|
||||||
|
f.write('psnr : %04f \n' % psnr)
|
||||||
|
|
||||||
|
result = Image.fromarray((output * 255.0).astype(np.uint8))
|
||||||
|
result.save('./result/res_%04d.png'%i)
|
||||||
|
|
||||||
|
f.write('avg psnr : %04f' % np.mean(psnr_list))
|
||||||
|
|
||||||
|
|
||||||
|
def test_only(args):
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
dataset = testOnly_data(LR_path = args.LR_path, in_memory = False, transform = None)
|
||||||
|
loader = DataLoader(dataset, batch_size = 1, shuffle = False, num_workers = args.num_workers)
|
||||||
|
|
||||||
|
generator = Generator(img_feat = 3, n_feats = 64, kernel_size = 3, num_block = args.res_num)
|
||||||
|
generator.load_state_dict(torch.load(args.generator_path))
|
||||||
|
generator = generator.to(device)
|
||||||
|
generator.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for i, te_data in enumerate(loader):
|
||||||
|
lr = te_data['LR'].to(device)
|
||||||
|
output, _ = generator(lr)
|
||||||
|
output = output[0].cpu().numpy()
|
||||||
|
output = (output + 1.0) / 2.0
|
||||||
|
output = output.transpose(1,2,0)
|
||||||
|
result = Image.fromarray((output * 255.0).astype(np.uint8))
|
||||||
|
result.save('./result/res_%04d.png'%i)
|
||||||
|
|
||||||
|
|
||||||
|
|
97
SRGAN-PyTorch-master/ops.py
Normal file
97
SRGAN-PyTorch-master/ops.py
Normal file
|
@ -0,0 +1,97 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class _conv(nn.Conv2d):
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias):
|
||||||
|
super(_conv, self).__init__(in_channels = in_channels, out_channels = out_channels,
|
||||||
|
kernel_size = kernel_size, stride = stride, padding = (kernel_size) // 2, bias = True)
|
||||||
|
|
||||||
|
self.weight.data = torch.normal(torch.zeros((out_channels, in_channels, kernel_size, kernel_size)), 0.02)
|
||||||
|
self.bias.data = torch.zeros((out_channels))
|
||||||
|
|
||||||
|
for p in self.parameters():
|
||||||
|
p.requires_grad = True
|
||||||
|
|
||||||
|
|
||||||
|
class conv(nn.Module):
|
||||||
|
def __init__(self, in_channel, out_channel, kernel_size, BN = False, act = None, stride = 1, bias = True):
|
||||||
|
super(conv, self).__init__()
|
||||||
|
m = []
|
||||||
|
m.append(_conv(in_channels = in_channel, out_channels = out_channel,
|
||||||
|
kernel_size = kernel_size, stride = stride, padding = (kernel_size) // 2, bias = True))
|
||||||
|
|
||||||
|
if BN:
|
||||||
|
m.append(nn.BatchNorm2d(num_features = out_channel))
|
||||||
|
|
||||||
|
if act is not None:
|
||||||
|
m.append(act)
|
||||||
|
|
||||||
|
self.body = nn.Sequential(*m)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.body(x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class ResBlock(nn.Module):
|
||||||
|
def __init__(self, channels, kernel_size, act = nn.ReLU(inplace = True), bias = True):
|
||||||
|
super(ResBlock, self).__init__()
|
||||||
|
m = []
|
||||||
|
m.append(conv(channels, channels, kernel_size, BN = True, act = act))
|
||||||
|
m.append(conv(channels, channels, kernel_size, BN = True, act = None))
|
||||||
|
self.body = nn.Sequential(*m)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
res = self.body(x)
|
||||||
|
res += x
|
||||||
|
return res
|
||||||
|
|
||||||
|
class BasicBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, kernel_size, num_res_block, act = nn.ReLU(inplace = True)):
|
||||||
|
super(BasicBlock, self).__init__()
|
||||||
|
m = []
|
||||||
|
|
||||||
|
self.conv = conv(in_channels, out_channels, kernel_size, BN = False, act = act)
|
||||||
|
for i in range(num_res_block):
|
||||||
|
m.append(ResBlock(out_channels, kernel_size, act))
|
||||||
|
|
||||||
|
m.append(conv(out_channels, out_channels, kernel_size, BN = True, act = None))
|
||||||
|
|
||||||
|
self.body = nn.Sequential(*m)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
res = self.conv(x)
|
||||||
|
out = self.body(res)
|
||||||
|
out += res
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
class Upsampler(nn.Module):
|
||||||
|
def __init__(self, channel, kernel_size, scale, act = nn.ReLU(inplace = True)):
|
||||||
|
super(Upsampler, self).__init__()
|
||||||
|
m = []
|
||||||
|
m.append(conv(channel, channel * scale * scale, kernel_size))
|
||||||
|
m.append(nn.PixelShuffle(scale))
|
||||||
|
|
||||||
|
if act is not None:
|
||||||
|
m.append(act)
|
||||||
|
|
||||||
|
self.body = nn.Sequential(*m)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.body(x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class discrim_block(nn.Module):
|
||||||
|
def __init__(self, in_feats, out_feats, kernel_size, act = nn.LeakyReLU(inplace = True)):
|
||||||
|
super(discrim_block, self).__init__()
|
||||||
|
m = []
|
||||||
|
m.append(conv(in_feats, out_feats, kernel_size, BN = True, act = act))
|
||||||
|
m.append(conv(out_feats, out_feats, kernel_size, BN = True, act = act, stride = 2))
|
||||||
|
self.body = nn.Sequential(*m)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.body(x)
|
||||||
|
return out
|
||||||
|
|
68
SRGAN-PyTorch-master/readme.md
Normal file
68
SRGAN-PyTorch-master/readme.md
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
# Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
An unofficial implementation of SRGAN described in the paper using PyTorch.
|
||||||
|
* [ Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network](https://arxiv.org/abs/1609.04802)
|
||||||
|
|
||||||
|
Published in CVPR 2017
|
||||||
|
|
||||||
|
## Requirement
|
||||||
|
- Python 3.6.5
|
||||||
|
- PyTorch 1.1.0
|
||||||
|
- Pillow 5.1.0
|
||||||
|
- numpy 1.14.5
|
||||||
|
- scikit-image 0.15.0
|
||||||
|
|
||||||
|
## Datasets
|
||||||
|
- [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/)
|
||||||
|
|
||||||
|
## Pre-trained model
|
||||||
|
- [SRResNet](https://drive.google.com/open?id=15F2zOrOg2hIjdI0WsrOwF1y8REOkmmm0)
|
||||||
|
|
||||||
|
|
||||||
|
- [SRGAN](https://drive.google.com/open?id=1-HmcV5X94u411HRa-KEMcGhAO1OXAjAc)
|
||||||
|
|
||||||
|
## Train & Test
|
||||||
|
Train
|
||||||
|
|
||||||
|
```
|
||||||
|
python main.py --LR_path ./LR_imgs_dir --GT_path ./GT_imgs_dir
|
||||||
|
```
|
||||||
|
|
||||||
|
Test
|
||||||
|
|
||||||
|
```
|
||||||
|
python main.py --mode test --LR_path ./LR_imgs_dir --GT_path ./GT_imgs_dir --generator_path ./model/SRGAN.pt
|
||||||
|
```
|
||||||
|
|
||||||
|
Inference your own images
|
||||||
|
|
||||||
|
```
|
||||||
|
python main.py --mode test_only --LR_path ./LR_imgs_dir --generator_path ./model/SRGAN.pt
|
||||||
|
```
|
||||||
|
|
||||||
|
## Experimental Results
|
||||||
|
Experimental results on benchmarks.
|
||||||
|
|
||||||
|
### Quantitative Results
|
||||||
|
|
||||||
|
| Method| Set5| Set14| B100 |
|
||||||
|
|-------|-----|------|------|
|
||||||
|
|Bicubic|28.43|25.99|25.94|
|
||||||
|
|SRResNet(paper)|32.05|28.49|27.58|
|
||||||
|
|SRResNet(my model)|31.96|28.48|27.49|
|
||||||
|
|SRGAN(paper)|29.40|26.02|25.16|
|
||||||
|
|SRGAN(my model)|29.93|26.95|26.10|
|
||||||
|
|
||||||
|
### Qualitative Results
|
||||||
|
|
||||||
|
| Bicubic | SRResNet | SRGAN |
|
||||||
|
| --- | --- | --- |
|
||||||
|
| <img src="result/Set14_BIx4/comic_LRBI_x4.png"> |<img src="result/set14_srres_result/res_0004.png"> | <img src="result/set14_srgan_result/res_0004.png"> |
|
||||||
|
| <img src="result/Set5_BIx4/woman_LRBI_x4.png"> |<img src="result/set5_srres_result/res_0004.png"> | <img src="result/set5_srgan_result/res_0004.png"> |
|
||||||
|
| <img src="result/Set14_BIx4/baboon_LRBI_x4.png"> |<img src="result/set14_srres_result/res_0000.png"> | <img src="result/set14_srgan_result/res_0000.png"> |
|
||||||
|
|
||||||
|
## Comments
|
||||||
|
If you have any questions or comments on my codes, please email to me. [son1113@snu.ac.kr](mailto:son1113@snu.ac.kr)
|
||||||
|
|
74
SRGAN-PyTorch-master/srgan_model.py
Normal file
74
SRGAN-PyTorch-master/srgan_model.py
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from ops import *
|
||||||
|
|
||||||
|
|
||||||
|
class Generator(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, img_feat = 3, n_feats = 64, kernel_size = 3, num_block = 16, act = nn.PReLU(), scale=4):
|
||||||
|
super(Generator, self).__init__()
|
||||||
|
|
||||||
|
self.conv01 = conv(in_channel = img_feat, out_channel = n_feats, kernel_size = 9, BN = False, act = act)
|
||||||
|
|
||||||
|
resblocks = [ResBlock(channels = n_feats, kernel_size = 3, act = act) for _ in range(num_block)]
|
||||||
|
self.body = nn.Sequential(*resblocks)
|
||||||
|
|
||||||
|
self.conv02 = conv(in_channel = n_feats, out_channel = n_feats, kernel_size = 3, BN = True, act = None)
|
||||||
|
|
||||||
|
if(scale == 4):
|
||||||
|
upsample_blocks = [Upsampler(channel = n_feats, kernel_size = 3, scale = 2, act = act) for _ in range(2)]
|
||||||
|
else:
|
||||||
|
upsample_blocks = [Upsampler(channel = n_feats, kernel_size = 3, scale = scale, act = act)]
|
||||||
|
|
||||||
|
self.tail = nn.Sequential(*upsample_blocks)
|
||||||
|
|
||||||
|
self.last_conv = conv(in_channel = n_feats, out_channel = img_feat, kernel_size = 3, BN = False, act = nn.Tanh())
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
x = self.conv01(x)
|
||||||
|
_skip_connection = x
|
||||||
|
|
||||||
|
x = self.body(x)
|
||||||
|
x = self.conv02(x)
|
||||||
|
feat = x + _skip_connection
|
||||||
|
|
||||||
|
x = self.tail(feat)
|
||||||
|
x = self.last_conv(x)
|
||||||
|
|
||||||
|
return x, feat
|
||||||
|
|
||||||
|
class Discriminator(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, img_feat = 3, n_feats = 64, kernel_size = 3, act = nn.LeakyReLU(inplace = True), num_of_block = 3, patch_size = 96):
|
||||||
|
super(Discriminator, self).__init__()
|
||||||
|
self.act = act
|
||||||
|
|
||||||
|
self.conv01 = conv(in_channel = img_feat, out_channel = n_feats, kernel_size = 3, BN = False, act = self.act)
|
||||||
|
self.conv02 = conv(in_channel = n_feats, out_channel = n_feats, kernel_size = 3, BN = False, act = self.act, stride = 2)
|
||||||
|
|
||||||
|
body = [discrim_block(in_feats = n_feats * (2 ** i), out_feats = n_feats * (2 ** (i + 1)), kernel_size = 3, act = self.act) for i in range(num_of_block)]
|
||||||
|
self.body = nn.Sequential(*body)
|
||||||
|
|
||||||
|
self.linear_size = ((patch_size // (2 ** (num_of_block + 1))) ** 2) * (n_feats * (2 ** num_of_block))
|
||||||
|
|
||||||
|
tail = []
|
||||||
|
|
||||||
|
tail.append(nn.Linear(self.linear_size, 1024))
|
||||||
|
tail.append(self.act)
|
||||||
|
tail.append(nn.Linear(1024, 1))
|
||||||
|
tail.append(nn.Sigmoid())
|
||||||
|
|
||||||
|
self.tail = nn.Sequential(*tail)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
x = self.conv01(x)
|
||||||
|
x = self.conv02(x)
|
||||||
|
x = self.body(x)
|
||||||
|
x = x.view(-1, self.linear_size)
|
||||||
|
x = self.tail(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
80
SRGAN-PyTorch-master/vgg19.py
Normal file
80
SRGAN-PyTorch-master/vgg19.py
Normal file
|
@ -0,0 +1,80 @@
|
||||||
|
from torchvision import models
|
||||||
|
from collections import namedtuple
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class vgg19(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, pre_trained = True, require_grad = False):
|
||||||
|
super(vgg19, self).__init__()
|
||||||
|
self.vgg_feature = models.vgg19(pretrained = pre_trained).features
|
||||||
|
self.seq_list = [nn.Sequential(ele) for ele in self.vgg_feature]
|
||||||
|
self.vgg_layer = ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
|
||||||
|
'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
||||||
|
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
|
||||||
|
'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
|
||||||
|
'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5']
|
||||||
|
|
||||||
|
if not require_grad:
|
||||||
|
for parameter in self.parameters():
|
||||||
|
parameter.requires_grad = False
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
conv1_1 = self.seq_list[0](x)
|
||||||
|
relu1_1 = self.seq_list[1](conv1_1)
|
||||||
|
conv1_2 = self.seq_list[2](relu1_1)
|
||||||
|
relu1_2 = self.seq_list[3](conv1_2)
|
||||||
|
pool1 = self.seq_list[4](relu1_2)
|
||||||
|
|
||||||
|
conv2_1 = self.seq_list[5](pool1)
|
||||||
|
relu2_1 = self.seq_list[6](conv2_1)
|
||||||
|
conv2_2 = self.seq_list[7](relu2_1)
|
||||||
|
relu2_2 = self.seq_list[8](conv2_2)
|
||||||
|
pool2 = self.seq_list[9](relu2_2)
|
||||||
|
|
||||||
|
conv3_1 = self.seq_list[10](pool2)
|
||||||
|
relu3_1 = self.seq_list[11](conv3_1)
|
||||||
|
conv3_2 = self.seq_list[12](relu3_1)
|
||||||
|
relu3_2 = self.seq_list[13](conv3_2)
|
||||||
|
conv3_3 = self.seq_list[14](relu3_2)
|
||||||
|
relu3_3 = self.seq_list[15](conv3_3)
|
||||||
|
conv3_4 = self.seq_list[16](relu3_3)
|
||||||
|
relu3_4 = self.seq_list[17](conv3_4)
|
||||||
|
pool3 = self.seq_list[18](relu3_4)
|
||||||
|
|
||||||
|
conv4_1 = self.seq_list[19](pool3)
|
||||||
|
relu4_1 = self.seq_list[20](conv4_1)
|
||||||
|
conv4_2 = self.seq_list[21](relu4_1)
|
||||||
|
relu4_2 = self.seq_list[22](conv4_2)
|
||||||
|
conv4_3 = self.seq_list[23](relu4_2)
|
||||||
|
relu4_3 = self.seq_list[24](conv4_3)
|
||||||
|
conv4_4 = self.seq_list[25](relu4_3)
|
||||||
|
relu4_4 = self.seq_list[26](conv4_4)
|
||||||
|
pool4 = self.seq_list[27](relu4_4)
|
||||||
|
|
||||||
|
conv5_1 = self.seq_list[28](pool4)
|
||||||
|
relu5_1 = self.seq_list[29](conv5_1)
|
||||||
|
conv5_2 = self.seq_list[30](relu5_1)
|
||||||
|
relu5_2 = self.seq_list[31](conv5_2)
|
||||||
|
conv5_3 = self.seq_list[32](relu5_2)
|
||||||
|
relu5_3 = self.seq_list[33](conv5_3)
|
||||||
|
conv5_4 = self.seq_list[34](relu5_3)
|
||||||
|
relu5_4 = self.seq_list[35](conv5_4)
|
||||||
|
pool5 = self.seq_list[36](relu5_4)
|
||||||
|
|
||||||
|
vgg_output = namedtuple("vgg_output", self.vgg_layer)
|
||||||
|
|
||||||
|
vgg_list = [conv1_1, relu1_1, conv1_2, relu1_2, pool1,
|
||||||
|
conv2_1, relu2_1, conv2_2, relu2_2, pool2,
|
||||||
|
conv3_1, relu3_1, conv3_2, relu3_2, conv3_3, relu3_3, conv3_4, relu3_4, pool3,
|
||||||
|
conv4_1, relu4_1, conv4_2, relu4_2, conv4_3, relu4_3, conv4_4, relu4_4, pool4,
|
||||||
|
conv5_1, relu5_1, conv5_2, relu5_2, conv5_3, relu5_3, conv5_4, relu5_4, pool5]
|
||||||
|
|
||||||
|
out = vgg_output(*vgg_list)
|
||||||
|
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
211
Super-résolution.pyw
Normal file
211
Super-résolution.pyw
Normal file
|
@ -0,0 +1,211 @@
|
||||||
|
from tkinter import *
|
||||||
|
from tkinter.filedialog import askopenfilename
|
||||||
|
from tkinter.messagebox import showerror
|
||||||
|
import os, subprocess, shutil, glob, time
|
||||||
|
from PIL import Image, ImageTk
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
class SRGUI():
|
||||||
|
def __init__(self):
|
||||||
|
self.root = Tk()
|
||||||
|
|
||||||
|
self.input_img_path = StringVar()
|
||||||
|
|
||||||
|
self.canvas = {}
|
||||||
|
self.image = {}
|
||||||
|
self.imagetk = {}
|
||||||
|
self.id_canvas = {}
|
||||||
|
self.label = {}
|
||||||
|
|
||||||
|
self.default_diff = 20
|
||||||
|
self.diff = self.default_diff
|
||||||
|
|
||||||
|
##### ENTREE
|
||||||
|
self.frame_input = Frame(self.root)
|
||||||
|
self.frame_input.grid(row=1, column=1, rowspan=2)
|
||||||
|
|
||||||
|
Label(self.frame_input, text="Entrée", font=("Purisa", 18)).grid(row=1, column=1, columnspan=2)
|
||||||
|
self.canvas["input"] = Canvas(self.frame_input, relief=SOLID, borderwidth=2, width=400, height=400)
|
||||||
|
self.canvas["input"].grid(row=2, column=1, columnspan=2)
|
||||||
|
Entry(self.frame_input, textvariable=self.input_img_path, width=60).grid(row=3, column=1, sticky="NEWS")
|
||||||
|
|
||||||
|
Button(self.frame_input, text="...", relief=RIDGE, command=self.select_img).grid(row=3, column=2, sticky="NEWS")
|
||||||
|
Button(self.frame_input, text="Calculer", relief=RIDGE, command=self.calcul).grid(row=4, column=1, columnspan=2, sticky="NEWS")
|
||||||
|
|
||||||
|
def zoom(event):
|
||||||
|
for type in ["input", "bil", "bic", "SRCNN", "SRGAN", "original"]:
|
||||||
|
if type in self.id_canvas:
|
||||||
|
x1 = ((event.x - self.diff) * self.image[type].width) // 400
|
||||||
|
x2 = ((event.x + self.diff) * self.image[type].width) // 400
|
||||||
|
y1 = ((event.y - self.diff) * self.image[type].height) // 400
|
||||||
|
y2 = ((event.y + self.diff) * self.image[type].height) // 400
|
||||||
|
|
||||||
|
_img = self.image[type].crop((x1, y1, x2, y2))
|
||||||
|
self.imagetk[type] = ImageTk.PhotoImage(_img.resize((400, 400)))
|
||||||
|
self.canvas[type].itemconfig(self.id_canvas[type], image = self.imagetk[type])
|
||||||
|
|
||||||
|
def change_diff(event):
|
||||||
|
if event.delta > 0: self.diff -= 5
|
||||||
|
else: self.diff += 5
|
||||||
|
zoom(event)
|
||||||
|
|
||||||
|
def reset(event):
|
||||||
|
for type in ["input", "bil", "bic", "SRCNN", "SRGAN", "original"]:
|
||||||
|
if type in self.id_canvas:
|
||||||
|
self.imagetk[type] = ImageTk.PhotoImage(self.image[type].resize((400, 400)))
|
||||||
|
self.canvas[type].itemconfig(self.id_canvas[type], image = self.imagetk[type])
|
||||||
|
|
||||||
|
def activate_zoom(event):
|
||||||
|
self.canvas["input"].bind("<Motion>", zoom)
|
||||||
|
def desactivate_zoom(event):
|
||||||
|
self.canvas["input"].unbind("<Motion>")
|
||||||
|
|
||||||
|
self.canvas["input"].bind("<ButtonPress-1>", activate_zoom)
|
||||||
|
self.canvas["input"].bind("<ButtonRelease-1>", desactivate_zoom)
|
||||||
|
|
||||||
|
self.canvas["input"].bind("<Button-3>", reset)
|
||||||
|
self.canvas["input"].bind("<MouseWheel>", change_diff)
|
||||||
|
|
||||||
|
##### BILINEAIRE
|
||||||
|
self.frame_bil = Frame(self.root)
|
||||||
|
self.frame_bil.grid(row=1, column=3)
|
||||||
|
|
||||||
|
Label(self.frame_bil, text="Bilinéaire", font=("Purisa", 18)).grid(row=1, column=1)
|
||||||
|
self.canvas["bil"] = Canvas(self.frame_bil, relief=SOLID, borderwidth=2, width=400, height=400)
|
||||||
|
self.canvas["bil"].grid(row=2, column=1)
|
||||||
|
|
||||||
|
self.label["bil"] = Label(self.frame_bil, fg="gray")
|
||||||
|
self.label["bil"].grid(row=3, column=1)
|
||||||
|
|
||||||
|
##### BICUBIQUE
|
||||||
|
self.frame_bic = Frame(self.root)
|
||||||
|
self.frame_bic.grid(row=2, column=3)
|
||||||
|
|
||||||
|
Label(self.frame_bic, text="Bicubique", font=("Purisa", 18)).grid(row=1, column=1)
|
||||||
|
self.canvas["bic"] = Canvas(self.frame_bic, relief=SOLID, borderwidth=2, width=400, height=400)
|
||||||
|
self.canvas["bic"].grid(row=2, column=1)
|
||||||
|
|
||||||
|
self.label["bic"] = Label(self.frame_bic, fg="gray")
|
||||||
|
self.label["bic"].grid(row=3, column=1)
|
||||||
|
|
||||||
|
##### SRCNN
|
||||||
|
self.frame_SRCNN = Frame(self.root)
|
||||||
|
self.frame_SRCNN.grid(row=1, column=4)
|
||||||
|
|
||||||
|
Label(self.frame_SRCNN, text="SRCNN", font=("Purisa", 18)).grid(row=1, column=1)
|
||||||
|
self.canvas["SRCNN"] = Canvas(self.frame_SRCNN, relief=SOLID, borderwidth=2, width=400, height=400)
|
||||||
|
self.canvas["SRCNN"].grid(row=2, column=1)
|
||||||
|
|
||||||
|
self.label["SRCNN"] = Label(self.frame_SRCNN, fg="gray")
|
||||||
|
self.label["SRCNN"].grid(row=3, column=1)
|
||||||
|
|
||||||
|
##### SRGAN
|
||||||
|
self.frame_SRGAN = Frame(self.root)
|
||||||
|
self.frame_SRGAN.grid(row=2, column=4)
|
||||||
|
|
||||||
|
Label(self.frame_SRGAN, text="SRGAN", font=("Purisa", 18)).grid(row=1, column=1)
|
||||||
|
self.canvas["SRGAN"] = Canvas(self.frame_SRGAN, relief=SOLID, borderwidth=2, width=400, height=400)
|
||||||
|
self.canvas["SRGAN"].grid(row=2, column=1)
|
||||||
|
|
||||||
|
self.label["SRGAN"] = Label(self.frame_SRGAN, fg="gray")
|
||||||
|
self.label["SRGAN"].grid(row=3, column=1)
|
||||||
|
|
||||||
|
##### ORIGINAL
|
||||||
|
self.frame_original = Frame(self.root)
|
||||||
|
|
||||||
|
Label(self.frame_original, text="Original", font=("Purisa", 18)).grid(row=1, column=1)
|
||||||
|
self.canvas["original"] = Canvas(self.frame_original, relief=SOLID, borderwidth=2, width=400, height=400)
|
||||||
|
self.canvas["original"].grid(row=2, column=1)
|
||||||
|
|
||||||
|
|
||||||
|
def select_img(self):
|
||||||
|
self.input_img_path.set(askopenfilename(filetypes = [('Images', '*.png *.jpg *.jpeg *.bmp *.gif')]))
|
||||||
|
if os.path.exists(self.input_img_path.get()):
|
||||||
|
self.image["input"] = Image.open(self.input_img_path.get())
|
||||||
|
self.imagetk["input"] = ImageTk.PhotoImage(self.image["input"].resize((400, 400)))
|
||||||
|
self.id_canvas["input"] = self.canvas["input"].create_image(204, 204, image = self.imagetk["input"])
|
||||||
|
|
||||||
|
_l = self.input_img_path.get().split(".")
|
||||||
|
_l[-2] += "_original"
|
||||||
|
original_img_path = ".".join(_l)
|
||||||
|
if os.path.exists(original_img_path):
|
||||||
|
self.image["original"] = Image.open(original_img_path)
|
||||||
|
self.imagetk["original"] = ImageTk.PhotoImage(self.image["original"].resize((400, 400)))
|
||||||
|
self.id_canvas["original"] = self.canvas["original"].create_image(204, 204, image=self.imagetk["original"])
|
||||||
|
|
||||||
|
self.frame_original.grid(row=1, column=5, rowspan=2)
|
||||||
|
else: self.frame_original.grid_forget()
|
||||||
|
|
||||||
|
|
||||||
|
else: showerror("Erreur", "Ce fichier n'existe pas.")
|
||||||
|
|
||||||
|
def calcul(self):
|
||||||
|
def thread_func():
|
||||||
|
type2func = {"bil": self.Bilinear, "bic": self.Bicubic, "SRCNN": self.SRCNN, "SRGAN": self.SRGAN}
|
||||||
|
for type in type2func:
|
||||||
|
try:
|
||||||
|
t1 = time.time()
|
||||||
|
type2func[type]()
|
||||||
|
t2 = time.time()
|
||||||
|
except: self.label["bil"].config(text="Erreur", fg="red")
|
||||||
|
else: self.label["bil"].config(text="Temps de calcul : %0.2fs" % (t2 - t1), fg="gray")
|
||||||
|
|
||||||
|
thfunc = Thread(target=thread_func)
|
||||||
|
thfunc.setDaemon(True)
|
||||||
|
thfunc.start()
|
||||||
|
|
||||||
|
def Bilinear(self):
|
||||||
|
self.image["bil"] = Image.open(self.input_img_path.get()).resize((400 * 4, 400 * 4), Image.BILINEAR)
|
||||||
|
self.imagetk["bil"] = ImageTk.PhotoImage(self.image["bil"].resize((400, 400)))
|
||||||
|
self.id_canvas["bil"] = self.canvas["bil"].create_image(204, 204, image=self.imagetk["bil"])
|
||||||
|
|
||||||
|
def Bicubic(self):
|
||||||
|
self.image["bic"] = Image.open(self.input_img_path.get()).resize((400 * 4, 400 * 4), Image.BICUBIC)
|
||||||
|
self.imagetk["bic"] = ImageTk.PhotoImage(self.image["bic"].resize((400, 400)))
|
||||||
|
self.id_canvas["bic"] = self.canvas["bic"].create_image(204, 204, image=self.imagetk["bic"])
|
||||||
|
|
||||||
|
def SRCNN(self):
|
||||||
|
SRCNN_img_path = "./SRCNN-pytorch-master/data/" + os.path.basename(self.input_img_path.get())
|
||||||
|
self.image["input"].resize((self.image["input"].width * 4, self.image["input"].height * 4), Image.BICUBIC).save(SRCNN_img_path)
|
||||||
|
|
||||||
|
subprocess.call(
|
||||||
|
'"D:/ProgramData/Anaconda3/python.exe" \
|
||||||
|
"./SRCNN-pytorch-master/test.py" \
|
||||||
|
--weights-file "./SRCNN-pytorch-master/weights/srcnn_x4.pth" \
|
||||||
|
--image-file "%s" \
|
||||||
|
--scale 4' % SRCNN_img_path)
|
||||||
|
|
||||||
|
l = SRCNN_img_path.split(".")
|
||||||
|
l[-2] += "_srcnn_x4"
|
||||||
|
self.SRCNN_img_path = ".".join(l)
|
||||||
|
|
||||||
|
self.image["SRCNN"] = Image.open(self.SRCNN_img_path)
|
||||||
|
self.imagetk["SRCNN"] = ImageTk.PhotoImage(self.image["SRCNN"].resize((400, 400)))
|
||||||
|
self.id_canvas["SRCNN"] = self.canvas["SRCNN"].create_image(204, 204, image=self.imagetk["SRCNN"])
|
||||||
|
|
||||||
|
def SRGAN(self):
|
||||||
|
SRGAN_img_path = "./SRGAN-PyTorch-master/LR_imgs_dir/" + os.path.basename(self.input_img_path.get())
|
||||||
|
shutil.copy(self.input_img_path.get(), SRGAN_img_path)
|
||||||
|
|
||||||
|
workingdir = os.getcwd()
|
||||||
|
os.chdir("./SRGAN-PyTorch-master/")
|
||||||
|
|
||||||
|
subprocess.call(
|
||||||
|
'"D:/ProgramData/Anaconda3/python.exe" \
|
||||||
|
main.py \
|
||||||
|
--mode test_only \
|
||||||
|
--LR_path LR_imgs_dir/ \
|
||||||
|
--generator_path model/SRGAN.pt"')
|
||||||
|
|
||||||
|
os.chdir(workingdir)
|
||||||
|
|
||||||
|
os.remove(SRGAN_img_path)
|
||||||
|
|
||||||
|
self.SRGAN_img_path = glob.glob("./SRGAN-PyTorch-master/result/res_*.png")[-1]
|
||||||
|
|
||||||
|
self.image["SRGAN"] = Image.open(self.SRGAN_img_path)
|
||||||
|
self.imagetk["SRGAN"] = ImageTk.PhotoImage(self.image["SRGAN"].resize((400, 400)))
|
||||||
|
self.id_canvas["SRGAN"] = self.canvas["SRGAN"].create_image(204, 204, image=self.imagetk["SRGAN"])
|
||||||
|
|
||||||
|
App = SRGUI()
|
||||||
|
mainloop()
|
Loading…
Reference in a new issue