Skip to content

Commit

Permalink
Add drn model (PaddlePaddle#153)
Browse files Browse the repository at this point in the history
* add drn model
  • Loading branch information
LielinJiang authored Jan 25, 2021
1 parent 827d0cf commit 95e5f4f
Show file tree
Hide file tree
Showing 6 changed files with 579 additions and 5 deletions.
117 changes: 117 additions & 0 deletions configs/drn_psnr_x4_div2k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
total_iters: 1000000
output_dir: output_dir
# tensor range for function tensor2img
min_max:
(0., 255.)

model:
name: DRN
generator:
name: DRNGenerator
scale: (2, 4)
n_blocks: 30
n_feats: 16
n_colors: 3
rgb_range: 255
negval: 0.2
pixel_criterion:
name: L1Loss

dataset:
train:
name: SRDataset
gt_folder: data/DIV2K/DIV2K_train_HR_sub
lq_folder: data/DIV2K/DIV2K_train_LR_bicubic/X4_sub
num_workers: 4
batch_size: 8
scale: 4
preprocess:
- name: LoadImageFromFile
key: lq
- name: LoadImageFromFile
key: gt
- name: Transforms
input_keys: [lq, gt]
output_keys: [lq, lqx2, gt]
pipeline:
- name: SRPairedRandomCrop
gt_patch_size: 384
scale: 4
scale_list: True
keys: [image, image]
- name: PairedRandomHorizontalFlip
keys: [image, image, image]
- name: PairedRandomVerticalFlip
keys: [image, image, image]
- name: PairedRandomTransposeHW
keys: [image, image, image]
- name: Transpose
keys: [image, image, image]
- name: Normalize
mean: [0., 0., 0.]
std: [1., 1., 1.]
keys: [image, image, image]
test:
name: SRDataset
gt_folder: data/DIV2K/val_set14/Set14
lq_folder: data/DIV2K/val_set14/Set14_bicLRx4
scale: 4
preprocess:
- name: LoadImageFromFile
key: lq
- name: LoadImageFromFile
key: gt
- name: Transforms
input_keys: [lq, gt]
pipeline:
- name: Transpose
keys: [image, image]
- name: Normalize
mean: [0., 0., 0.]
std: [1., 1., 1.]
keys: [image, image]

lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: 0.0001
periods: [1000000]
restart_weights: [1]
eta_min: !!float 1e-7

optimizer:
optimG:
name: Adam
net_names:
- generator
weight_decay: 0.0
beta1: 0.9
beta2: 0.999
optimD:
name: Adam
net_names:
- dual_model_0
- dual_model_1
weight_decay: 0.0
beta1: 0.9
beta2: 0.999

validate:
interval: 5000
save_img: false

metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 4
test_y_channel: True
ssim:
name: SSIM
crop_border: 4
test_y_channel: True

log_config:
interval: 10
visiual_interval: 500

snapshot_config:
interval: 5000
21 changes: 16 additions & 5 deletions docs/en_US/tutorials/super_resolution.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

## 1.1 Principle

Super resolution is a process of upscaling and improving the details within an image. It usually takes a low-resolution image as input and upscales the same image to a higher resolution as output.
Here we provide three super-resolution models, namely [RealSR](https://openaccess.thecvf.com/content_CVPRW_2020/papers/w31/Ji_Real-World_Super-Resolution_via_Kernel_Estimation_and_Noise_Injection_CVPRW_2020_paper.pdf), [ESRGAN](https://arxiv.org/abs/1809.00219v2), [LESRCNN](https://arxiv.org/abs/2007.04344).
Super resolution is a process of upscaling and improving the details within an image. It usually takes a low-resolution image as input and upscales the same image to a higher resolution as output.
Here we provide three super-resolution models, namely [RealSR](https://openaccess.thecvf.com/content_CVPRW_2020/papers/w31/Ji_Real-World_Super-Resolution_via_Kernel_Estimation_and_Noise_Injection_CVPRW_2020_paper.pdf), [ESRGAN](https://arxiv.org/abs/1809.00219v2), [LESRCNN](https://arxiv.org/abs/2007.04344).
[RealSR](https://openaccess.thecvf.com/content_CVPRW_2020/papers/w31/Ji_Real-World_Super-Resolution_via_Kernel_Estimation_and_Noise_Injection_CVPRW_2020_paper.pdf) proposed a realworld super-resolution model aiming at better perception.
[ESRGAN](https://arxiv.org/abs/1809.00219v2) is an enhanced SRGAN that improves the three key components of SRGAN.
[LESRCNN](https://arxiv.org/abs/2007.04344) is a lightweight enhanced SR CNN (LESRCNN) with three successive sub-blocks.
[LESRCNN](https://arxiv.org/abs/2007.04344) is a lightweight enhanced SR CNN (LESRCNN) with three successive sub-blocks.


## 1.2 How to use
Expand All @@ -32,7 +32,7 @@
├── DIV2K_valid_LR_bicubic
...
```

The structures of Set5 and Set14 are similar. Taking Set5 as an example, the structure is as following:
```
Set5
Expand Down Expand Up @@ -71,7 +71,7 @@ The metrics are PSNR / SSIM.
| lesrcnn_x4 | 31.9476 / 0.8909 | 28.4110 / 0.7770 | 30.231 / 0.8326 |
| esrgan_psnr_x4 | 32.5512 / 0.8991 | 28.8114 / 0.7871 | 30.7565 / 0.8449 |
| esrgan_x4 | 28.7647 / 0.8187 | 25.0065 / 0.6762 | 26.9013 / 0.7542 |


<!-- ![](../../imgs/horse2zebra.png) -->

Expand All @@ -85,6 +85,7 @@ The metrics are PSNR / SSIM.
| lesrcnn_x4 | DIV2K | [lesrcnn_x4](https://paddlegan.bj.bcebos.com/models/lesrcnn_x4.pdparams)
| esrgan_psnr_x4 | DIV2K | [esrgan_psnr_x4](https://paddlegan.bj.bcebos.com/models/esrgan_psnr_x4.pdparams)
| esrgan_x4 | DIV2K | [esrgan_x4](https://paddlegan.bj.bcebos.com/models/esrgan_x4.pdparams)
| drns_x4 | DIV2K | [drns_x4](https://paddlegan.bj.bcebos.com/models/DRNSx4.pdparams)


# References
Expand Down Expand Up @@ -126,3 +127,13 @@ The metrics are PSNR / SSIM.
publisher={Elsevier}
}
```
- 4. [Closed-loop Matters: Dual Regression Networks for Single Image Super-Resolution](https://arxiv.org/pdf/2003.07018.pdf)

```
@inproceedings{guo2020closed,
title={Closed-loop Matters: Dual Regression Networks for Single Image Super-Resolution},
author={Guo, Yong and Chen, Jian and Wang, Jingdong and Chen, Qi and Cao, Jiezhang and Deng, Zeshuai and Xu, Yanwu and Tan, Mingkui},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
year={2020}
}
```
1 change: 1 addition & 0 deletions ppgan/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .esrgan_model import ESRGAN
from .ugatit_model import UGATITModel
from .dc_gan_model import DCGANModel
from .drn_model import DRN
from .animeganv2_model import AnimeGANV2Model, AnimeGANV2PreTrainModel
from .styleganv2_model import StyleGAN2Model
from .wav2lip_model import Wav2LipModel
Expand Down
151 changes: 151 additions & 0 deletions ppgan/models/drn_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle

from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from .generators.drn import DownBlock
from .sr_model import BaseSRModel
from .builder import MODELS

from .criterions import build_criterion
from ..modules.init import init_weights
from ..utils.visual import tensor2img


@MODELS.register()
class DRN(BaseSRModel):
"""
This class implements the DRN model.
DRN paper: https://arxiv.org/pdf/1809.00219.pdf
"""
def __init__(self,
generator,
lq_loss_weight=0.1,
dual_loss_weight=0.1,
discriminator=None,
pixel_criterion=None,
perceptual_criterion=None,
gan_criterion=None,
params=None):
"""Initialize the DRN class.
Args:
generator (dict): config of generator.
discriminator (dict): config of discriminator.
pixel_criterion (dict): config of pixel criterion.
perceptual_criterion (dict): config of perceptual criterion.
gan_criterion (dict): config of gan criterion.
"""
super(DRN, self).__init__(generator)
self.lq_loss_weight = lq_loss_weight
self.dual_loss_weight = dual_loss_weight
self.params = params
self.nets['generator'] = build_generator(generator)
init_weights(self.nets['generator'])
negval = generator.negval
n_feats = generator.n_feats
n_colors = generator.n_colors
self.scale = generator.scale

for i in range(len(self.scale)):
dual_model = DownBlock(negval, n_feats, n_colors, 2)
self.nets['dual_model_' + str(i)] = dual_model
init_weights(self.nets['dual_model_' + str(i)])

if discriminator:
self.nets['discriminator'] = build_discriminator(discriminator)

if pixel_criterion:
self.pixel_criterion = build_criterion(pixel_criterion)

if perceptual_criterion:
self.perceptual_criterion = build_criterion(perceptual_criterion)

if gan_criterion:
self.gan_criterion = build_criterion(gan_criterion)

def setup_input(self, input):
self.lq = paddle.fluid.dygraph.to_variable(input['lq'])
self.visual_items['lq'] = self.lq

if isinstance(self.scale, (list, tuple)) and len(
self.scale) == 2 and 'lqx2' in input:
self.lqx2 = input['lqx2']

if 'gt' in input:
self.gt = paddle.fluid.dygraph.to_variable(input['gt'])
self.visual_items['gt'] = self.gt
self.image_paths = input['lq_path']

def train_iter(self, optimizers=None):
lr = [self.lq]

if hasattr(self, 'lqx2'):
lr.append(self.lqx2)

hr = self.gt

sr = self.nets['generator'](self.lq)

sr2lr = []

for i in range(len(self.scale)):
sr2lr_i = self.nets['dual_model_' + str(i)](sr[i - len(self.scale)])
sr2lr.append(sr2lr_i)

# compute primary loss
loss_primary = self.pixel_criterion(sr[-1], hr)
for i in range(1, len(sr)):
if self.lq_loss_weight > 0.0:
loss_primary += self.pixel_criterion(
sr[i - 1 - len(sr)], lr[i - len(sr)]) * self.lq_loss_weight

# compute dual loss
loss_dual = self.pixel_criterion(sr2lr[0], lr[0])
for i in range(1, len(self.scale)):
if self.dual_loss_weight > 0.0:
loss_dual += self.pixel_criterion(sr2lr[i],
lr[i]) * self.dual_loss_weight

loss_total = loss_primary + loss_dual

optimizers['optimG'].clear_grad()
optimizers['optimD'].clear_grad()
loss_total.backward()
optimizers['optimG'].step()
optimizers['optimD'].step()

self.losses['loss_promary'] = loss_primary
self.losses['loss_dual'] = loss_dual
self.losses['loss_total'] = loss_total

def test_iter(self, metrics=None):
self.nets['generator'].eval()
with paddle.no_grad():
self.output = self.nets['generator'](self.lq)[-1]
self.visual_items['output'] = self.output
self.nets['generator'].train()

out_img = []
gt_img = []
for out_tensor, gt_tensor in zip(self.output, self.gt):
out_img.append(tensor2img(out_tensor, (0., 255.)))
gt_img.append(tensor2img(gt_tensor, (0., 255.)))

if metrics is not None:
for metric in metrics.values():
metric.update(out_img, gt_img)
1 change: 1 addition & 0 deletions ppgan/models/generators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@
from .resnet_ugatit_p2c import ResnetUGATITP2CGenerator
from .generator_styleganv2 import StyleGANv2Generator
from .generator_pixel2style2pixel import Pixel2Style2Pixel
from .drn import DRNGenerator
Loading

0 comments on commit 95e5f4f

Please sign in to comment.