Skip to content

Commit

Permalink
fix sr docs and add div2k process script (PaddlePaddle#154)
Browse files Browse the repository at this point in the history
  • Loading branch information
LielinJiang authored Jan 29, 2021
1 parent 95e5f4f commit 9d9f69f
Show file tree
Hide file tree
Showing 9 changed files with 367 additions and 64 deletions.
4 changes: 2 additions & 2 deletions configs/drn_psnr_x4_div2k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ dataset:
keys: [image, image, image]
test:
name: SRDataset
gt_folder: data/DIV2K/val_set14/Set14
lq_folder: data/DIV2K/val_set14/Set14_bicLRx4
gt_folder: data/Set14/GTmod12
lq_folder: data/Set14/LRbicx4
scale: 4
preprocess:
- name: LoadImageFromFile
Expand Down
8 changes: 4 additions & 4 deletions configs/esrgan_psnr_x4_div2k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ model:
dataset:
train:
name: SRDataset
gt_folder: data/DIV2K/DIV2K_train_HR_sub
lq_folder: data/DIV2K/DIV2K_train_LR_bicubic/X4_sub
gt_folder: data/DIV2K/DIV2K_train_HR
lq_folder: data/DIV2K/DIV2K_train_LR_bicubic/X4
num_workers: 4
batch_size: 16
scale: 4
Expand Down Expand Up @@ -49,8 +49,8 @@ dataset:
keys: [image, image]
test:
name: SRDataset
gt_folder: data/DIV2K/val_set14/Set14
lq_folder: data/DIV2K/val_set14/Set14_bicLRx4
gt_folder: data/Set14/GTmod12
lq_folder: data/Set14/LRbicx4
scale: 4
preprocess:
- name: LoadImageFromFile
Expand Down
4 changes: 2 additions & 2 deletions configs/esrgan_x4_div2k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ dataset:
keys: [image, image]
test:
name: SRDataset
gt_folder: data/DIV2K/val_set14/Set14
lq_folder: data/DIV2K/val_set14/Set14_bicLRx4
gt_folder: data/Set14/GTmod12
lq_folder: data/Set14/LRbicx4
scale: 4
preprocess:
- name: LoadImageFromFile
Expand Down
4 changes: 2 additions & 2 deletions configs/lesrcnn_psnr_x4_div2k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ dataset:
keys: [image, image]
test:
name: SRDataset
gt_folder: data/DIV2K/val_set14/Set14
lq_folder: data/DIV2K/val_set14/Set14_bicLRx4
gt_folder: data/Set14/GTmod12
lq_folder: data/Set14/LRbicx4
scale: 4
preprocess:
- name: LoadImageFromFile
Expand Down
4 changes: 2 additions & 2 deletions configs/realsr_bicubic_noise_x4_df2k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ dataset:
keys: [image]
test:
name: SRDataset
gt_folder: data/DIV2K/val_set14/Set14
lq_folder: data/DIV2K/val_set14/Set14_bicLRx4
gt_folder: data/Set14/GTmod12
lq_folder: data/Set14/LRbicx4
scale: 4
preprocess:
- name: LoadImageFromFile
Expand Down
4 changes: 2 additions & 2 deletions configs/realsr_kernel_noise_x4_dped.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ dataset:
keys: [image]
test:
name: SRDataset
gt_folder: data/DIV2K/val_set14/Set14
lq_folder: data/DIV2K/val_set14/Set14_bicLRx4
gt_folder: data/Set14/GTmod12
lq_folder: data/Set14/LRbicx4
scale: 4
preprocess:
- name: LoadImageFromFile
Expand Down
290 changes: 290 additions & 0 deletions data/process_div2k_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
import os
import re
import sys
import cv2
import argparse
import numpy as np
import os.path as osp

from time import time
from multiprocessing import Pool
from shutil import get_terminal_size
from ppgan.datasets.base_dataset import scandir


class Timer:
"""A flexible Timer class."""
def __init__(self, start=True, print_tmpl=None):
self._is_running = False
self.print_tmpl = print_tmpl if print_tmpl else '{:.3f}'
if start:
self.start()

@property
def is_running(self):
"""bool: indicate whether the timer is running"""
return self._is_running

def __enter__(self):
self.start()
return self

def __exit__(self, type, value, traceback):
print(self.print_tmpl.format(self.since_last_check()))
self._is_running = False

def start(self):
"""Start the timer."""
if not self._is_running:
self._t_start = time()
self._is_running = True
self._t_last = time()

def since_start(self):
"""Total time since the timer is started.
Returns (float): Time in seconds.
"""
if not self._is_running:
raise ValueError('timer is not running')
self._t_last = time()
return self._t_last - self._t_start

def since_last_check(self):
"""Time since the last checking.
Either :func:`since_start` or :func:`since_last_check` is a checking
operation.
Returns (float): Time in seconds.
"""
if not self._is_running:
raise ValueError('timer is not running')
dur = time() - self._t_last
self._t_last = time()
return dur


class ProgressBar:
"""A progress bar which can print the progress."""
def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout):
self.task_num = task_num
self.bar_width = bar_width
self.completed = 0
self.file = file
if start:
self.start()

@property
def terminal_width(self):
width, _ = get_terminal_size()
return width

def start(self):
if self.task_num > 0:
self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, '
'elapsed: 0s, ETA:')
else:
self.file.write('completed: 0, elapsed: 0s')
self.file.flush()
self.timer = Timer()

def update(self, num_tasks=1):
assert num_tasks > 0
self.completed += num_tasks
elapsed = self.timer.since_start()
if elapsed > 0:
fps = self.completed / elapsed
else:
fps = float('inf')
if self.task_num > 0:
percentage = self.completed / float(self.task_num)
eta = int(elapsed * (1 - percentage) / percentage + 0.5)
msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \
f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \
f'ETA: {eta:5}s'

