forked from PaddlePaddle/PaddleGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* 【兴智杯复现赛】NAFNet * [feature] add TIPC and README.md
- Loading branch information
Showing
20 changed files
with
1,289 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. | ||
# | ||
#Licensed under the Apache License, Version 2.0 (the "License"); | ||
#you may not use this file except in compliance with the License. | ||
#You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
#Unless required by applicable law or agreed to in writing, software | ||
#distributed under the License is distributed on an "AS IS" BASIS, | ||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
#See the License for the specific language governing permissions and | ||
#limitations under the License. | ||
|
||
import os | ||
import sys | ||
import argparse | ||
|
||
sys.path.insert(0, os.getcwd()) | ||
import paddle | ||
from ppgan.apps import NAFNetPredictor | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--output_path", | ||
type=str, | ||
default='output_dir', | ||
help="path to output image dir") | ||
|
||
parser.add_argument("--weight_path", | ||
type=str, | ||
default=None, | ||
help="path to model checkpoint path") | ||
|
||
parser.add_argument("--seed", | ||
type=int, | ||
default=None, | ||
help="sample random seed for model's image generation") | ||
|
||
parser.add_argument('--images_path', | ||
default=None, | ||
required=True, | ||
type=str, | ||
help='Single image or images directory.') | ||
|
||
parser.add_argument("--cpu", | ||
dest="cpu", | ||
action="store_true", | ||
help="cpu mode.") | ||
|
||
args = parser.parse_args() | ||
|
||
if args.cpu: | ||
paddle.set_device('cpu') | ||
|
||
predictor = NAFNetPredictor(output_path=args.output_path, | ||
weight_path=args.weight_path, | ||
seed=args.seed) | ||
predictor.run(images_path=args.images_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
total_iters: 3200000 | ||
output_dir: output_dir | ||
|
||
model: | ||
name: NAFNetModel | ||
generator: | ||
name: NAFNet | ||
img_channel: 3 | ||
width: 64 | ||
enc_blk_nums: [2, 2, 4, 8] | ||
middle_blk_num: 12 | ||
dec_blk_nums: [2, 2, 2, 2] | ||
psnr_criterion: | ||
name: PSNRLoss | ||
|
||
dataset: | ||
train: | ||
name: NAFNetTrain | ||
rgb_dir: data/SIDD/train | ||
num_workers: 16 | ||
batch_size: 8 # 1GPU | ||
img_options: | ||
patch_size: 256 | ||
test: | ||
name: NAFNetVal | ||
rgb_dir: data/SIDD/val | ||
num_workers: 1 | ||
batch_size: 1 | ||
img_options: | ||
patch_size: 256 | ||
|
||
export_model: | ||
- {name: 'generator', inputs_num: 1} | ||
|
||
lr_scheduler: | ||
name: CosineAnnealingRestartLR | ||
learning_rate: !!float 125e-6 # num_gpu * 0.000125 | ||
periods: [3200000] | ||
restart_weights: [1] | ||
eta_min: !!float 1e-7 | ||
|
||
validate: | ||
interval: 5000 | ||
save_img: false | ||
|
||
metrics: | ||
psnr: # metric name, can be arbitrary | ||
name: PSNR | ||
crop_border: 4 | ||
test_y_channel: True | ||
ssim: | ||
name: SSIM | ||
crop_border: 4 | ||
test_y_channel: True | ||
|
||
optimizer: | ||
name: AdamW | ||
# add parameters of net_name to optim | ||
# name should in self.nets | ||
net_names: | ||
- generator | ||
weight_decay: 0.0 | ||
beta1: 0.9 | ||
beta2: 0.9 | ||
epsilon: 1e-8 | ||
|
||
log_config: | ||
interval: 10 | ||
visiual_interval: 5000 | ||
|
||
snapshot_config: | ||
interval: 5000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
English | [Chinese](../../zh_CN/tutorials/nafnet.md) | ||
|
||
## NAFNet:Simple Baselines for Image Restoration | ||
|
||
## 1、Introduction | ||
|
||
NAFNet proposes an ultra-simple baseline scheme, Baseline, which is not only computationally efficient but also outperforms the previous SOTA scheme; the resulting Baseline is further simplified to give NAFNet: the non-linear activation units are removed and the performance is further improved. The proposed solution achieves new SOTA performance for both SIDD noise reduction and GoPro deblurring tasks with a significant reduction in computational effort. The network design and features are shown in the figure below, using a UNet with skip connections as the overall architecture, modifying the Transformer module in the Restormer block and eliminating the activation function, adopting a simpler and more efficient simplegate design, and applying a simpler channel attention mechanism. | ||
|
||
![NAFNet](https://ai-studio-static-online.cdn.bcebos.com/699b87449c7e495f8655ae5ac8bc0eb77bed4d9cd828451e8939ddbc5732a704) | ||
|
||
For a more detailed introduction to the model, please refer to the original paper [Simple Baselines for Image Restoration](https://arxiv.org/pdf/2204.04676), PaddleGAN currently provides the weight of the denoising task. | ||
|
||
## 2 How to use | ||
|
||
### 2.1 Quick start | ||
|
||
After installing PaddleGAN, you can run a command as follows to generate the restorated image. | ||
|
||
```sh | ||
python applications/tools/nafnet_denoising.py --images_path ${PATH_OF_IMAGE} | ||
``` | ||
Where `PATH_OF_IMAGE` is the path of the image you need to denoise, or the path of the folder where the images is located. If you need to use your own model weights, run the following command, where `PATH_OF_MODEL` is the path to the model weights. | ||
|
||
```sh | ||
python applications/tools/nafnet_denoising.py --images_path ${PATH_OF_IMAGE} --weight_path ${PATH_OF_MODEL} | ||
``` | ||
|
||
### 2.2 Prepare dataset | ||
|
||
The Denoising training datasets is SIDD, an image denoising datasets, containing 30,000 noisy images from 10 different lighting conditions, which can be downloaded from [training datasets](https://www.eecs.yorku.ca/~kamel/sidd/dataset.php) and [Test datasets](https://drive.google.com/drive/folders/1S44fHXaVxAYW3KLNxK41NYCnyX9S79su). | ||
After downloading, decompress it to the data directory. After decompression, the structure of `SIDDdataset` is as following: | ||
|
||
```sh | ||
SIDD | ||
├── train | ||
│ ├── input | ||
│ └── target | ||
└── val | ||
├── input | ||
└── target | ||
|
||
``` | ||
Users can also use the [SIDD data](https://aistudio.baidu.com/aistudio/datasetdetail/149460) on AI studio, but need to rename the folders `input_crops` and `gt_crops` to `input` and ` target` | ||
|
||
### 2.3 Training | ||
An example is training to denoising. If you want to train for other tasks,If you want to train other tasks, you can change the dataset and modify the config file. | ||
|
||
```sh | ||
python -u tools/main.py --config-file configs/nafnet_denoising.yaml | ||
``` | ||
|
||
### 2.4 Test | ||
|
||
test model: | ||
```sh | ||
python tools/main.py --config-file configs/nafnet_denoising.yaml --evaluate-only --load ${PATH_OF_WEIGHT} | ||
``` | ||
|
||
## 3 Results | ||
Denoising | ||
| model | dataset | PSNR/SSIM | | ||
|---|---|---| | ||
| NAFNet | SIDD Val | 43.1468 / 0.9563 | | ||
|
||
## 4 Download | ||
|
||
| model | link | | ||
|---|---| | ||
| NAFNet| [NAFNet_Denoising](https://paddlegan.bj.bcebos.com/models/NAFNet_Denoising.pdparams) | | ||
|
||
# References | ||
|
||
- [Simple Baselines for Image Restoration](https://arxiv.org/pdf/2204.04676) | ||
|
||
``` | ||
@article{chen_simple_nodate, | ||
title = {Simple {Baselines} for {Image} {Restoration}}, | ||
abstract = {Although there have been significant advances in the field of image restoration recently, the system complexity of the state-of-the-art (SOTA) methods is increasing as well, which may hinder the convenient analysis and comparison of methods. In this paper, we propose a simple baseline that exceeds the SOTA methods and is computationally efficient. To further simplify the baseline, we reveal that the nonlinear activation functions, e.g. Sigmoid, ReLU, GELU, Softmax, etc. are not necessary: they could be replaced by multiplication or removed. Thus, we derive a Nonlinear Activation Free Network, namely NAFNet, from the baseline. SOTA results are achieved on various challenging benchmarks, e.g. 33.69 dB PSNR on GoPro (for image deblurring), exceeding the previous SOTA 0.38 dB with only 8.4\% of its computational costs; 40.30 dB PSNR on SIDD (for image denoising), exceeding the previous SOTA 0.28 dB with less than half of its computational costs. The code and the pretrained models will be released at github.com/megvii-research/NAFNet.}, | ||
language = {en}, | ||
author = {Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian}, | ||
pages = {17} | ||
} | ||
``` | ||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
[English](../../en_US/tutorials/nafnet.md) | 中文 | ||
|
||
# NAFNet:图像恢复的简单基线 | ||
|
||
## 1、简介 | ||
|
||
NAFNet提出一种超简基线方案Baseline,它不仅计算高效同时性能优于之前SOTA方案;在所得Baseline基础上进一步简化得到了NAFNet:移除了非线性激活单元且性能进一步提升。所提方案在SIDD降噪与GoPro去模糊任务上均达到了新的SOTA性能,同时计算量大幅降低。网络设计和特点如下图所示,采用带跳过连接的UNet作为整体架构,同时修改了Restormer块中的Transformer模块,并取消了激活函数,采取更简单有效的simplegate设计,运用更简单的通道注意力机制 | ||
|
||
![NAFNet](https://ai-studio-static-online.cdn.bcebos.com/699b87449c7e495f8655ae5ac8bc0eb77bed4d9cd828451e8939ddbc5732a704) | ||
|
||
对模型更详细的介绍,可参考论文原文[Simple Baselines for Image Restoration](https://arxiv.org/pdf/2204.04676),PaddleGAN中目前提供去噪任务的权重 | ||
|
||
## 2 如何使用 | ||
|
||
### 2.1 快速体验 | ||
|
||
安装`PaddleGAN`之后进入`PaddleGAN`文件夹下,运行如下命令即生成修复后的图像`./output_dir/Denoising/image_name.png` | ||
|
||
```sh | ||
python applications/tools/nafnet_denoising.py --images_path ${PATH_OF_IMAGE} | ||
``` | ||
其中`PATH_OF_IMAGE`为你需要去噪的图像路径,或图像所在文件夹的路径。若需要使用自己的模型权重,则运行如下命令,其中`PATH_OF_MODEL`为模型权重的路径 | ||
|
||
```sh | ||
python applications/tools/nafnet_denoising.py --images_path ${PATH_OF_IMAGE} --weight_path ${PATH_OF_MODEL} | ||
``` | ||
|
||
### 2.2 数据准备 | ||
|
||
Denoising训练数据是SIDD,一个图像去噪数据集,包含来自10个不同光照条件下的3万幅噪声图像,可以从[训练数据集下载](https://www.eecs.yorku.ca/~kamel/sidd/dataset.php)和[测试数据集下载](https://drive.google.com/drive/folders/1S44fHXaVxAYW3KLNxK41NYCnyX9S79su)下载。 | ||
下载后解压到data目录下,解压完成后数据分布如下所示: | ||
|
||
```sh | ||
SIDD | ||
├── train | ||
│ ├── input | ||
│ └── target | ||
└── val | ||
├── input | ||
└── target | ||
|
||
``` | ||
用户也可以使用AI studio上的[SIDD数据](https://aistudio.baidu.com/aistudio/datasetdetail/149460),但需要将文件夹`input_crops`与`gt_crops`重命名为`input`和`target` | ||
|
||
### 2.3 训练 | ||
示例以训练Denoising的数据为例。如果想训练其他任务可以更换数据集并修改配置文件 | ||
|
||
```sh | ||
python -u tools/main.py --config-file configs/nafnet_denoising.yaml | ||
``` | ||
|
||
### 2.4 测试 | ||
|
||
测试模型: | ||
```sh | ||
python tools/main.py --config-file configs/nafnet_denoising.yaml --evaluate-only --load ${PATH_OF_WEIGHT} | ||
``` | ||
|
||
## 3 结果展示 | ||
|
||
去噪 | ||
| 模型 | 数据集 | PSNR/SSIM | | ||
|---|---|---| | ||
| NAFNet | SIDD Val | 43.1468 / 0.9563 | | ||
|
||
## 4 模型下载 | ||
|
||
| 模型 | 下载地址 | | ||
|---|---| | ||
| NAFNet| [NAFNet_Denoising](https://paddlegan.bj.bcebos.com/models/NAFNet_Denoising.pdparams) | | ||
|
||
|
||
|
||
# 参考文献 | ||
|
||
- [Simple Baselines for Image Restoration](https://arxiv.org/pdf/2204.04676) | ||
|
||
``` | ||
@article{chen_simple_nodate, | ||
title = {Simple {Baselines} for {Image} {Restoration}}, | ||
abstract = {Although there have been significant advances in the field of image restoration recently, the system complexity of the state-of-the-art (SOTA) methods is increasing as well, which may hinder the convenient analysis and comparison of methods. In this paper, we propose a simple baseline that exceeds the SOTA methods and is computationally efficient. To further simplify the baseline, we reveal that the nonlinear activation functions, e.g. Sigmoid, ReLU, GELU, Softmax, etc. are not necessary: they could be replaced by multiplication or removed. Thus, we derive a Nonlinear Activation Free Network, namely NAFNet, from the baseline. SOTA results are achieved on various challenging benchmarks, e.g. 33.69 dB PSNR on GoPro (for image deblurring), exceeding the previous SOTA 0.38 dB with only 8.4\% of its computational costs; 40.30 dB PSNR on SIDD (for image denoising), exceeding the previous SOTA 0.28 dB with less than half of its computational costs. The code and the pretrained models will be released at github.com/megvii-research/NAFNet.}, | ||
language = {en}, | ||
author = {Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian}, | ||
pages = {17} | ||
} | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.