L1-SR-GUI/SRCNN-pytorch-master/test.py
2020-11-12 23:30:33 +01:00

67 lines
2.1 KiB
Python

import argparse
import torch
import torch.backends.cudnn as cudnn
import numpy as np
import PIL.Image as pil_image
from models import SRCNN
from utils import convert_rgb_to_ycbcr, convert_ycbcr_to_rgb, calc_psnr
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--weights-file', type=str, required=True)
parser.add_argument('--image-file', type=str, required=True)
parser.add_argument('--scale', type=int, default=3)
args = parser.parse_args()
cudnn.benchmark = True
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = SRCNN().to(device)
state_dict = model.state_dict()
for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():
if n in state_dict.keys():
state_dict[n].copy_(p)
else:
raise KeyError(n)
model.eval()
image = pil_image.open(args.image_file).convert('RGB')
image_width = (image.width // args.scale) * args.scale
image_height = (image.height // args.scale) * args.scale
image = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
image = image.resize((image.width // args.scale, image.height // args.scale), resample=pil_image.BICUBIC)
image = image.resize((image.width * args.scale, image.height * args.scale), resample=pil_image.BICUBIC)
image = np.array(image).astype(np.float32)
ycbcr = convert_rgb_to_ycbcr(image)
y = ycbcr[..., 0]
y /= 255.
y = torch.from_numpy(y).to(device)
y = y.unsqueeze(0).unsqueeze(0)
with torch.no_grad():
preds = model(y).clamp(0.0, 1.0)
psnr = calc_psnr(y, preds)
print('PSNR: {:.2f}'.format(psnr))
preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)
output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
output = pil_image.fromarray(output)
l = args.image_file.split(".")
l[-2] += '_srcnn_x{}'.format(args.scale)
args.image_file = ".".join(l)
output.save(args.image_file)