forked from PaddlePaddle/PaddleGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feature] add rcan model for remote sensing image super-resolution (P…
…addlePaddle#610) * [feature] add rcan model for super-resolution
- Loading branch information
Showing
7 changed files
with
484 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
total_iters: 1000000 | ||
output_dir: output_dir | ||
# tensor range for function tensor2img | ||
min_max: | ||
(0., 255.) | ||
|
||
model: | ||
name: RCANModel | ||
generator: | ||
name: RCAN | ||
scale: 4 | ||
n_resgroups: 10 | ||
n_resblocks: 20 | ||
pixel_criterion: | ||
name: L1Loss | ||
|
||
dataset: | ||
train: | ||
name: SRDataset | ||
gt_folder: data/DIV2K/DIV2K_train_HR_sub | ||
lq_folder: data/DIV2K/DIV2K_train_LR_bicubic/X4_sub | ||
num_workers: 4 | ||
batch_size: 16 | ||
scale: 4 | ||
preprocess: | ||
- name: LoadImageFromFile | ||
key: lq | ||
- name: LoadImageFromFile | ||
key: gt | ||
- name: Transforms | ||
input_keys: [lq, gt] | ||
pipeline: | ||
- name: SRPairedRandomCrop | ||
gt_patch_size: 192 | ||
scale: 4 | ||
keys: [image, image] | ||
- name: PairedRandomHorizontalFlip | ||
keys: [image, image] | ||
- name: PairedRandomVerticalFlip | ||
keys: [image, image] | ||
- name: PairedRandomTransposeHW | ||
keys: [image, image] | ||
- name: Transpose | ||
keys: [image, image] | ||
- name: Normalize | ||
mean: [0., .0, 0.] | ||
std: [1., 1., 1.] | ||
keys: [image, image] | ||
test: | ||
name: SRDataset | ||
gt_folder: data/Set14/GTmod12 | ||
lq_folder: data/Set14/LRbicx4 | ||
scale: 4 | ||
preprocess: | ||
- name: LoadImageFromFile | ||
key: lq | ||
- name: LoadImageFromFile | ||
key: gt | ||
- name: Transforms | ||
input_keys: [lq, gt] | ||
pipeline: | ||
- name: Transpose | ||
keys: [image, image] | ||
- name: Normalize | ||
mean: [0., .0, 0.] | ||
std: [1., 1., 1.] | ||
keys: [image, image] | ||
|
||
lr_scheduler: | ||
name: CosineAnnealingRestartLR | ||
learning_rate: 0.0001 | ||
periods: [1000000] | ||
restart_weights: [1] | ||
eta_min: !!float 1e-7 | ||
|
||
optimizer: | ||
name: Adam | ||
# add parameters of net_name to optim | ||
# name should in self.nets | ||
net_names: | ||
- generator | ||
beta1: 0.9 | ||
beta2: 0.99 | ||
|
||
validate: | ||
interval: 2500 | ||
save_img: false | ||
|
||
metrics: | ||
psnr: # metric name, can be arbitrary | ||
name: PSNR | ||
crop_border: 4 | ||
test_y_channel: True | ||
ssim: | ||
name: SSIM | ||
crop_border: 4 | ||
test_y_channel: True | ||
|
||
log_config: | ||
interval: 10 | ||
visiual_interval: 5000 | ||
|
||
snapshot_config: | ||
interval: 2500 |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
70 changes: 70 additions & 0 deletions
70
docs/zh_CN/tutorials/remote_sensing_image_super-resolution.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# 1.单幅遥感图像超分辨率重建 | ||
|
||
## 1.1 背景和原理介绍 | ||
|
||
**意义与应用场景**:单幅影像超分辨率重建一直是low-level视觉领域中一个比较热门的任务,其可以成为修复老电影、老照片的技术手段,也可以为图像分割、目标检测等下游任务提供质量较高的数据。在遥感中的应用场景也比较广泛,例如:在**船舶检测和分类**等诸多遥感影像应用中,**提高遥感影像分辨率具有重要意义**。 | ||
|
||
**原理**:单幅遥感影像的超分辨率重建本质上与单幅影像超分辨率重建类似,均是使用RGB三通道的低分辨率影像生成纹理清晰的高分辨率影像。本项目复现的论文是[Yulun Zhang](http://yulunzhang.com/), [Kunpeng Li](https://kunpengli1994.github.io/), [Kai Li](http://kailigo.github.io/), [Lichen Wang](https://sites.google.com/site/lichenwang123/), [Bineng Zhong](https://scholar.google.de/citations?user=hvRBydsAAAAJ&hl=en), and [Yun Fu](http://www1.ece.neu.edu/~yunfu/), 发表在ECCV 2018上的论文[《Image Super-Resolution Using Very Deep Residual Channel Attention Networks》](https://arxiv.org/abs/1807.02758)。 | ||
作者提出了一个深度残差通道注意力网络(RCAN),引入一种通道注意力机制(CA),通过考虑通道之间的相互依赖性来自适应地重新调整特征。该模型取得优异的性能,因此本项目选择RCAN进行单幅遥感影像的x4超分辨率重建。 | ||
|
||
## 1.2 如何使用 | ||
|
||
### 1.2.1 数据准备 | ||
本项目的训练分为两个阶段,第一个阶段使用[DIV2K数据集](https://data.vision.ee.ethz.ch/cvl/DIV2K/)进行预训练RCANx4模型,然后基于该模型再使用[遥感超分数据集合](https://aistudio.baidu.com/aistudio/datasetdetail/129011)进行迁移学习。 | ||
- 关于DIV2K数据的准备方法参考[该文档](./single_image_super_resolution.md) | ||
- 遥感超分数据准备 | ||
- 数据已经上传至AI studio中,该数据为从UC Merced Land-Use Dataset 21 级土地利用图像遥感数据集中抽取部分遥感影像,通过BI退化生成的HR-LR影像对用于训练超分模型,其中训练集6720对,测试集420对 | ||
- 下载解压后的文件组织形式如下 | ||
``` | ||
├── RSdata_for_SR | ||
├── train_HR | ||
├── train_LR | ||
| └──x4 | ||
├── test_HR | ||
├── test_LR | ||
| └──x4 | ||
``` | ||
|
||
### 1.2.2 DIV2K数据集上训练/测试 | ||
|
||
首先是在DIV2K数据集上训练RCANx4模型,并以Set14作为测试集。按照论文需要准备RCANx2作为初始化权重,可通过下表进行获取。 | ||
|
||
| 模型 | 数据集 | 下载地址 | | ||
|---|---|---| | ||
| RCANx2 | DIV2K | [RCANx2](https://paddlegan.bj.bcebos.com/models/RCAN_X2_DIV2K.pdparams) | ||
|
||
|
||
将DIV2K数据按照 [该文档](./single_image_super_resolution.md)所示准备好后,执行以下命令训练模型,`--load`的参数为下载好的RCANx2模型权重所在路径。 | ||
|
||
```shell | ||
python -u tools/main.py --config-file configs/rcan_rssr_x4.yaml --load ${PATH_OF_WEIGHT} | ||
``` | ||
|
||
训练好后,执行以下命令可对测试集Set14预测,`--load`的参数为训练好的RCANx4模型权重 | ||
```shell | ||
python tools/main.py --config-file configs/rcan_rssr_x4.yaml --evaluate-only --load ${PATH_OF_WEIGHT} | ||
``` | ||
|
||
本项目在DIV2K数据集训练迭代第57250次得到的权重[RCAN_X4_DIV2K](https://pan.baidu.com/s/1rI7yUdD4T1DE0RZB5yHXjA)(提取码:aglw),在Set14数据集上测得的精度:`PSNR:28.8959 SSIM:0.7896` | ||
|
||
### 1.2.3 遥感超分数据上迁移学习训练/测试 | ||
- 使用该数据集,需要修改`rcan_rssr_x4.yaml`文件中训练集与测试集的高分辨率图像路径和低分辨率图像路径,即文件中的`gt_folder`和`lq_folder`。 | ||
- 同时,由于使用了在DIV2K数据集上训练的RCAN_X4_DIV2K模型权重来进行迁移学习,所以训练的迭代次数`total_iters`也可以进行修改,并不需要很多次数的迭代就能有良好的效果。训练模型中`--load`的参数为下载好的RCANx4模型权重所在路径。 | ||
|
||
训练模型: | ||
```shell | ||
python -u tools/main.py --config-file configs/rcan_rssr_x4.yaml --load ${PATH_OF_RCANx4_WEIGHT} | ||
``` | ||
测试模型: | ||
```shell | ||
python -u tools/main.py --config-file configs/rcan_rssr_x4.yaml --load ${PATH_OF_RCANx4_WEIGHT} | ||
``` | ||
|
||
## 1.3 实验结果 | ||
|
||
- RCANx4遥感影像超分效果 | ||
|
||
<img src=../../imgs/RSSR.png></img> | ||
|
||
- [RCAN遥感影像超分辨率重建 Ai studio 项目在线体验](https://aistudio.baidu.com/aistudio/projectdetail/3508912) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
# base on https://github.com/kongdebug/RCAN-Paddle | ||
import math | ||
import paddle | ||
import paddle.nn as nn | ||
|
||
from .builder import GENERATORS | ||
|
||
|
||
def default_conv(in_channels, out_channels, kernel_size, bias=True): | ||
weight_attr = paddle.ParamAttr( | ||
initializer=paddle.nn.initializer.XavierUniform(), need_clip=True) | ||
return nn.Conv2D(in_channels, | ||
out_channels, | ||
kernel_size, | ||
padding=(kernel_size // 2), | ||
weight_attr=weight_attr, | ||
bias_attr=bias) | ||
|
||
|
||
class MeanShift(nn.Conv2D): | ||
|
||
def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): | ||
super(MeanShift, self).__init__(3, 3, kernel_size=1) | ||
std = paddle.to_tensor(rgb_std) | ||
self.weight.set_value(paddle.eye(3).reshape([3, 3, 1, 1])) | ||
self.weight.set_value(self.weight / (std.reshape([3, 1, 1, 1]))) | ||
|
||
mean = paddle.to_tensor(rgb_mean) | ||
self.bias.set_value(sign * rgb_range * mean / std) | ||
|
||
self.weight.trainable = False | ||
self.bias.trainable = False | ||
|
||
|
||
## Channel Attention (CA) Layer | ||
class CALayer(nn.Layer): | ||
|
||
def __init__(self, channel, reduction=16): | ||
super(CALayer, self).__init__() | ||
# global average pooling: feature --> point | ||
self.avg_pool = nn.AdaptiveAvgPool2D(1) | ||
# feature channel downscale and upscale --> channel weight | ||
self.conv_du = nn.Sequential( | ||
nn.Conv2D(channel, | ||
channel // reduction, | ||
1, | ||
padding=0, | ||
bias_attr=True), nn.ReLU(), | ||
nn.Conv2D(channel // reduction, | ||
channel, | ||
1, | ||
padding=0, | ||
bias_attr=True), nn.Sigmoid()) | ||
|
||
def forward(self, x): | ||
y = self.avg_pool(x) | ||
y = self.conv_du(y) | ||
return x * y | ||
|
||
|
||
class RCAB(nn.Layer): | ||
|
||
def __init__(self, | ||
conv, | ||
n_feat, | ||
kernel_size, | ||
reduction=16, | ||
bias=True, | ||
bn=False, | ||
act=nn.ReLU(), | ||
res_scale=1): | ||
super(RCAB, self).__init__() | ||
modules_body = [] | ||
for i in range(2): | ||
modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) | ||
if bn: modules_body.append(nn.BatchNorm2D(n_feat)) | ||
if i == 0: modules_body.append(act) | ||
modules_body.append(CALayer(n_feat, reduction)) | ||
self.body = nn.Sequential(*modules_body) | ||
self.res_scale = res_scale | ||
|
||
def forward(self, x): | ||
res = self.body(x) | ||
res += x | ||
return res | ||
|
||
|
||
## Residual Group (RG) | ||
class ResidualGroup(nn.Layer): | ||
|
||
def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, | ||
n_resblocks): | ||
super(ResidualGroup, self).__init__() | ||
modules_body = [] | ||
modules_body = [ | ||
RCAB( | ||
conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(), res_scale=1) \ | ||
for _ in range(n_resblocks)] | ||
modules_body.append(conv(n_feat, n_feat, kernel_size)) | ||
self.body = nn.Sequential(*modules_body) | ||
|
||
def forward(self, x): | ||
res = self.body(x) | ||
res += x | ||
return res | ||
|
||
|
||
class Upsampler(nn.Sequential): | ||
|
||
def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): | ||
m = [] | ||
if (scale & (scale - 1)) == 0: # Is scale = 2^n? | ||
for _ in range(int(math.log(scale, 2))): | ||
m.append(conv(n_feats, 4 * n_feats, 3, bias)) | ||
m.append(nn.PixelShuffle(2)) | ||
if bn: m.append(nn.BatchNorm2D(n_feats)) | ||
|
||
if act == 'relu': | ||
m.append(nn.ReLU()) | ||
elif act == 'prelu': | ||
m.append(nn.PReLU(n_feats)) | ||
|
||
elif scale == 3: | ||
m.append(conv(n_feats, 9 * n_feats, 3, bias)) | ||
m.append(nn.PixelShuffle(3)) | ||
if bn: m.append(nn.BatchNorm2D(n_feats)) | ||
|
||
if act == 'relu': | ||
m.append(nn.ReLU()) | ||
elif act == 'prelu': | ||
m.append(nn.PReLU(n_feats)) | ||
else: | ||
raise NotImplementedError | ||
|
||
super(Upsampler, self).__init__(*m) | ||
|
||
|
||
@GENERATORS.register() | ||
class RCAN(nn.Layer): | ||
|
||
def __init__( | ||
self, | ||
scale, | ||
n_resgroups, | ||
n_resblocks, | ||
n_feats=64, | ||
n_colors=3, | ||
rgb_range=255, | ||
kernel_size=3, | ||
reduction=16, | ||
conv=default_conv, | ||
): | ||
super(RCAN, self).__init__() | ||
self.scale = scale | ||
act = nn.ReLU() | ||
|
||
n_resgroups = n_resgroups | ||
n_resblocks = n_resblocks | ||
n_feats = n_feats | ||
kernel_size = kernel_size | ||
reduction = reduction | ||
scale = scale | ||
act = nn.ReLU() | ||
|
||
rgb_mean = (0.4488, 0.4371, 0.4040) | ||
rgb_std = (1.0, 1.0, 1.0) | ||
self.sub_mean = MeanShift(rgb_range, rgb_mean, rgb_std) | ||
|
||
# define head module | ||
modules_head = [conv(n_colors, n_feats, kernel_size)] | ||
|
||
# define body module | ||
modules_body = [ | ||
ResidualGroup( | ||
conv, n_feats, kernel_size, reduction, act=act, res_scale= 1, n_resblocks=n_resblocks) \ | ||
for _ in range(n_resgroups)] | ||
|
||
modules_body.append(conv(n_feats, n_feats, kernel_size)) | ||
|
||
# define tail module | ||
modules_tail = [ | ||
Upsampler(conv, scale, n_feats, act=False), | ||
conv(n_feats, n_colors, kernel_size) | ||
] | ||
|
||
self.head = nn.Sequential(*modules_head) | ||
self.body = nn.Sequential(*modules_body) | ||
self.tail = nn.Sequential(*modules_tail) | ||
|
||
self.add_mean = MeanShift(rgb_range, rgb_mean, rgb_std, 1) | ||
|
||
def forward(self, x): | ||
x = self.sub_mean(x) | ||
x = self.head(x) | ||
|
||
res = self.body(x) | ||
res += x | ||
|
||
x = self.tail(res) | ||
x = self.add_mean(x) | ||
|
||
return x |
Oops, something went wrong.