bar_width = min(self.bar_width,
int(self.terminal_width - len(msg)) + 2,
int(self.terminal_width * 0.6))
bar_width = max(2, bar_width)
mark_width = int(bar_width * percentage)
bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width)
self.file.write(msg.format(bar_chars))
else:
self.file.write(
f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,'
f' {fps:.1f} tasks/s')
self.file.flush()


def main_extract_subimages(args):
"""A multi-thread tool to crop large images to sub-images for faster IO.
It is used for DIV2K dataset.
args (dict): Configuration dict. It contains:
n_thread (int): Thread number.
compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9.
A higher value means a smaller size and longer compression time.
Use 0 for faster CPU decompression. Default: 3, same in cv2.
input_folder (str): Path to the input folder.
save_folder (str): Path to save folder.
crop_size (int): Crop size.
step (int): Step for overlapped sliding window.
thresh_size (int): Threshold size. Patches whose size is lower
than thresh_size will be dropped.
Usage:
For each folder, run this script.
Typically, there are four folders to be processed for DIV2K dataset.
DIV2K_train_HR
DIV2K_train_LR_bicubic/X2
DIV2K_train_LR_bicubic/X3
DIV2K_train_LR_bicubic/X4
After process, each sub_folder should have the same number of
subimages.
Remember to modify opt configurations according to your settings.
"""

opt = {}
opt['n_thread'] = args.n_thread
opt['compression_level'] = args.compression_level

# HR images
opt['input_folder'] = osp.join(args.data_root, 'DIV2K_train_HR')
opt['save_folder'] = osp.join(args.data_root, 'DIV2K_train_HR_sub')
opt['crop_size'] = args.crop_size
opt['step'] = args.step
opt['thresh_size'] = args.thresh_size
extract_subimages(opt)

for scale in [2, 3, 4]:
opt['input_folder'] = osp.join(args.data_root,
f'DIV2K_train_LR_bicubic/X{scale}')
opt['save_folder'] = osp.join(args.data_root,
f'DIV2K_train_LR_bicubic/X{scale}_sub')
opt['crop_size'] = args.crop_size // scale
opt['step'] = args.step // scale
opt['thresh_size'] = args.thresh_size // scale
extract_subimages(opt)


def extract_subimages(opt):
"""Crop images to subimages.
Args:
opt (dict): Configuration dict. It contains:
input_folder (str): Path to the input folder.
save_folder (str): Path to save folder.
n_thread (int): Thread number.
"""
input_folder = opt['input_folder']
save_folder = opt['save_folder']
if not osp.exists(save_folder):
os.makedirs(save_folder)
print(f'mkdir {save_folder} ...')
else:
print(f'Folder {save_folder} already exists. Exit.')
sys.exit(1)

img_list = list(scandir(input_folder))
img_list = [osp.join(input_folder, v) for v in img_list]

prog_bar = ProgressBar(len(img_list))
pool = Pool(opt['n_thread'])
for path in img_list:
pool.apply_async(worker,
args=(path, opt),
callback=lambda arg: prog_bar.update())
pool.close()
pool.join()
print('All processes done.')


def worker(path, opt):
"""Worker for each process.
Args:
path (str): Image path.
opt (dict): Configuration dict. It contains:
crop_size (int): Crop size.
step (int): Step for overlapped sliding window.
thresh_size (int): Threshold size. Patches whose size is smaller
than thresh_size will be dropped.
save_folder (str): Path to save folder.
compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION.
Returns:
process_info (str): Process information displayed in progress bar.
"""
crop_size = opt['crop_size']
step = opt['step']
thresh_size = opt['thresh_size']
img_name, extension = osp.splitext(osp.basename(path))

# remove the x2, x3, x4 and x8 in the filename for DIV2K
img_name = re.sub('x[2348]', '', img_name)

img = cv2.imread(path, cv2.IMREAD_UNCHANGED)

if img.ndim == 2 or img.ndim == 3:
h, w = img.shape[:2]
else:
raise ValueError(f'Image ndim should be 2 or 3, but got {img.ndim}')

h_space = np.arange(0, h - crop_size + 1, step)
if h - (h_space[-1] + crop_size) > thresh_size:
h_space = np.append(h_space, h - crop_size)
w_space = np.arange(0, w - crop_size + 1, step)
if w - (w_space[-1] + crop_size) > thresh_size:
w_space = np.append(w_space, w - crop_size)

index = 0
for x in h_space:
for y in w_space:
index += 1
cropped_img = img[x:x + crop_size, y:y + crop_size, ...]
cv2.imwrite(
osp.join(opt['save_folder'],
f'{img_name}_s{index:03d}{extension}'), cropped_img,
[cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
process_info = f'Processing {img_name} ...'
return process_info


def parse_args():
parser = argparse.ArgumentParser(
description='Prepare DIV2K dataset',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--data-root', help='dataset root')
parser.add_argument('--crop-size',
nargs='?',
default=480,
help='cropped size for HR images')
parser.add_argument('--step',
nargs='?',
default=240,
help='step size for HR images')
parser.add_argument('--thresh-size',
nargs='?',
default=0,
help='threshold size for HR images')
parser.add_argument('--compression-level',
nargs='?',
default=3,
help='compression level when save png images')
parser.add_argument('--n-thread',
nargs='?',
default=20,
help='thread number when using multiprocessing')

args = parser.parse_args()
return args


if __name__ == '__main__':
args = parse_args()
# extract subimages
main_extract_subimages(args)
Loading

0 comments on commit 9d9f69f

Please sign in to comment.