Initial commit

This commit is contained in:
ℝλℙℍλΣʟ 2020-11-12 23:30:33 +01:00 committed by GitHub
commit 50e4c082d8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 1477 additions and 0 deletions

View file

@ -0,0 +1,129 @@
# SRCNN
This repository is implementation of the ["Image Super-Resolution Using Deep Convolutional Networks"](https://arxiv.org/abs/1501.00092).
<center><img src="./thumbnails/fig1.png"></center>
## Differences from the original
- Added the zero-padding
- Used the Adam instead of the SGD
- Removed the weights initialization
## Requirements
- PyTorch 1.0.0
- Numpy 1.15.4
- Pillow 5.4.1
- h5py 2.8.0
- tqdm 4.30.0
## Train
The 91-image, Set5 dataset converted to HDF5 can be downloaded from the links below.
| Dataset | Scale | Type | Link |
|---------|-------|------|------|
| 91-image | 2 | Train | [Download](https://www.dropbox.com/s/2hsah93sxgegsry/91-image_x2.h5?dl=0) |
| 91-image | 3 | Train | [Download](https://www.dropbox.com/s/curldmdf11iqakd/91-image_x3.h5?dl=0) |
| 91-image | 4 | Train | [Download](https://www.dropbox.com/s/22afykv4amfxeio/91-image_x4.h5?dl=0) |
| Set5 | 2 | Eval | [Download](https://www.dropbox.com/s/r8qs6tp395hgh8g/Set5_x2.h5?dl=0) |
| Set5 | 3 | Eval | [Download](https://www.dropbox.com/s/58ywjac4te3kbqq/Set5_x3.h5?dl=0) |
| Set5 | 4 | Eval | [Download](https://www.dropbox.com/s/0rz86yn3nnrodlb/Set5_x4.h5?dl=0) |
Otherwise, you can use `prepare.py` to create custom dataset.
```bash
python train.py --train-file "BLAH_BLAH/91-image_x3.h5" \
--eval-file "BLAH_BLAH/Set5_x3.h5" \
--outputs-dir "BLAH_BLAH/outputs" \
--scale 3 \
--lr 1e-4 \
--batch-size 16 \
--num-epochs 400 \
--num-workers 8 \
--seed 123
```
## Test
Pre-trained weights can be downloaded from the links below.
| Model | Scale | Link |
|-------|-------|------|
| 9-5-5 | 2 | [Download](https://www.dropbox.com/s/rxluu1y8ptjm4rn/srcnn_x2.pth?dl=0) |
| 9-5-5 | 3 | [Download](https://www.dropbox.com/s/zn4fdobm2kw0c58/srcnn_x3.pth?dl=0) |
| 9-5-5 | 4 | [Download](https://www.dropbox.com/s/pd5b2ketm0oamhj/srcnn_x4.pth?dl=0) |
The results are stored in the same path as the query image.
```bash
python test.py --weights-file "BLAH_BLAH/srcnn_x3.pth" \
--image-file "data/butterfly_GT.bmp" \
--scale 3
```
## Results
We used the network settings for experiments, i.e., <a href="https://www.codecogs.com/eqnedit.php?latex={&space;f&space;}_{&space;1&space;}=9,{&space;f&space;}_{&space;2&space;}=5,{&space;f&space;}_{&space;3&space;}=5,{&space;n&space;}_{&space;1&space;}=64,{&space;n&space;}_{&space;2&space;}=32,{&space;n&space;}_{&space;3&space;}=1" target="_blank"><img src="https://latex.codecogs.com/gif.latex?{&space;f&space;}_{&space;1&space;}=9,{&space;f&space;}_{&space;2&space;}=5,{&space;f&space;}_{&space;3&space;}=5,{&space;n&space;}_{&space;1&space;}=64,{&space;n&space;}_{&space;2&space;}=32,{&space;n&space;}_{&space;3&space;}=1" title="{ f }_{ 1 }=9,{ f }_{ 2 }=5,{ f }_{ 3 }=5,{ n }_{ 1 }=64,{ n }_{ 2 }=32,{ n }_{ 3 }=1" /></a>.
PSNR was calculated on the Y channel.
### Set5
| Eval. Mat | Scale | SRCNN | SRCNN (Ours) |
|-----------|-------|-------|--------------|
| PSNR | 2 | 36.66 | 36.65 |
| PSNR | 3 | 32.75 | 33.29 |
| PSNR | 4 | 30.49 | 30.25 |
<table>
<tr>
<td><center>Original</center></td>
<td><center>BICUBIC x3</center></td>
<td><center>SRCNN x3 (27.53 dB)</center></td>
</tr>
<tr>
<td>
<center><img src="./data/butterfly_GT.bmp""></center>
</td>
<td>
<center><img src="./data/butterfly_GT_bicubic_x3.bmp"></center>
</td>
<td>
<center><img src="./data/butterfly_GT_srcnn_x3.bmp"></center>
</td>
</tr>
<tr>
<td><center>Original</center></td>
<td><center>BICUBIC x3</center></td>
<td><center>SRCNN x3 (29.30 dB)</center></td>
</tr>
<tr>
<td>
<center><img src="./data/zebra.bmp""></center>
</td>
<td>
<center><img src="./data/zebra_bicubic_x3.bmp"></center>
</td>
<td>
<center><img src="./data/zebra_srcnn_x3.bmp"></center>
</td>
</tr>
<tr>
<td><center>Original</center></td>
<td><center>BICUBIC x3</center></td>
<td><center>SRCNN x3 (28.58 dB)</center></td>
</tr>
<tr>
<td>
<center><img src="./data/ppt3.bmp""></center>
</td>
<td>
<center><img src="./data/ppt3_bicubic_x3.bmp"></center>
</td>
<td>
<center><img src="./data/ppt3_srcnn_x3.bmp"></center>
</td>
</tr>
</table>

Binary file not shown.

Binary file not shown.

View file

@ -0,0 +1,31 @@
import h5py
import numpy as np
from torch.utils.data import Dataset
class TrainDataset(Dataset):
def __init__(self, h5_file):
super(TrainDataset, self).__init__()
self.h5_file = h5_file
def __getitem__(self, idx):
with h5py.File(self.h5_file, 'r') as f:
return np.expand_dims(f['lr'][idx] / 255., 0), np.expand_dims(f['hr'][idx] / 255., 0)
def __len__(self):
with h5py.File(self.h5_file, 'r') as f:
return len(f['lr'])
class EvalDataset(Dataset):
def __init__(self, h5_file):
super(EvalDataset, self).__init__()
self.h5_file = h5_file
def __getitem__(self, idx):
with h5py.File(self.h5_file, 'r') as f:
return np.expand_dims(f['lr'][str(idx)][:, :] / 255., 0), np.expand_dims(f['hr'][str(idx)][:, :] / 255., 0)
def __len__(self):
with h5py.File(self.h5_file, 'r') as f:
return len(f['lr'])

View file

@ -0,0 +1,2 @@
Programme écrit par : yjn870
Code disponible sur github au lien : https://github.com/yjn870/SRCNN-pytorch

View file

@ -0,0 +1,16 @@
from torch import nn
class SRCNN(nn.Module):
def __init__(self, num_channels=1):
super(SRCNN, self).__init__()
self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2)
self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2)
self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.conv3(x)
return x

View file

@ -0,0 +1,78 @@
import argparse
import glob
import h5py
import numpy as np
import PIL.Image as pil_image
from utils import convert_rgb_to_y
def train(args):
h5_file = h5py.File(args.output_path, 'w')
lr_patches = []
hr_patches = []
for image_path in sorted(glob.glob('{}/*'.format(args.images_dir))):
hr = pil_image.open(image_path).convert('RGB')
hr_width = (hr.width // args.scale) * args.scale
hr_height = (hr.height // args.scale) * args.scale
hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)
lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)
hr = np.array(hr).astype(np.float32)
lr = np.array(lr).astype(np.float32)
hr = convert_rgb_to_y(hr)
lr = convert_rgb_to_y(lr)
for i in range(0, lr.shape[0] - args.patch_size + 1, args.stride):
for j in range(0, lr.shape[1] - args.patch_size + 1, args.stride):
lr_patches.append(lr[i:i + args.patch_size, j:j + args.patch_size])
hr_patches.append(hr[i:i + args.patch_size, j:j + args.patch_size])
lr_patches = np.array(lr_patches)
hr_patches = np.array(hr_patches)
h5_file.create_dataset('lr', data=lr_patches)
h5_file.create_dataset('hr', data=hr_patches)
h5_file.close()
def eval(args):
h5_file = h5py.File(args.output_path, 'w')
lr_group = h5_file.create_group('lr')
hr_group = h5_file.create_group('hr')
for i, image_path in enumerate(sorted(glob.glob('{}/*'.format(args.images_dir)))):
hr = pil_image.open(image_path).convert('RGB')
hr_width = (hr.width // args.scale) * args.scale
hr_height = (hr.height // args.scale) * args.scale
hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)
lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)
hr = np.array(hr).astype(np.float32)
lr = np.array(lr).astype(np.float32)
hr = convert_rgb_to_y(hr)
lr = convert_rgb_to_y(lr)
lr_group.create_dataset(str(i), data=lr)
hr_group.create_dataset(str(i), data=hr)
h5_file.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--images-dir', type=str, required=True)
parser.add_argument('--output-path', type=str, required=True)
parser.add_argument('--patch-size', type=int, default=33)
parser.add_argument('--stride', type=int, default=14)
parser.add_argument('--scale', type=int, default=2)
parser.add_argument('--eval', action='store_true')
args = parser.parse_args()
if not args.eval:
train(args)
else:
eval(args)

View file

@ -0,0 +1,67 @@
import argparse
import torch
import torch.backends.cudnn as cudnn
import numpy as np
import PIL.Image as pil_image
from models import SRCNN
from utils import convert_rgb_to_ycbcr, convert_ycbcr_to_rgb, calc_psnr
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights-file', type=str, required=True)
parser.add_argument('--image-file', type=str, required=True)
parser.add_argument('--scale', type=int, default=3)
args = parser.parse_args()
cudnn.benchmark = True
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = SRCNN().to(device)
state_dict = model.state_dict()
for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():
if n in state_dict.keys():
state_dict[n].copy_(p)
else:
raise KeyError(n)
model.eval()
image = pil_image.open(args.image_file).convert('RGB')
image_width = (image.width // args.scale) * args.scale
image_height = (image.height // args.scale) * args.scale
image = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
image = image.resize((image.width // args.scale, image.height // args.scale), resample=pil_image.BICUBIC)
image = image.resize((image.width * args.scale, image.height * args.scale), resample=pil_image.BICUBIC)
image = np.array(image).astype(np.float32)
ycbcr = convert_rgb_to_ycbcr(image)
y = ycbcr[..., 0]
y /= 255.
y = torch.from_numpy(y).to(device)
y = y.unsqueeze(0).unsqueeze(0)
with torch.no_grad():
preds = model(y).clamp(0.0, 1.0)
psnr = calc_psnr(y, preds)
print('PSNR: {:.2f}'.format(psnr))
preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)
output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
output = pil_image.fromarray(output)
l = args.image_file.split(".")
l[-2] += '_srcnn_x{}'.format(args.scale)
args.image_file = ".".join(l)
output.save(args.image_file)

View file

@ -0,0 +1,112 @@
import argparse
import os
import copy
import torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from models import SRCNN
from datasets import TrainDataset, EvalDataset
from utils import AverageMeter, calc_psnr
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--train-file', type=str, required=True)
parser.add_argument('--eval-file', type=str, required=True)
parser.add_argument('--outputs-dir', type=str, required=True)
parser.add_argument('--scale', type=int, default=3)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--batch-size', type=int, default=16)
parser.add_argument('--num-epochs', type=int, default=400)
parser.add_argument('--num-workers', type=int, default=8)
parser.add_argument('--seed', type=int, default=123)
args = parser.parse_args()
args.outputs_dir = os.path.join(args.outputs_dir, 'x{}'.format(args.scale))
if not os.path.exists(args.outputs_dir):
os.makedirs(args.outputs_dir)
cudnn.benchmark = True
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(args.seed)
model = SRCNN().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam([
{'params': model.conv1.parameters()},
{'params': model.conv2.parameters()},
{'params': model.conv3.parameters(), 'lr': args.lr * 0.1}
], lr=args.lr)
train_dataset = TrainDataset(args.train_file)
train_dataloader = DataLoader(dataset=train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True,
drop_last=True)
eval_dataset = EvalDataset(args.eval_file)
eval_dataloader = DataLoader(dataset=eval_dataset, batch_size=1)
best_weights = copy.deepcopy(model.state_dict())
best_epoch = 0
best_psnr = 0.0
for epoch in range(args.num_epochs):
model.train()
epoch_losses = AverageMeter()
with tqdm(total=(len(train_dataset) - len(train_dataset) % args.batch_size)) as t:
t.set_description('epoch: {}/{}'.format(epoch, args.num_epochs - 1))
for data in train_dataloader:
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
preds = model(inputs)
loss = criterion(preds, labels)
epoch_losses.update(loss.item(), len(inputs))
optimizer.zero_grad()
loss.backward()
optimizer.step()
t.set_postfix(loss='{:.6f}'.format(epoch_losses.avg))
t.update(len(inputs))
torch.save(model.state_dict(), os.path.join(args.outputs_dir, 'epoch_{}.pth'.format(epoch)))
model.eval()
epoch_psnr = AverageMeter()
for data in eval_dataloader:
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
with torch.no_grad():
preds = model(inputs).clamp(0.0, 1.0)
epoch_psnr.update(calc_psnr(preds, labels), len(inputs))
print('eval psnr: {:.2f}'.format(epoch_psnr.avg))
if epoch_psnr.avg > best_psnr:
best_epoch = epoch
best_psnr = epoch_psnr.avg
best_weights = copy.deepcopy(model.state_dict())
print('best epoch: {}, psnr: {:.2f}'.format(best_epoch, best_psnr))
torch.save(best_weights, os.path.join(args.outputs_dir, 'best.pth'))

View file

@ -0,0 +1,68 @@
import torch
import numpy as np
def convert_rgb_to_y(img):
if type(img) == np.ndarray:
return 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
elif type(img) == torch.Tensor:
if len(img.shape) == 4:
img = img.squeeze(0)
return 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256.
else:
raise Exception('Unknown Type', type(img))
def convert_rgb_to_ycbcr(img):
if type(img) == np.ndarray:
y = 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
cb = 128. + (-37.945 * img[:, :, 0] - 74.494 * img[:, :, 1] + 112.439 * img[:, :, 2]) / 256.
cr = 128. + (112.439 * img[:, :, 0] - 94.154 * img[:, :, 1] - 18.285 * img[:, :, 2]) / 256.
return np.array([y, cb, cr]).transpose([1, 2, 0])
elif type(img) == torch.Tensor:
if len(img.shape) == 4:
img = img.squeeze(0)
y = 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256.
cb = 128. + (-37.945 * img[0, :, :] - 74.494 * img[1, :, :] + 112.439 * img[2, :, :]) / 256.
cr = 128. + (112.439 * img[0, :, :] - 94.154 * img[1, :, :] - 18.285 * img[2, :, :]) / 256.
return torch.cat([y, cb, cr], 0).permute(1, 2, 0)
else:
raise Exception('Unknown Type', type(img))
def convert_ycbcr_to_rgb(img):
if type(img) == np.ndarray:
r = 298.082 * img[:, :, 0] / 256. + 408.583 * img[:, :, 2] / 256. - 222.921
g = 298.082 * img[:, :, 0] / 256. - 100.291 * img[:, :, 1] / 256. - 208.120 * img[:, :, 2] / 256. + 135.576
b = 298.082 * img[:, :, 0] / 256. + 516.412 * img[:, :, 1] / 256. - 276.836
return np.array([r, g, b]).transpose([1, 2, 0])
elif type(img) == torch.Tensor:
if len(img.shape) == 4:
img = img.squeeze(0)
r = 298.082 * img[0, :, :] / 256. + 408.583 * img[2, :, :] / 256. - 222.921
g = 298.082 * img[0, :, :] / 256. - 100.291 * img[1, :, :] / 256. - 208.120 * img[2, :, :] / 256. + 135.576
b = 298.082 * img[0, :, :] / 256. + 516.412 * img[1, :, :] / 256. - 276.836
return torch.cat([r, g, b], 0).permute(1, 2, 0)
else:
raise Exception('Unknown Type', type(img))
def calc_psnr(img1, img2):
return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))
class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View file

@ -0,0 +1,135 @@
import torch
from torch.utils.data import Dataset
import os
from PIL import Image
import numpy as np
import random
class mydata(Dataset):
def __init__(self, LR_path, GT_path, in_memory = True, transform = None):
self.LR_path = LR_path
self.GT_path = GT_path
self.in_memory = in_memory
self.transform = transform
self.LR_img = sorted(os.listdir(LR_path))
self.GT_img = sorted(os.listdir(GT_path))
if in_memory:
self.LR_img = [np.array(Image.open(os.path.join(self.LR_path, lr)).convert("RGB")).astype(np.uint8) for lr in self.LR_img]
self.GT_img = [np.array(Image.open(os.path.join(self.GT_path, gt)).convert("RGB")).astype(np.uint8) for gt in self.GT_img]
def __len__(self):
return len(self.LR_img)
def __getitem__(self, i):
img_item = {}
if self.in_memory:
GT = self.GT_img[i].astype(np.float32)
LR = self.LR_img[i].astype(np.float32)
else:
GT = np.array(Image.open(os.path.join(self.GT_path, self.GT_img[i])).convert("RGB"))
LR = np.array(Image.open(os.path.join(self.LR_path, self.LR_img[i])).convert("RGB"))
img_item['GT'] = (GT / 127.5) - 1.0
img_item['LR'] = (LR / 127.5) - 1.0
if self.transform is not None:
img_item = self.transform(img_item)
img_item['GT'] = img_item['GT'].transpose(2, 0, 1).astype(np.float32)
img_item['LR'] = img_item['LR'].transpose(2, 0, 1).astype(np.float32)
return img_item
class testOnly_data(Dataset):
def __init__(self, LR_path, in_memory = True, transform = None):
self.LR_path = LR_path
self.LR_img = sorted(os.listdir(LR_path))
self.in_memory = in_memory
if in_memory:
self.LR_img = [np.array(Image.open(os.path.join(self.LR_path, lr))) for lr in self.LR_img]
def __len__(self):
return len(self.LR_img)
def __getitem__(self, i):
img_item = {}
if self.in_memory:
LR = self.LR_img[i]
else:
LR = np.array(Image.open(os.path.join(self.LR_path, self.LR_img[i])))
img_item['LR'] = (LR / 127.5) - 1.0
img_item['LR'] = img_item['LR'].transpose(2, 0, 1).astype(np.float32)
return img_item
class crop(object):
def __init__(self, scale, patch_size):
self.scale = scale
self.patch_size = patch_size
def __call__(self, sample):
LR_img, GT_img = sample['LR'], sample['GT']
ih, iw = LR_img.shape[:2]
ix = random.randrange(0, iw - self.patch_size +1)
iy = random.randrange(0, ih - self.patch_size +1)
tx = ix * self.scale
ty = iy * self.scale
LR_patch = LR_img[iy : iy + self.patch_size, ix : ix + self.patch_size]
GT_patch = GT_img[ty : ty + (self.scale * self.patch_size), tx : tx + (self.scale * self.patch_size)]
return {'LR' : LR_patch, 'GT' : GT_patch}
class augmentation(object):
def __call__(self, sample):
LR_img, GT_img = sample['LR'], sample['GT']
hor_flip = random.randrange(0,2)
ver_flip = random.randrange(0,2)
rot = random.randrange(0,2)
if hor_flip:
temp_LR = np.fliplr(LR_img)
LR_img = temp_LR.copy()
temp_GT = np.fliplr(GT_img)
GT_img = temp_GT.copy()
del temp_LR, temp_GT
if ver_flip:
temp_LR = np.flipud(LR_img)
LR_img = temp_LR.copy()
temp_GT = np.flipud(GT_img)
GT_img = temp_GT.copy()
del temp_LR, temp_GT
if rot:
LR_img = LR_img.transpose(1, 0, 2)
GT_img = GT_img.transpose(1, 0, 2)
return {'LR' : LR_img, 'GT' : GT_img}

View file

@ -0,0 +1,2 @@
Programme écrit par dongheehand en Python
Disponible sur github au lien https://github.com/dongheehand/SRGAN-PyTorch

View file

@ -0,0 +1,60 @@
import torch
import torch.nn as nn
from torchvision import transforms
class MeanShift(nn.Conv2d):
def __init__(
self, rgb_range = 1,
norm_mean=(0.485, 0.456, 0.406), norm_std=(0.229, 0.224, 0.225), sign=-1):
super(MeanShift, self).__init__(3, 3, kernel_size=1)
std = torch.Tensor(norm_std)
self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
self.bias.data = sign * rgb_range * torch.Tensor(norm_mean) / std
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for p in self.parameters():
p.requires_grad = False
class perceptual_loss(nn.Module):
def __init__(self, vgg):
super(perceptual_loss, self).__init__()
self.normalization_mean = [0.485, 0.456, 0.406]
self.normalization_std = [0.229, 0.224, 0.225]
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.transform = MeanShift(norm_mean = self.normalization_mean, norm_std = self.normalization_std).to(self.device)
self.vgg = vgg
self.criterion = nn.MSELoss()
def forward(self, HR, SR, layer = 'relu5_4'):
## HR and SR should be normalized [0,1]
hr = self.transform(HR)
sr = self.transform(SR)
hr_feat = getattr(self.vgg(hr), layer)
sr_feat = getattr(self.vgg(sr), layer)
return self.criterion(hr_feat, sr_feat), hr_feat, sr_feat
class TVLoss(nn.Module):
def __init__(self, tv_loss_weight=1):
super(TVLoss, self).__init__()
self.tv_loss_weight = tv_loss_weight
def forward(self, x):
batch_size = x.size()[0]
h_x = x.size()[2]
w_x = x.size()[3]
count_h = self.tensor_size(x[:, :, 1:, :])
count_w = self.tensor_size(x[:, :, :, 1:])
h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size
@staticmethod
def tensor_size(t):
return t.size()[1] * t.size()[2] * t.size()[3]

View file

@ -0,0 +1,39 @@
from mode import *
import argparse
parser = argparse.ArgumentParser()
def str2bool(v):
return v.lower() in ('true')
parser.add_argument("--LR_path", type = str, default = '../dataSet/DIV2K/DIV2K_train_LR_bicubic/X4')
parser.add_argument("--GT_path", type = str, default = '../dataSet/DIV2K/DIV2K_train_HR/')
parser.add_argument("--res_num", type = int, default = 16)
parser.add_argument("--num_workers", type = int, default = 0)
parser.add_argument("--batch_size", type = int, default = 16)
parser.add_argument("--L2_coeff", type = float, default = 1.0)
parser.add_argument("--adv_coeff", type = float, default = 1e-3)
parser.add_argument("--tv_loss_coeff", type = float, default = 0.0)
parser.add_argument("--pre_train_epoch", type = int, default = 8000)
parser.add_argument("--fine_train_epoch", type = int, default = 4000)
parser.add_argument("--scale", type = int, default = 4)
parser.add_argument("--patch_size", type = int, default = 24)
parser.add_argument("--feat_layer", type = str, default = 'relu5_4')
parser.add_argument("--vgg_rescale_coeff", type = float, default = 0.006)
parser.add_argument("--fine_tuning", type = str2bool, default = False)
parser.add_argument("--in_memory", type = str2bool, default = True)
parser.add_argument("--generator_path", type = str)
parser.add_argument("--mode", type = str, default = 'train')
args = parser.parse_args()
if args.mode == 'train':
train(args)
elif args.mode == 'test':
test(args)
elif args.mode == 'test_only':
test_only(args)

View file

@ -0,0 +1,208 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from losses import TVLoss, perceptual_loss
from dataset import *
from srgan_model import Generator, Discriminator
from vgg19 import vgg19
import numpy as np
from PIL import Image
from skimage.color import rgb2ycbcr
from skimage.measure import compare_psnr
def train(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([crop(args.scale, args.patch_size), augmentation()])
dataset = mydata(GT_path = args.GT_path, LR_path = args.LR_path, in_memory = args.in_memory, transform = transform)
loader = DataLoader(dataset, batch_size = args.batch_size, shuffle = True, num_workers = args.num_workers)
generator = Generator(img_feat = 3, n_feats = 64, kernel_size = 3, num_block = args.res_num, scale=args.scale)
if args.fine_tuning:
generator.load_state_dict(torch.load(args.generator_path))
print("pre-trained model is loaded")
print("path : %s"%(args.generator_path))
generator = generator.to(device)
generator.train()
l2_loss = nn.MSELoss()
g_optim = optim.Adam(generator.parameters(), lr = 1e-4)
pre_epoch = 0
fine_epoch = 0
#### Train using L2_loss
while pre_epoch < args.pre_train_epoch:
for i, tr_data in enumerate(loader):
gt = tr_data['GT'].to(device)
lr = tr_data['LR'].to(device)
output, _ = generator(lr)
loss = l2_loss(gt, output)
g_optim.zero_grad()
loss.backward()
g_optim.step()
pre_epoch += 1
if pre_epoch % 2 == 0:
print(pre_epoch)
print(loss.item())
print('=========')
if pre_epoch % 800 ==0:
torch.save(generator.state_dict(), './model/pre_trained_model_%03d.pt'%pre_epoch)
#### Train using perceptual & adversarial loss
vgg_net = vgg19().to(device)
vgg_net = vgg_net.eval()
discriminator = Discriminator(patch_size = args.patch_size * args.scale)
discriminator = discriminator.to(device)
discriminator.train()
d_optim = optim.Adam(discriminator.parameters(), lr = 1e-4)
scheduler = optim.lr_scheduler.StepLR(g_optim, step_size = 2000, gamma = 0.1)
VGG_loss = perceptual_loss(vgg_net)
cross_ent = nn.BCELoss()
tv_loss = TVLoss()
real_label = torch.ones((args.batch_size, 1)).to(device)
fake_label = torch.zeros((args.batch_size, 1)).to(device)
while fine_epoch < args.fine_train_epoch:
scheduler.step()
for i, tr_data in enumerate(loader):
gt = tr_data['GT'].to(device)
lr = tr_data['LR'].to(device)
## Training Discriminator
output, _ = generator(lr)
fake_prob = discriminator(output)
real_prob = discriminator(gt)
d_loss_real = cross_ent(real_prob, real_label)
d_loss_fake = cross_ent(fake_prob, fake_label)
d_loss = d_loss_real + d_loss_fake
g_optim.zero_grad()
d_optim.zero_grad()
d_loss.backward()
d_optim.step()
## Training Generator
output, _ = generator(lr)
fake_prob = discriminator(output)
_percep_loss, hr_feat, sr_feat = VGG_loss((gt + 1.0) / 2.0, (output + 1.0) / 2.0, layer = args.feat_layer)
L2_loss = l2_loss(output, gt)
percep_loss = args.vgg_rescale_coeff * _percep_loss
adversarial_loss = args.adv_coeff * cross_ent(fake_prob, real_label)
total_variance_loss = args.tv_loss_coeff * tv_loss(args.vgg_rescale_coeff * (hr_feat - sr_feat)**2)
g_loss = percep_loss + adversarial_loss + total_variance_loss + L2_loss
g_optim.zero_grad()
d_optim.zero_grad()
g_loss.backward()
g_optim.step()
fine_epoch += 1
if fine_epoch % 2 == 0:
print(fine_epoch)
print(g_loss.item())
print(d_loss.item())
print('=========')
if fine_epoch % 500 ==0:
torch.save(generator.state_dict(), './model/SRGAN_gene_%03d.pt'%fine_epoch)
torch.save(discriminator.state_dict(), './model/SRGAN_discrim_%03d.pt'%fine_epoch)
# In[ ]:
def test(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = mydata(GT_path = args.GT_path, LR_path = args.LR_path, in_memory = False, transform = None)
loader = DataLoader(dataset, batch_size = 1, shuffle = False, num_workers = args.num_workers)
generator = Generator(img_feat = 3, n_feats = 64, kernel_size = 3, num_block = args.res_num)
generator.load_state_dict(torch.load(args.generator_path))
generator = generator.to(device)
generator.eval()
f = open('./result.txt', 'w')
psnr_list = []
with torch.no_grad():
for i, te_data in enumerate(loader):
gt = te_data['GT'].to(device)
lr = te_data['LR'].to(device)
bs, c, h, w = lr.size()
gt = gt[:, :, : h * args.scale, : w *args.scale]
output, _ = generator(lr)
output = output[0].cpu().numpy()
output = np.clip(output, -1.0, 1.0)
gt = gt[0].cpu().numpy()
output = (output + 1.0) / 2.0
gt = (gt + 1.0) / 2.0
output = output.transpose(1,2,0)
gt = gt.transpose(1,2,0)
y_output = rgb2ycbcr(output)[args.scale:-args.scale, args.scale:-args.scale, :1]
y_gt = rgb2ycbcr(gt)[args.scale:-args.scale, args.scale:-args.scale, :1]
psnr = compare_psnr(y_output / 255.0, y_gt / 255.0, data_range = 1.0)
psnr_list.append(psnr)
f.write('psnr : %04f \n' % psnr)
result = Image.fromarray((output * 255.0).astype(np.uint8))
result.save('./result/res_%04d.png'%i)
f.write('avg psnr : %04f' % np.mean(psnr_list))
def test_only(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = testOnly_data(LR_path = args.LR_path, in_memory = False, transform = None)
loader = DataLoader(dataset, batch_size = 1, shuffle = False, num_workers = args.num_workers)
generator = Generator(img_feat = 3, n_feats = 64, kernel_size = 3, num_block = args.res_num)
generator.load_state_dict(torch.load(args.generator_path))
generator = generator.to(device)
generator.eval()
with torch.no_grad():
for i, te_data in enumerate(loader):
lr = te_data['LR'].to(device)
output, _ = generator(lr)
output = output[0].cpu().numpy()
output = (output + 1.0) / 2.0
output = output.transpose(1,2,0)
result = Image.fromarray((output * 255.0).astype(np.uint8))
result.save('./result/res_%04d.png'%i)

View file

@ -0,0 +1,97 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class _conv(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias):
super(_conv, self).__init__(in_channels = in_channels, out_channels = out_channels,
kernel_size = kernel_size, stride = stride, padding = (kernel_size) // 2, bias = True)
self.weight.data = torch.normal(torch.zeros((out_channels, in_channels, kernel_size, kernel_size)), 0.02)
self.bias.data = torch.zeros((out_channels))
for p in self.parameters():
p.requires_grad = True
class conv(nn.Module):
def __init__(self, in_channel, out_channel, kernel_size, BN = False, act = None, stride = 1, bias = True):
super(conv, self).__init__()
m = []
m.append(_conv(in_channels = in_channel, out_channels = out_channel,
kernel_size = kernel_size, stride = stride, padding = (kernel_size) // 2, bias = True))
if BN:
m.append(nn.BatchNorm2d(num_features = out_channel))
if act is not None:
m.append(act)
self.body = nn.Sequential(*m)
def forward(self, x):
out = self.body(x)
return out
class ResBlock(nn.Module):
def __init__(self, channels, kernel_size, act = nn.ReLU(inplace = True), bias = True):
super(ResBlock, self).__init__()
m = []
m.append(conv(channels, channels, kernel_size, BN = True, act = act))
m.append(conv(channels, channels, kernel_size, BN = True, act = None))
self.body = nn.Sequential(*m)
def forward(self, x):
res = self.body(x)
res += x
return res
class BasicBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, num_res_block, act = nn.ReLU(inplace = True)):
super(BasicBlock, self).__init__()
m = []
self.conv = conv(in_channels, out_channels, kernel_size, BN = False, act = act)
for i in range(num_res_block):
m.append(ResBlock(out_channels, kernel_size, act))
m.append(conv(out_channels, out_channels, kernel_size, BN = True, act = None))
self.body = nn.Sequential(*m)
def forward(self, x):
res = self.conv(x)
out = self.body(res)
out += res
return out
class Upsampler(nn.Module):
def __init__(self, channel, kernel_size, scale, act = nn.ReLU(inplace = True)):
super(Upsampler, self).__init__()
m = []
m.append(conv(channel, channel * scale * scale, kernel_size))
m.append(nn.PixelShuffle(scale))
if act is not None:
m.append(act)
self.body = nn.Sequential(*m)
def forward(self, x):
out = self.body(x)
return out
class discrim_block(nn.Module):
def __init__(self, in_feats, out_feats, kernel_size, act = nn.LeakyReLU(inplace = True)):
super(discrim_block, self).__init__()
m = []
m.append(conv(in_feats, out_feats, kernel_size, BN = True, act = act))
m.append(conv(out_feats, out_feats, kernel_size, BN = True, act = act, stride = 2))
self.body = nn.Sequential(*m)
def forward(self, x):
out = self.body(x)
return out

View file

@ -0,0 +1,68 @@
# Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
## Overview
An unofficial implementation of SRGAN described in the paper using PyTorch.
* [ Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network](https://arxiv.org/abs/1609.04802)
Published in CVPR 2017
## Requirement
- Python 3.6.5
- PyTorch 1.1.0
- Pillow 5.1.0
- numpy 1.14.5
- scikit-image 0.15.0
## Datasets
- [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/)
## Pre-trained model
- [SRResNet](https://drive.google.com/open?id=15F2zOrOg2hIjdI0WsrOwF1y8REOkmmm0)
- [SRGAN](https://drive.google.com/open?id=1-HmcV5X94u411HRa-KEMcGhAO1OXAjAc)
## Train & Test
Train
```
python main.py --LR_path ./LR_imgs_dir --GT_path ./GT_imgs_dir
```
Test
```
python main.py --mode test --LR_path ./LR_imgs_dir --GT_path ./GT_imgs_dir --generator_path ./model/SRGAN.pt
```
Inference your own images
```
python main.py --mode test_only --LR_path ./LR_imgs_dir --generator_path ./model/SRGAN.pt
```
## Experimental Results
Experimental results on benchmarks.
### Quantitative Results
| Method| Set5| Set14| B100 |
|-------|-----|------|------|
|Bicubic|28.43|25.99|25.94|
|SRResNet(paper)|32.05|28.49|27.58|
|SRResNet(my model)|31.96|28.48|27.49|
|SRGAN(paper)|29.40|26.02|25.16|
|SRGAN(my model)|29.93|26.95|26.10|
### Qualitative Results
| Bicubic | SRResNet | SRGAN |
| --- | --- | --- |
| <img src="result/Set14_BIx4/comic_LRBI_x4.png"> |<img src="result/set14_srres_result/res_0004.png"> | <img src="result/set14_srgan_result/res_0004.png"> |
| <img src="result/Set5_BIx4/woman_LRBI_x4.png"> |<img src="result/set5_srres_result/res_0004.png"> | <img src="result/set5_srgan_result/res_0004.png"> |
| <img src="result/Set14_BIx4/baboon_LRBI_x4.png"> |<img src="result/set14_srres_result/res_0000.png"> | <img src="result/set14_srgan_result/res_0000.png"> |
## Comments
If you have any questions or comments on my codes, please email to me. [son1113@snu.ac.kr](mailto:son1113@snu.ac.kr)

View file

@ -0,0 +1,74 @@
import torch
import torch.nn as nn
from ops import *
class Generator(nn.Module):
def __init__(self, img_feat = 3, n_feats = 64, kernel_size = 3, num_block = 16, act = nn.PReLU(), scale=4):
super(Generator, self).__init__()
self.conv01 = conv(in_channel = img_feat, out_channel = n_feats, kernel_size = 9, BN = False, act = act)
resblocks = [ResBlock(channels = n_feats, kernel_size = 3, act = act) for _ in range(num_block)]
self.body = nn.Sequential(*resblocks)
self.conv02 = conv(in_channel = n_feats, out_channel = n_feats, kernel_size = 3, BN = True, act = None)
if(scale == 4):
upsample_blocks = [Upsampler(channel = n_feats, kernel_size = 3, scale = 2, act = act) for _ in range(2)]
else:
upsample_blocks = [Upsampler(channel = n_feats, kernel_size = 3, scale = scale, act = act)]
self.tail = nn.Sequential(*upsample_blocks)
self.last_conv = conv(in_channel = n_feats, out_channel = img_feat, kernel_size = 3, BN = False, act = nn.Tanh())
def forward(self, x):
x = self.conv01(x)
_skip_connection = x
x = self.body(x)
x = self.conv02(x)
feat = x + _skip_connection
x = self.tail(feat)
x = self.last_conv(x)
return x, feat
class Discriminator(nn.Module):
def __init__(self, img_feat = 3, n_feats = 64, kernel_size = 3, act = nn.LeakyReLU(inplace = True), num_of_block = 3, patch_size = 96):
super(Discriminator, self).__init__()
self.act = act
self.conv01 = conv(in_channel = img_feat, out_channel = n_feats, kernel_size = 3, BN = False, act = self.act)
self.conv02 = conv(in_channel = n_feats, out_channel = n_feats, kernel_size = 3, BN = False, act = self.act, stride = 2)
body = [discrim_block(in_feats = n_feats * (2 ** i), out_feats = n_feats * (2 ** (i + 1)), kernel_size = 3, act = self.act) for i in range(num_of_block)]
self.body = nn.Sequential(*body)
self.linear_size = ((patch_size // (2 ** (num_of_block + 1))) ** 2) * (n_feats * (2 ** num_of_block))
tail = []
tail.append(nn.Linear(self.linear_size, 1024))
tail.append(self.act)
tail.append(nn.Linear(1024, 1))
tail.append(nn.Sigmoid())
self.tail = nn.Sequential(*tail)
def forward(self, x):
x = self.conv01(x)
x = self.conv02(x)
x = self.body(x)
x = x.view(-1, self.linear_size)
x = self.tail(x)
return x

View file

@ -0,0 +1,80 @@
from torchvision import models
from collections import namedtuple
import torch
import torch.nn as nn
class vgg19(nn.Module):
def __init__(self, pre_trained = True, require_grad = False):
super(vgg19, self).__init__()
self.vgg_feature = models.vgg19(pretrained = pre_trained).features
self.seq_list = [nn.Sequential(ele) for ele in self.vgg_feature]
self.vgg_layer = ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5']
if not require_grad:
for parameter in self.parameters():
parameter.requires_grad = False
def forward(self, x):
conv1_1 = self.seq_list[0](x)
relu1_1 = self.seq_list[1](conv1_1)
conv1_2 = self.seq_list[2](relu1_1)
relu1_2 = self.seq_list[3](conv1_2)
pool1 = self.seq_list[4](relu1_2)
conv2_1 = self.seq_list[5](pool1)
relu2_1 = self.seq_list[6](conv2_1)
conv2_2 = self.seq_list[7](relu2_1)
relu2_2 = self.seq_list[8](conv2_2)
pool2 = self.seq_list[9](relu2_2)
conv3_1 = self.seq_list[10](pool2)
relu3_1 = self.seq_list[11](conv3_1)
conv3_2 = self.seq_list[12](relu3_1)
relu3_2 = self.seq_list[13](conv3_2)
conv3_3 = self.seq_list[14](relu3_2)
relu3_3 = self.seq_list[15](conv3_3)
conv3_4 = self.seq_list[16](relu3_3)
relu3_4 = self.seq_list[17](conv3_4)
pool3 = self.seq_list[18](relu3_4)
conv4_1 = self.seq_list[19](pool3)
relu4_1 = self.seq_list[20](conv4_1)
conv4_2 = self.seq_list[21](relu4_1)
relu4_2 = self.seq_list[22](conv4_2)
conv4_3 = self.seq_list[23](relu4_2)
relu4_3 = self.seq_list[24](conv4_3)
conv4_4 = self.seq_list[25](relu4_3)
relu4_4 = self.seq_list[26](conv4_4)
pool4 = self.seq_list[27](relu4_4)
conv5_1 = self.seq_list[28](pool4)
relu5_1 = self.seq_list[29](conv5_1)
conv5_2 = self.seq_list[30](relu5_1)
relu5_2 = self.seq_list[31](conv5_2)
conv5_3 = self.seq_list[32](relu5_2)
relu5_3 = self.seq_list[33](conv5_3)
conv5_4 = self.seq_list[34](relu5_3)
relu5_4 = self.seq_list[35](conv5_4)
pool5 = self.seq_list[36](relu5_4)
vgg_output = namedtuple("vgg_output", self.vgg_layer)
vgg_list = [conv1_1, relu1_1, conv1_2, relu1_2, pool1,
conv2_1, relu2_1, conv2_2, relu2_2, pool2,
conv3_1, relu3_1, conv3_2, relu3_2, conv3_3, relu3_3, conv3_4, relu3_4, pool3,
conv4_1, relu4_1, conv4_2, relu4_2, conv4_3, relu4_3, conv4_4, relu4_4, pool4,
conv5_1, relu5_1, conv5_2, relu5_2, conv5_3, relu5_3, conv5_4, relu5_4, pool5]
out = vgg_output(*vgg_list)
return out

211
Super-résolution.pyw Normal file
View file

@ -0,0 +1,211 @@
from tkinter import *
from tkinter.filedialog import askopenfilename
from tkinter.messagebox import showerror
import os, subprocess, shutil, glob, time
from PIL import Image, ImageTk
from threading import Thread
class SRGUI():
def __init__(self):
self.root = Tk()
self.input_img_path = StringVar()
self.canvas = {}
self.image = {}
self.imagetk = {}
self.id_canvas = {}
self.label = {}
self.default_diff = 20
self.diff = self.default_diff
##### ENTREE
self.frame_input = Frame(self.root)
self.frame_input.grid(row=1, column=1, rowspan=2)
Label(self.frame_input, text="Entrée", font=("Purisa", 18)).grid(row=1, column=1, columnspan=2)
self.canvas["input"] = Canvas(self.frame_input, relief=SOLID, borderwidth=2, width=400, height=400)
self.canvas["input"].grid(row=2, column=1, columnspan=2)
Entry(self.frame_input, textvariable=self.input_img_path, width=60).grid(row=3, column=1, sticky="NEWS")
Button(self.frame_input, text="...", relief=RIDGE, command=self.select_img).grid(row=3, column=2, sticky="NEWS")
Button(self.frame_input, text="Calculer", relief=RIDGE, command=self.calcul).grid(row=4, column=1, columnspan=2, sticky="NEWS")
def zoom(event):
for type in ["input", "bil", "bic", "SRCNN", "SRGAN", "original"]:
if type in self.id_canvas:
x1 = ((event.x - self.diff) * self.image[type].width) // 400
x2 = ((event.x + self.diff) * self.image[type].width) // 400
y1 = ((event.y - self.diff) * self.image[type].height) // 400
y2 = ((event.y + self.diff) * self.image[type].height) // 400
_img = self.image[type].crop((x1, y1, x2, y2))
self.imagetk[type] = ImageTk.PhotoImage(_img.resize((400, 400)))
self.canvas[type].itemconfig(self.id_canvas[type], image = self.imagetk[type])
def change_diff(event):
if event.delta > 0: self.diff -= 5
else: self.diff += 5
zoom(event)
def reset(event):
for type in ["input", "bil", "bic", "SRCNN", "SRGAN", "original"]:
if type in self.id_canvas:
self.imagetk[type] = ImageTk.PhotoImage(self.image[type].resize((400, 400)))
self.canvas[type].itemconfig(self.id_canvas[type], image = self.imagetk[type])
def activate_zoom(event):
self.canvas["input"].bind("<Motion>", zoom)
def desactivate_zoom(event):
self.canvas["input"].unbind("<Motion>")
self.canvas["input"].bind("<ButtonPress-1>", activate_zoom)
self.canvas["input"].bind("<ButtonRelease-1>", desactivate_zoom)
self.canvas["input"].bind("<Button-3>", reset)
self.canvas["input"].bind("<MouseWheel>", change_diff)
##### BILINEAIRE
self.frame_bil = Frame(self.root)
self.frame_bil.grid(row=1, column=3)
Label(self.frame_bil, text="Bilinéaire", font=("Purisa", 18)).grid(row=1, column=1)
self.canvas["bil"] = Canvas(self.frame_bil, relief=SOLID, borderwidth=2, width=400, height=400)
self.canvas["bil"].grid(row=2, column=1)
self.label["bil"] = Label(self.frame_bil, fg="gray")
self.label["bil"].grid(row=3, column=1)
##### BICUBIQUE
self.frame_bic = Frame(self.root)
self.frame_bic.grid(row=2, column=3)
Label(self.frame_bic, text="Bicubique", font=("Purisa", 18)).grid(row=1, column=1)
self.canvas["bic"] = Canvas(self.frame_bic, relief=SOLID, borderwidth=2, width=400, height=400)
self.canvas["bic"].grid(row=2, column=1)
self.label["bic"] = Label(self.frame_bic, fg="gray")
self.label["bic"].grid(row=3, column=1)
##### SRCNN
self.frame_SRCNN = Frame(self.root)
self.frame_SRCNN.grid(row=1, column=4)
Label(self.frame_SRCNN, text="SRCNN", font=("Purisa", 18)).grid(row=1, column=1)
self.canvas["SRCNN"] = Canvas(self.frame_SRCNN, relief=SOLID, borderwidth=2, width=400, height=400)
self.canvas["SRCNN"].grid(row=2, column=1)
self.label["SRCNN"] = Label(self.frame_SRCNN, fg="gray")
self.label["SRCNN"].grid(row=3, column=1)
##### SRGAN
self.frame_SRGAN = Frame(self.root)
self.frame_SRGAN.grid(row=2, column=4)
Label(self.frame_SRGAN, text="SRGAN", font=("Purisa", 18)).grid(row=1, column=1)
self.canvas["SRGAN"] = Canvas(self.frame_SRGAN, relief=SOLID, borderwidth=2, width=400, height=400)
self.canvas["SRGAN"].grid(row=2, column=1)
self.label["SRGAN"] = Label(self.frame_SRGAN, fg="gray")
self.label["SRGAN"].grid(row=3, column=1)
##### ORIGINAL
self.frame_original = Frame(self.root)
Label(self.frame_original, text="Original", font=("Purisa", 18)).grid(row=1, column=1)
self.canvas["original"] = Canvas(self.frame_original, relief=SOLID, borderwidth=2, width=400, height=400)
self.canvas["original"].grid(row=2, column=1)
def select_img(self):
self.input_img_path.set(askopenfilename(filetypes = [('Images', '*.png *.jpg *.jpeg *.bmp *.gif')]))
if os.path.exists(self.input_img_path.get()):
self.image["input"] = Image.open(self.input_img_path.get())
self.imagetk["input"] = ImageTk.PhotoImage(self.image["input"].resize((400, 400)))
self.id_canvas["input"] = self.canvas["input"].create_image(204, 204, image = self.imagetk["input"])
_l = self.input_img_path.get().split(".")
_l[-2] += "_original"
original_img_path = ".".join(_l)
if os.path.exists(original_img_path):
self.image["original"] = Image.open(original_img_path)
self.imagetk["original"] = ImageTk.PhotoImage(self.image["original"].resize((400, 400)))
self.id_canvas["original"] = self.canvas["original"].create_image(204, 204, image=self.imagetk["original"])
self.frame_original.grid(row=1, column=5, rowspan=2)
else: self.frame_original.grid_forget()
else: showerror("Erreur", "Ce fichier n'existe pas.")
def calcul(self):
def thread_func():
type2func = {"bil": self.Bilinear, "bic": self.Bicubic, "SRCNN": self.SRCNN, "SRGAN": self.SRGAN}
for type in type2func:
try:
t1 = time.time()
type2func[type]()
t2 = time.time()
except: self.label["bil"].config(text="Erreur", fg="red")
else: self.label["bil"].config(text="Temps de calcul : %0.2fs" % (t2 - t1), fg="gray")
thfunc = Thread(target=thread_func)
thfunc.setDaemon(True)
thfunc.start()
def Bilinear(self):
self.image["bil"] = Image.open(self.input_img_path.get()).resize((400 * 4, 400 * 4), Image.BILINEAR)
self.imagetk["bil"] = ImageTk.PhotoImage(self.image["bil"].resize((400, 400)))
self.id_canvas["bil"] = self.canvas["bil"].create_image(204, 204, image=self.imagetk["bil"])
def Bicubic(self):
self.image["bic"] = Image.open(self.input_img_path.get()).resize((400 * 4, 400 * 4), Image.BICUBIC)
self.imagetk["bic"] = ImageTk.PhotoImage(self.image["bic"].resize((400, 400)))
self.id_canvas["bic"] = self.canvas["bic"].create_image(204, 204, image=self.imagetk["bic"])
def SRCNN(self):
SRCNN_img_path = "./SRCNN-pytorch-master/data/" + os.path.basename(self.input_img_path.get())
self.image["input"].resize((self.image["input"].width * 4, self.image["input"].height * 4), Image.BICUBIC).save(SRCNN_img_path)
subprocess.call(
'"D:/ProgramData/Anaconda3/python.exe" \
"./SRCNN-pytorch-master/test.py" \
--weights-file "./SRCNN-pytorch-master/weights/srcnn_x4.pth" \
--image-file "%s" \
--scale 4' % SRCNN_img_path)
l = SRCNN_img_path.split(".")
l[-2] += "_srcnn_x4"
self.SRCNN_img_path = ".".join(l)
self.image["SRCNN"] = Image.open(self.SRCNN_img_path)
self.imagetk["SRCNN"] = ImageTk.PhotoImage(self.image["SRCNN"].resize((400, 400)))
self.id_canvas["SRCNN"] = self.canvas["SRCNN"].create_image(204, 204, image=self.imagetk["SRCNN"])
def SRGAN(self):
SRGAN_img_path = "./SRGAN-PyTorch-master/LR_imgs_dir/" + os.path.basename(self.input_img_path.get())
shutil.copy(self.input_img_path.get(), SRGAN_img_path)
workingdir = os.getcwd()
os.chdir("./SRGAN-PyTorch-master/")
subprocess.call(
'"D:/ProgramData/Anaconda3/python.exe" \
main.py \
--mode test_only \
--LR_path LR_imgs_dir/ \
--generator_path model/SRGAN.pt"')
os.chdir(workingdir)
os.remove(SRGAN_img_path)
self.SRGAN_img_path = glob.glob("./SRGAN-PyTorch-master/result/res_*.png")[-1]
self.image["SRGAN"] = Image.open(self.SRGAN_img_path)
self.imagetk["SRGAN"] = ImageTk.PhotoImage(self.image["SRGAN"].resize((400, 400)))
self.id_canvas["SRGAN"] = self.canvas["SRGAN"].create_image(204, 204, image=self.imagetk["SRGAN"])
App = SRGUI()
mainloop()