97 lines
3.3 KiB
Python
97 lines
3.3 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class _conv(nn.Conv2d):
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias):
|
|
super(_conv, self).__init__(in_channels = in_channels, out_channels = out_channels,
|
|
kernel_size = kernel_size, stride = stride, padding = (kernel_size) // 2, bias = True)
|
|
|
|
self.weight.data = torch.normal(torch.zeros((out_channels, in_channels, kernel_size, kernel_size)), 0.02)
|
|
self.bias.data = torch.zeros((out_channels))
|
|
|
|
for p in self.parameters():
|
|
p.requires_grad = True
|
|
|
|
|
|
class conv(nn.Module):
|
|
def __init__(self, in_channel, out_channel, kernel_size, BN = False, act = None, stride = 1, bias = True):
|
|
super(conv, self).__init__()
|
|
m = []
|
|
m.append(_conv(in_channels = in_channel, out_channels = out_channel,
|
|
kernel_size = kernel_size, stride = stride, padding = (kernel_size) // 2, bias = True))
|
|
|
|
if BN:
|
|
m.append(nn.BatchNorm2d(num_features = out_channel))
|
|
|
|
if act is not None:
|
|
m.append(act)
|
|
|
|
self.body = nn.Sequential(*m)
|
|
|
|
def forward(self, x):
|
|
out = self.body(x)
|
|
return out
|
|
|
|
class ResBlock(nn.Module):
|
|
def __init__(self, channels, kernel_size, act = nn.ReLU(inplace = True), bias = True):
|
|
super(ResBlock, self).__init__()
|
|
m = []
|
|
m.append(conv(channels, channels, kernel_size, BN = True, act = act))
|
|
m.append(conv(channels, channels, kernel_size, BN = True, act = None))
|
|
self.body = nn.Sequential(*m)
|
|
|
|
def forward(self, x):
|
|
res = self.body(x)
|
|
res += x
|
|
return res
|
|
|
|
class BasicBlock(nn.Module):
|
|
def __init__(self, in_channels, out_channels, kernel_size, num_res_block, act = nn.ReLU(inplace = True)):
|
|
super(BasicBlock, self).__init__()
|
|
m = []
|
|
|
|
self.conv = conv(in_channels, out_channels, kernel_size, BN = False, act = act)
|
|
for i in range(num_res_block):
|
|
m.append(ResBlock(out_channels, kernel_size, act))
|
|
|
|
m.append(conv(out_channels, out_channels, kernel_size, BN = True, act = None))
|
|
|
|
self.body = nn.Sequential(*m)
|
|
|
|
def forward(self, x):
|
|
res = self.conv(x)
|
|
out = self.body(res)
|
|
out += res
|
|
|
|
return out
|
|
|
|
class Upsampler(nn.Module):
|
|
def __init__(self, channel, kernel_size, scale, act = nn.ReLU(inplace = True)):
|
|
super(Upsampler, self).__init__()
|
|
m = []
|
|
m.append(conv(channel, channel * scale * scale, kernel_size))
|
|
m.append(nn.PixelShuffle(scale))
|
|
|
|
if act is not None:
|
|
m.append(act)
|
|
|
|
self.body = nn.Sequential(*m)
|
|
|
|
def forward(self, x):
|
|
out = self.body(x)
|
|
return out
|
|
|
|
class discrim_block(nn.Module):
|
|
def __init__(self, in_feats, out_feats, kernel_size, act = nn.LeakyReLU(inplace = True)):
|
|
super(discrim_block, self).__init__()
|
|
m = []
|
|
m.append(conv(in_feats, out_feats, kernel_size, BN = True, act = act))
|
|
m.append(conv(out_feats, out_feats, kernel_size, BN = True, act = act, stride = 2))
|
|
self.body = nn.Sequential(*m)
|
|
|
|
def forward(self, x):
|
|
out = self.body(x)
|
|
return out
|
|
|