Skip to content

Commit

Permalink
Add BasicVSR++ (PaddlePaddle#383)
Browse files Browse the repository at this point in the history
* add BasicVSR++
  • Loading branch information
wangna11BD authored Nov 6, 2021
1 parent 8c7878d commit acd1615
Show file tree
Hide file tree
Showing 9 changed files with 633 additions and 7 deletions.
93 changes: 93 additions & 0 deletions configs/basicvsr++_reds.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
total_iters: 600000
output_dir: output_dir
find_unused_parameters: True
checkpoints_dir: checkpoints
use_dataset: True
# tensor range for function tensor2img
min_max:
(0., 1.)

model:
name: BasicVSRModel
fix_iter: 5000
lr_mult: 0.25
generator:
name: BasicVSRPlusPlus
mid_channels: 64
num_blocks: 7
is_low_res_input: True
pixel_criterion:
name: CharbonnierLoss
reduction: mean

dataset:
train:
name: RepeatDataset
times: 1000
num_workers: 4
batch_size: 2 #4 gpus
dataset:
name: SRREDSMultipleGTDataset
mode: train
lq_folder: data/REDS/train_sharp_bicubic/X4
gt_folder: data/REDS/train_sharp/X4
crop_size: 256
interval_list: [1]
random_reverse: False
number_frames: 30
use_flip: True
use_rot: True
scale: 4
val_partition: REDS4

test:
name: SRREDSMultipleGTDataset
mode: test
lq_folder: data/REDS/REDS4_test_sharp_bicubic/X4
gt_folder: data/REDS/REDS4_test_sharp/X4
interval_list: [1]
random_reverse: False
number_frames: 100
use_flip: False
use_rot: False
scale: 4
val_partition: REDS4
num_workers: 0
batch_size: 1

lr_scheduler:
name: CosineAnnealingRestartLR
learning_rate: !!float 1e-4
periods: [600000]
restart_weights: [1]
eta_min: !!float 1e-7

optimizer:
name: Adam
# add parameters of net_name to optim
# name should in self.nets
net_names:
- generator
beta1: 0.9
beta2: 0.99

validate:
interval: 5000
save_img: false

metrics:
psnr: # metric name, can be arbitrary
name: PSNR
crop_border: 0
test_y_channel: False
ssim:
name: SSIM
crop_border: 0
test_y_channel: False

log_config:
interval: 10
visiual_interval: 500

snapshot_config:
interval: 5000
1 change: 1 addition & 0 deletions configs/basicvsr_reds.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ min_max:
model:
name: BasicVSRModel
fix_iter: 5000
lr_mult: 0.125
generator:
name: BasicVSRNet
mid_channels: 64
Expand Down
1 change: 1 addition & 0 deletions configs/iconvsr_reds.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ min_max:
model:
name: BasicVSRModel
fix_iter: 5000
lr_mult: 0.125
generator:
name: IconVSR
mid_channels: 64
Expand Down
15 changes: 14 additions & 1 deletion docs/en_US/tutorials/video_super_resolution.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

## 1.1 Principle

Video super-resolution originates from image super-resolution, which aims to recover high-resolution (HR) images from one or more low resolution (LR) images. The difference between them is that the video is composed of multiple frames, so the video super-resolution usually uses the information between frames to repair. Here we provide the video super-resolution model [EDVR](https://arxiv.org/pdf/1905.02716.pdf).
Video super-resolution originates from image super-resolution, which aims to recover high-resolution (HR) images from one or more low resolution (LR) images. The difference between them is that the video is composed of multiple frames, so the video super-resolution usually uses the information between frames to repair. Here we provide the video super-resolution model [EDVR](https://arxiv.org/pdf/1905.02716.pdf).[BasicVSR](https://arxiv.org/pdf/2012.02181.pdf),[IconVSR](https://arxiv.org/pdf/2012.02181.pdf),[BasicVSR++](https://arxiv.org/pdf/2104.13371v1.pdf).

[EDVR](https://arxiv.org/pdf/1905.02716.pdf) wins the champions and outperforms the second place by a large margin in all four tracks in the NTIRE19 video restoration and enhancement challenges. The main difficulties of video super-resolution from two aspects: (1) how to align multiple frames given large motions, and (2) how to effectively fuse different frames with diverse motion and blur. First, to handle large motions, EDVR devise a Pyramid, Cascading and Deformable (PCD) alignment module, in which frame alignment is done at the feature level using deformable convolutions in a coarse-to-fine manner. Second, EDVR propose a Temporal and Spatial Attention (TSA) fusion module, in which attention is applied both temporally and spatially, so as to emphasize important features for subsequent restoration.

Expand Down Expand Up @@ -79,6 +79,7 @@ The metrics are PSNR / SSIM.
| EDVR_L_w_tsa_deblur | 35.1473 / 0.9526 |
| BasicVSR_x4 | 31.4325 / 0.8913 |
| IconVSR_x4 | 31.6882 / 0.8950 |
| BasicVSR++_x4 | 32.4018 / 0.9071 |


## 1.4 Model Download
Expand All @@ -92,6 +93,7 @@ The metrics are PSNR / SSIM.
| EDVR_L_w_tsa_deblur | REDS | [EDVR_L_w_tsa_deblur](https://paddlegan.bj.bcebos.com/models/EDVR_L_w_tsa_deblur.pdparams)
| BasicVSR_x4 | REDS | [BasicVSR_x4](https://paddlegan.bj.bcebos.com/models/BasicVSR_reds_x4.pdparams)
| IconVSR_x4 | REDS | [IconVSR_x4](https://paddlegan.bj.bcebos.com/models/IconVSR_reds_x4.pdparams)
| BasicVSR++_x4 | REDS | [BasicVSR++_x4](https://paddlegan.bj.bcebos.com/models/BasicVSR%2B%2B_reds_x4.pdparams)



Expand Down Expand Up @@ -120,3 +122,14 @@ The metrics are PSNR / SSIM.
year = {2021}
}
```

- 3. [BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment](https://arxiv.org/pdf/2104.13371v1.pdf)

```
@article{chan2021basicvsr++,
author = {Chan, Kelvin C.K. and Zhou, Shangchen and Xu, Xiangyu and Loy, Chen Change},
title = {BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment},
booktitle = {arXiv preprint arXiv:2104.13371},
year = {2021}
}
```
14 changes: 13 additions & 1 deletion docs/zh_CN/tutorials/video_super_resolution.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

## 1.1 原理介绍

视频超分源于图像超分,其目的是从一个或多个低分辨率(LR)图像中恢复高分辨率(HR)图像。它们的区别也很明显,由于视频是由多个帧组成的,所以视频超分通常利用帧间的信息来进行修复。这里我们提供视频超分模型[EDVR](https://arxiv.org/pdf/1905.02716.pdf).
视频超分源于图像超分,其目的是从一个或多个低分辨率(LR)图像中恢复高分辨率(HR)图像。它们的区别也很明显,由于视频是由多个帧组成的,所以视频超分通常利用帧间的信息来进行修复。这里我们提供视频超分模型[EDVR](https://arxiv.org/pdf/1905.02716.pdf),[BasicVSR](https://arxiv.org/pdf/2012.02181.pdf),[IconVSR](https://arxiv.org/pdf/2012.02181.pdf),[BasicVSR++](https://arxiv.org/pdf/2104.13371v1.pdf).

[EDVR](https://arxiv.org/pdf/1905.02716.pdf)模型在NTIRE19视频恢复和增强挑战赛的四个赛道中都赢得了冠军,并以巨大的优势超过了第二名。视频超分的主要难点在于(1)如何在给定大运动的情况下对齐多个帧;(2)如何有效地融合具有不同运动和模糊的不同帧。首先,为了处理大的运动,EDVR模型设计了一个金字塔级联的可变形(PCD)对齐模块,在该模块中,从粗到精的可变形卷积被使用来进行特征级的帧对齐。其次,EDVR使用了时空注意力(TSA)融合模块,该模块在时间和空间上同时应用注意力机制,以强调后续恢复的重要特征。

Expand Down Expand Up @@ -75,6 +75,7 @@
| EDVR_L_w_tsa_deblur | 35.1473 / 0.9526 |
| BasicVSR_x4 | 31.4325 / 0.8913 |
| IconVSR_x4 | 31.6882 / 0.8950 |
| BasicVSR++_x4 | 32.4018 / 0.9071 |

## 1.4 模型下载
| 模型 | 数据集 | 下载地址 |
Expand All @@ -87,6 +88,7 @@
| EDVR_L_w_tsa_deblur | REDS | [EDVR_L_w_tsa_deblur](https://paddlegan.bj.bcebos.com/models/EDVR_L_w_tsa_deblur.pdparams)
| BasicVSR_x4 | REDS | [BasicVSR_x4](https://paddlegan.bj.bcebos.com/models/BasicVSR_reds_x4.pdparams)
| IconVSR_x4 | REDS | [IconVSR_x4](https://paddlegan.bj.bcebos.com/models/IconVSR_reds_x4.pdparams)
| BasicVSR++_x4 | REDS | [BasicVSR++_x4](https://paddlegan.bj.bcebos.com/models/BasicVSR%2B%2B_reds_x4.pdparams)



Expand All @@ -113,3 +115,13 @@
year = {2021}
}
```
- 3. [BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment](https://arxiv.org/pdf/2104.13371v1.pdf)

```
@article{chan2021basicvsr++,
author = {Chan, Kelvin C.K. and Zhou, Shangchen and Xu, Xiangyu and Loy, Chen Change},
title = {BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment},
booktitle = {arXiv preprint arXiv:2104.13371},
year = {2021}
}
```
5 changes: 3 additions & 2 deletions ppgan/models/basicvsr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class BasicVSRModel(BaseSRModel):
Paper: BasicVSR: The Search for Essential Components in Video Super-Resolution and Beyond, CVPR, 2021
"""
def __init__(self, generator, fix_iter, pixel_criterion=None):
def __init__(self, generator, fix_iter, lr_mult, pixel_criterion=None):
"""Initialize the BasicVSR class.
Args:
Expand All @@ -41,6 +41,7 @@ def __init__(self, generator, fix_iter, pixel_criterion=None):
self.fix_iter = fix_iter
self.current_iter = 1
self.flag = True
self.lr_mult = lr_mult
init_basicvsr_weight(self.nets['generator'])

def setup_input(self, input):
Expand All @@ -65,7 +66,7 @@ def train_iter(self, optims=None):
for name, param in self.nets['generator'].named_parameters():
param.trainable = True
if 'spynet' in name:
param.optimize_attr['learning_rate'] = 0.125
param.optimize_attr['learning_rate'] = self.lr_mult
self.flag = False
for net in self.nets.values():
net.find_unused_parameters = False
Expand Down
1 change: 1 addition & 0 deletions ppgan/models/generators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@
from .iconvsr import IconVSR
from .gpen import GPEN
from .pan import PAN
from .basicvsr_plus_plus import BasicVSRPlusPlus
71 changes: 68 additions & 3 deletions ppgan/models/generators/basicvsr.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# Copyright (c) MMEditing Authors.

import paddle

import numpy as np

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.vision.ops import DeformConv2D
from ...utils.download import get_path_from_url
from ...modules.init import kaiming_normal_, constant_

from .builder import GENERATORS


Expand Down Expand Up @@ -607,3 +606,69 @@ def forward(self, lrs):
outputs[i] = out

return paddle.stack(outputs, axis=1)


class SecondOrderDeformableAlignment(nn.Layer):
"""Second-order deformable alignment module.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
kernel_size (int or tuple[int]): Same as nn.Conv2d.
stride (int or tuple[int]): Same as nn.Conv2d.
padding (int or tuple[int]): Same as nn.Conv2d.
dilation (int or tuple[int]): Same as nn.Conv2d.
groups (int): Same as nn.Conv2d.
deformable_groups (int).
"""
def __init__(self,
in_channels=128,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
groups=1,
deformable_groups=16):
super(SecondOrderDeformableAlignment, self).__init__()

self.conv_offset = nn.Sequential(
nn.Conv2D(3 * out_channels + 4, out_channels, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.1),
nn.Conv2D(out_channels, out_channels, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.1),
nn.Conv2D(out_channels, out_channels, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.1),
nn.Conv2D(out_channels, 27 * deformable_groups, 3, 1, 1),
)
self.dcn = DeformConv2D(in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
deformable_groups=deformable_groups)
self.init_offset()

def init_offset(self):
constant_(self.conv_offset[-1].weight, 0)
constant_(self.conv_offset[-1].bias, 0)

def forward(self, x, extra_feat, flow_1, flow_2):
extra_feat = paddle.concat([extra_feat, flow_1, flow_2], axis=1)
out = self.conv_offset(extra_feat)
o1, o2, mask = paddle.chunk(out, 3, axis=1)

# offset
offset = 10 * paddle.tanh(paddle.concat((o1, o2), axis=1))
offset_1, offset_2 = paddle.chunk(offset, 2, axis=1)
offset_1 = offset_1 + flow_1.flip(1).tile(
[1, offset_1.shape[1] // 2, 1, 1])
offset_2 = offset_2 + flow_2.flip(1).tile(
[1, offset_2.shape[1] // 2, 1, 1])
offset = paddle.concat([offset_1, offset_2], axis=1)

# mask
mask = F.sigmoid(mask)

out = self.dcn(x, offset, mask)
return out
Loading

0 comments on commit acd1615

Please sign in to comment.