Skip to content

Commit

Permalink
add esrgan x2 (PaddlePaddle#529)
Browse files Browse the repository at this point in the history
* add esrgan x2
  • Loading branch information
LielinJiang authored Dec 16, 2021
1 parent 424ab9e commit 5bf728d
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 2 deletions.
106 changes: 106 additions & 0 deletions configs/esrgan_psnr_x2_div2k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
total_iters: 1000000
output_dir: output_dir
# tensor range for function tensor2img
min_max:
(0., 1.)

model:
name: BaseSRModel
generator:
name: RRDBNet
in_nc: 3
out_nc: 3
nf: 64
nb: 23
scale: 2
pixel_criterion:
name: L1Loss

dataset:
train:
name: SRDataset
gt_folder: data/DIV2K/DIV2K_train_HR_sub
lq_folder: data/DIV2K/DIV2K_train_LR_bicubic/X2_sub
num_workers: 4
batch_size: 8
scale: 2
preprocess:
- name: LoadImageFromFile
key: lq
- name: LoadImageFromFile
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: SRPairedRandomCrop
gt_patch_size: 128
scale: 2
keys: [image, image]
- name: PairedRandomHorizontalFlip
keys: [image, image]
- name: PairedRandomVerticalFlip
keys: [image, image]
- name: PairedRandomTransposeHW
keys: [image, image]
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [0., .0, 0.]
std: [255., 255., 255.]
keys: [image, image]
test:
name: SRDataset
gt_folder: data/Set14/GTmod12
lq_folder: data/Set14/LRbicx2
scale: 2
preprocess:
- name: LoadImageFromFile
key: lq
- name: LoadImageFromFile
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [0., .0, 0.]
std: [255., 255., 255.]
keys: [image, image]

lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: 0.0002
periods: [250000, 250000, 250000, 250000]
restart_weights: [1, 1, 1, 1]
eta_min: !!float 1e-7

optimizer:
name: Adam
# add parameters of net_name to optim
# name should in self.nets
net_names:
- generator
beta1: 0.9
beta2: 0.99

validate:
interval: 5000
save_img: false

metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 2
test_y_channel: True
ssim:
name: SSIM
crop_border: 2
test_y_channel: True

log_config:
interval: 100
visiual_interval: 500

snapshot_config:
interval: 5000
39 changes: 37 additions & 2 deletions ppgan/models/generators/rrdb_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,26 @@
from .builder import GENERATORS


def pixel_unshuffle(x, scale):
""" Pixel unshuffle function.
Args:
x (paddle.Tensor): Input feature.
scale (int): Downsample ratio.
Returns:
paddle.Tensor: the pixel unshuffled feature.
"""
b, c, h, w = x.shape
out_channel = c * (scale**2)
assert h % scale == 0 and w % scale == 0
hh = h // scale
ww = w // scale
x_reshaped = x.reshape([b, c, hh, scale, ww, scale])
return x_reshaped.transpose([0, 1, 3, 5, 2,
4]).reshape([b, out_channel, hh, ww])


class ResidualDenseBlock_5C(nn.Layer):
def __init__(self, nf=64, gc=32, bias=True):
super(ResidualDenseBlock_5C, self).__init__()
Expand Down Expand Up @@ -66,13 +86,21 @@ def make_layer(block, n_layers):

@GENERATORS.register()
class RRDBNet(nn.Layer):
def __init__(self, in_nc, out_nc, nf, nb, gc=32):
def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4):
super(RRDBNet, self).__init__()

self.scale = scale
if scale == 2:
in_nc = in_nc * 4
elif scale == 1:
in_nc = in_nc * 16

RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)

self.conv_first = nn.Conv2D(in_nc, nf, 3, 1, 1, bias_attr=True)
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
self.trunk_conv = nn.Conv2D(nf, nf, 3, 1, 1, bias_attr=True)

#### upsampling
self.upconv1 = nn.Conv2D(nf, nf, 3, 1, 1, bias_attr=True)
self.upconv2 = nn.Conv2D(nf, nf, 3, 1, 1, bias_attr=True)
Expand All @@ -82,7 +110,14 @@ def __init__(self, in_nc, out_nc, nf, nb, gc=32):
self.lrelu = nn.LeakyReLU(negative_slope=0.2)

def forward(self, x):
fea = self.conv_first(x)
if self.scale == 2:
fea = pixel_unshuffle(x, scale=2)
elif self.scale == 1:
fea = pixel_unshuffle(x, scale=4)
else:
fea = x

fea = self.conv_first(fea)
trunk = self.trunk_conv(self.RRDB_trunk(fea))
fea = fea + trunk

Expand Down

0 comments on commit 5bf728d

Please sign in to comment.