208 lines
7.1 KiB
Python
208 lines
7.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.utils.data import DataLoader
|
|
from torchvision import transforms
|
|
from losses import TVLoss, perceptual_loss
|
|
from dataset import *
|
|
from srgan_model import Generator, Discriminator
|
|
from vgg19 import vgg19
|
|
import numpy as np
|
|
from PIL import Image
|
|
from skimage.color import rgb2ycbcr
|
|
from skimage.measure import compare_psnr
|
|
|
|
|
|
def train(args):
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
transform = transforms.Compose([crop(args.scale, args.patch_size), augmentation()])
|
|
dataset = mydata(GT_path = args.GT_path, LR_path = args.LR_path, in_memory = args.in_memory, transform = transform)
|
|
loader = DataLoader(dataset, batch_size = args.batch_size, shuffle = True, num_workers = args.num_workers)
|
|
|
|
generator = Generator(img_feat = 3, n_feats = 64, kernel_size = 3, num_block = args.res_num, scale=args.scale)
|
|
|
|
|
|
if args.fine_tuning:
|
|
generator.load_state_dict(torch.load(args.generator_path))
|
|
print("pre-trained model is loaded")
|
|
print("path : %s"%(args.generator_path))
|
|
|
|
generator = generator.to(device)
|
|
generator.train()
|
|
|
|
l2_loss = nn.MSELoss()
|
|
g_optim = optim.Adam(generator.parameters(), lr = 1e-4)
|
|
|
|
pre_epoch = 0
|
|
fine_epoch = 0
|
|
|
|
#### Train using L2_loss
|
|
while pre_epoch < args.pre_train_epoch:
|
|
for i, tr_data in enumerate(loader):
|
|
gt = tr_data['GT'].to(device)
|
|
lr = tr_data['LR'].to(device)
|
|
|
|
output, _ = generator(lr)
|
|
loss = l2_loss(gt, output)
|
|
|
|
g_optim.zero_grad()
|
|
loss.backward()
|
|
g_optim.step()
|
|
|
|
pre_epoch += 1
|
|
|
|
if pre_epoch % 2 == 0:
|
|
print(pre_epoch)
|
|
print(loss.item())
|
|
print('=========')
|
|
|
|
if pre_epoch % 800 ==0:
|
|
torch.save(generator.state_dict(), './model/pre_trained_model_%03d.pt'%pre_epoch)
|
|
|
|
|
|
#### Train using perceptual & adversarial loss
|
|
vgg_net = vgg19().to(device)
|
|
vgg_net = vgg_net.eval()
|
|
|
|
discriminator = Discriminator(patch_size = args.patch_size * args.scale)
|
|
discriminator = discriminator.to(device)
|
|
discriminator.train()
|
|
|
|
d_optim = optim.Adam(discriminator.parameters(), lr = 1e-4)
|
|
scheduler = optim.lr_scheduler.StepLR(g_optim, step_size = 2000, gamma = 0.1)
|
|
|
|
VGG_loss = perceptual_loss(vgg_net)
|
|
cross_ent = nn.BCELoss()
|
|
tv_loss = TVLoss()
|
|
real_label = torch.ones((args.batch_size, 1)).to(device)
|
|
fake_label = torch.zeros((args.batch_size, 1)).to(device)
|
|
|
|
while fine_epoch < args.fine_train_epoch:
|
|
|
|
scheduler.step()
|
|
|
|
for i, tr_data in enumerate(loader):
|
|
gt = tr_data['GT'].to(device)
|
|
lr = tr_data['LR'].to(device)
|
|
|
|
## Training Discriminator
|
|
output, _ = generator(lr)
|
|
fake_prob = discriminator(output)
|
|
real_prob = discriminator(gt)
|
|
|
|
d_loss_real = cross_ent(real_prob, real_label)
|
|
d_loss_fake = cross_ent(fake_prob, fake_label)
|
|
|
|
d_loss = d_loss_real + d_loss_fake
|
|
|
|
g_optim.zero_grad()
|
|
d_optim.zero_grad()
|
|
d_loss.backward()
|
|
d_optim.step()
|
|
|
|
## Training Generator
|
|
output, _ = generator(lr)
|
|
fake_prob = discriminator(output)
|
|
|
|
_percep_loss, hr_feat, sr_feat = VGG_loss((gt + 1.0) / 2.0, (output + 1.0) / 2.0, layer = args.feat_layer)
|
|
|
|
L2_loss = l2_loss(output, gt)
|
|
percep_loss = args.vgg_rescale_coeff * _percep_loss
|
|
adversarial_loss = args.adv_coeff * cross_ent(fake_prob, real_label)
|
|
total_variance_loss = args.tv_loss_coeff * tv_loss(args.vgg_rescale_coeff * (hr_feat - sr_feat)**2)
|
|
|
|
g_loss = percep_loss + adversarial_loss + total_variance_loss + L2_loss
|
|
|
|
g_optim.zero_grad()
|
|
d_optim.zero_grad()
|
|
g_loss.backward()
|
|
g_optim.step()
|
|
|
|
|
|
fine_epoch += 1
|
|
|
|
if fine_epoch % 2 == 0:
|
|
print(fine_epoch)
|
|
print(g_loss.item())
|
|
print(d_loss.item())
|
|
print('=========')
|
|
|
|
if fine_epoch % 500 ==0:
|
|
torch.save(generator.state_dict(), './model/SRGAN_gene_%03d.pt'%fine_epoch)
|
|
torch.save(discriminator.state_dict(), './model/SRGAN_discrim_%03d.pt'%fine_epoch)
|
|
|
|
|
|
# In[ ]:
|
|
|
|
def test(args):
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
dataset = mydata(GT_path = args.GT_path, LR_path = args.LR_path, in_memory = False, transform = None)
|
|
loader = DataLoader(dataset, batch_size = 1, shuffle = False, num_workers = args.num_workers)
|
|
|
|
generator = Generator(img_feat = 3, n_feats = 64, kernel_size = 3, num_block = args.res_num)
|
|
generator.load_state_dict(torch.load(args.generator_path))
|
|
generator = generator.to(device)
|
|
generator.eval()
|
|
|
|
f = open('./result.txt', 'w')
|
|
psnr_list = []
|
|
|
|
with torch.no_grad():
|
|
for i, te_data in enumerate(loader):
|
|
gt = te_data['GT'].to(device)
|
|
lr = te_data['LR'].to(device)
|
|
|
|
bs, c, h, w = lr.size()
|
|
gt = gt[:, :, : h * args.scale, : w *args.scale]
|
|
|
|
output, _ = generator(lr)
|
|
|
|
output = output[0].cpu().numpy()
|
|
output = np.clip(output, -1.0, 1.0)
|
|
gt = gt[0].cpu().numpy()
|
|
|
|
output = (output + 1.0) / 2.0
|
|
gt = (gt + 1.0) / 2.0
|
|
|
|
output = output.transpose(1,2,0)
|
|
gt = gt.transpose(1,2,0)
|
|
|
|
y_output = rgb2ycbcr(output)[args.scale:-args.scale, args.scale:-args.scale, :1]
|
|
y_gt = rgb2ycbcr(gt)[args.scale:-args.scale, args.scale:-args.scale, :1]
|
|
|
|
psnr = compare_psnr(y_output / 255.0, y_gt / 255.0, data_range = 1.0)
|
|
psnr_list.append(psnr)
|
|
f.write('psnr : %04f \n' % psnr)
|
|
|
|
result = Image.fromarray((output * 255.0).astype(np.uint8))
|
|
result.save('./result/res_%04d.png'%i)
|
|
|
|
f.write('avg psnr : %04f' % np.mean(psnr_list))
|
|
|
|
|
|
def test_only(args):
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
dataset = testOnly_data(LR_path = args.LR_path, in_memory = False, transform = None)
|
|
loader = DataLoader(dataset, batch_size = 1, shuffle = False, num_workers = args.num_workers)
|
|
|
|
generator = Generator(img_feat = 3, n_feats = 64, kernel_size = 3, num_block = args.res_num)
|
|
generator.load_state_dict(torch.load(args.generator_path))
|
|
generator = generator.to(device)
|
|
generator.eval()
|
|
|
|
with torch.no_grad():
|
|
for i, te_data in enumerate(loader):
|
|
lr = te_data['LR'].to(device)
|
|
output, _ = generator(lr)
|
|
output = output[0].cpu().numpy()
|
|
output = (output + 1.0) / 2.0
|
|
output = output.transpose(1,2,0)
|
|
result = Image.fromarray((output * 255.0).astype(np.uint8))
|
|
result.save('./result/res_%04d.png'%i)
|
|
|
|
|
|
|