68 lines
2.7 KiB
Python
68 lines
2.7 KiB
Python
import torch
|
|
import numpy as np
|
|
|
|
|
|
def convert_rgb_to_y(img):
|
|
if type(img) == np.ndarray:
|
|
return 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
|
|
elif type(img) == torch.Tensor:
|
|
if len(img.shape) == 4:
|
|
img = img.squeeze(0)
|
|
return 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256.
|
|
else:
|
|
raise Exception('Unknown Type', type(img))
|
|
|
|
|
|
def convert_rgb_to_ycbcr(img):
|
|
if type(img) == np.ndarray:
|
|
y = 16. + (64.738 * img[:, :, 0] + 129.057 * img[:, :, 1] + 25.064 * img[:, :, 2]) / 256.
|
|
cb = 128. + (-37.945 * img[:, :, 0] - 74.494 * img[:, :, 1] + 112.439 * img[:, :, 2]) / 256.
|
|
cr = 128. + (112.439 * img[:, :, 0] - 94.154 * img[:, :, 1] - 18.285 * img[:, :, 2]) / 256.
|
|
return np.array([y, cb, cr]).transpose([1, 2, 0])
|
|
elif type(img) == torch.Tensor:
|
|
if len(img.shape) == 4:
|
|
img = img.squeeze(0)
|
|
y = 16. + (64.738 * img[0, :, :] + 129.057 * img[1, :, :] + 25.064 * img[2, :, :]) / 256.
|
|
cb = 128. + (-37.945 * img[0, :, :] - 74.494 * img[1, :, :] + 112.439 * img[2, :, :]) / 256.
|
|
cr = 128. + (112.439 * img[0, :, :] - 94.154 * img[1, :, :] - 18.285 * img[2, :, :]) / 256.
|
|
return torch.cat([y, cb, cr], 0).permute(1, 2, 0)
|
|
else:
|
|
raise Exception('Unknown Type', type(img))
|
|
|
|
|
|
def convert_ycbcr_to_rgb(img):
|
|
if type(img) == np.ndarray:
|
|
r = 298.082 * img[:, :, 0] / 256. + 408.583 * img[:, :, 2] / 256. - 222.921
|
|
g = 298.082 * img[:, :, 0] / 256. - 100.291 * img[:, :, 1] / 256. - 208.120 * img[:, :, 2] / 256. + 135.576
|
|
b = 298.082 * img[:, :, 0] / 256. + 516.412 * img[:, :, 1] / 256. - 276.836
|
|
return np.array([r, g, b]).transpose([1, 2, 0])
|
|
elif type(img) == torch.Tensor:
|
|
if len(img.shape) == 4:
|
|
img = img.squeeze(0)
|
|
r = 298.082 * img[0, :, :] / 256. + 408.583 * img[2, :, :] / 256. - 222.921
|
|
g = 298.082 * img[0, :, :] / 256. - 100.291 * img[1, :, :] / 256. - 208.120 * img[2, :, :] / 256. + 135.576
|
|
b = 298.082 * img[0, :, :] / 256. + 516.412 * img[1, :, :] / 256. - 276.836
|
|
return torch.cat([r, g, b], 0).permute(1, 2, 0)
|
|
else:
|
|
raise Exception('Unknown Type', type(img))
|
|
|
|
|
|
def calc_psnr(img1, img2):
|
|
return 10. * torch.log10(1. / torch.mean((img1 - img2) ** 2))
|
|
|
|
|
|
class AverageMeter(object):
|
|
def __init__(self):
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.val = 0
|
|
self.avg = 0
|
|
self.sum = 0
|
|
self.count = 0
|
|
|
|
def update(self, val, n=1):
|
|
self.val = val
|
|
self.sum += val * n
|
|
self.count += n
|
|
self.avg = self.sum / self.count
|