67 lines
2.1 KiB
Python
67 lines
2.1 KiB
Python
import argparse
|
|
|
|
import torch
|
|
import torch.backends.cudnn as cudnn
|
|
import numpy as np
|
|
import PIL.Image as pil_image
|
|
|
|
from models import SRCNN
|
|
from utils import convert_rgb_to_ycbcr, convert_ycbcr_to_rgb, calc_psnr
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--weights-file', type=str, required=True)
|
|
parser.add_argument('--image-file', type=str, required=True)
|
|
parser.add_argument('--scale', type=int, default=3)
|
|
args = parser.parse_args()
|
|
|
|
cudnn.benchmark = True
|
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
|
|
|
model = SRCNN().to(device)
|
|
|
|
state_dict = model.state_dict()
|
|
for n, p in torch.load(args.weights_file, map_location=lambda storage, loc: storage).items():
|
|
if n in state_dict.keys():
|
|
state_dict[n].copy_(p)
|
|
else:
|
|
raise KeyError(n)
|
|
|
|
model.eval()
|
|
|
|
image = pil_image.open(args.image_file).convert('RGB')
|
|
|
|
image_width = (image.width // args.scale) * args.scale
|
|
image_height = (image.height // args.scale) * args.scale
|
|
image = image.resize((image_width, image_height), resample=pil_image.BICUBIC)
|
|
image = image.resize((image.width // args.scale, image.height // args.scale), resample=pil_image.BICUBIC)
|
|
image = image.resize((image.width * args.scale, image.height * args.scale), resample=pil_image.BICUBIC)
|
|
|
|
|
|
image = np.array(image).astype(np.float32)
|
|
ycbcr = convert_rgb_to_ycbcr(image)
|
|
|
|
y = ycbcr[..., 0]
|
|
y /= 255.
|
|
y = torch.from_numpy(y).to(device)
|
|
y = y.unsqueeze(0).unsqueeze(0)
|
|
|
|
with torch.no_grad():
|
|
preds = model(y).clamp(0.0, 1.0)
|
|
|
|
psnr = calc_psnr(y, preds)
|
|
print('PSNR: {:.2f}'.format(psnr))
|
|
|
|
preds = preds.mul(255.0).cpu().numpy().squeeze(0).squeeze(0)
|
|
|
|
output = np.array([preds, ycbcr[..., 1], ycbcr[..., 2]]).transpose([1, 2, 0])
|
|
output = np.clip(convert_ycbcr_to_rgb(output), 0.0, 255.0).astype(np.uint8)
|
|
output = pil_image.fromarray(output)
|
|
|
|
l = args.image_file.split(".")
|
|
l[-2] += '_srcnn_x{}'.format(args.scale)
|
|
args.image_file = ".".join(l)
|
|
|
|
output.save(args.image_file)
|
|
|