forked from PaddlePaddle/PaddleGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix sr docs and add div2k process script (PaddlePaddle#154)
- Loading branch information
1 parent
95e5f4f
commit 9d9f69f
Showing
9 changed files
with
367 additions
and
64 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,290 @@ | ||
import os | ||
import re | ||
import sys | ||
import cv2 | ||
import argparse | ||
import numpy as np | ||
import os.path as osp | ||
|
||
from time import time | ||
from multiprocessing import Pool | ||
from shutil import get_terminal_size | ||
from ppgan.datasets.base_dataset import scandir | ||
|
||
|
||
class Timer: | ||
"""A flexible Timer class.""" | ||
def __init__(self, start=True, print_tmpl=None): | ||
self._is_running = False | ||
self.print_tmpl = print_tmpl if print_tmpl else '{:.3f}' | ||
if start: | ||
self.start() | ||
|
||
@property | ||
def is_running(self): | ||
"""bool: indicate whether the timer is running""" | ||
return self._is_running | ||
|
||
def __enter__(self): | ||
self.start() | ||
return self | ||
|
||
def __exit__(self, type, value, traceback): | ||
print(self.print_tmpl.format(self.since_last_check())) | ||
self._is_running = False | ||
|
||
def start(self): | ||
"""Start the timer.""" | ||
if not self._is_running: | ||
self._t_start = time() | ||
self._is_running = True | ||
self._t_last = time() | ||
|
||
def since_start(self): | ||
"""Total time since the timer is started. | ||
Returns (float): Time in seconds. | ||
""" | ||
if not self._is_running: | ||
raise ValueError('timer is not running') | ||
self._t_last = time() | ||
return self._t_last - self._t_start | ||
|
||
def since_last_check(self): | ||
"""Time since the last checking. | ||
Either :func:`since_start` or :func:`since_last_check` is a checking | ||
operation. | ||
Returns (float): Time in seconds. | ||
""" | ||
if not self._is_running: | ||
raise ValueError('timer is not running') | ||
dur = time() - self._t_last | ||
self._t_last = time() | ||
return dur | ||
|
||
|
||
class ProgressBar: | ||
"""A progress bar which can print the progress.""" | ||
def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout): | ||
self.task_num = task_num | ||
self.bar_width = bar_width | ||
self.completed = 0 | ||
self.file = file | ||
if start: | ||
self.start() | ||
|
||
@property | ||
def terminal_width(self): | ||
width, _ = get_terminal_size() | ||
return width | ||
|
||
def start(self): | ||
if self.task_num > 0: | ||
self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, ' | ||
'elapsed: 0s, ETA:') | ||
else: | ||
self.file.write('completed: 0, elapsed: 0s') | ||
self.file.flush() | ||
self.timer = Timer() | ||
|
||
def update(self, num_tasks=1): | ||
assert num_tasks > 0 | ||
self.completed += num_tasks | ||
elapsed = self.timer.since_start() | ||
if elapsed > 0: | ||
fps = self.completed / elapsed | ||
else: | ||
fps = float('inf') | ||
if self.task_num > 0: | ||
percentage = self.completed / float(self.task_num) | ||
eta = int(elapsed * (1 - percentage) / percentage + 0.5) | ||
msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \ | ||
f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \ | ||
f'ETA: {eta:5}s' | ||
|
||
bar_width = min(self.bar_width, | ||
int(self.terminal_width - len(msg)) + 2, | ||
int(self.terminal_width * 0.6)) | ||
bar_width = max(2, bar_width) | ||
mark_width = int(bar_width * percentage) | ||
bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width) | ||
self.file.write(msg.format(bar_chars)) | ||
else: | ||
self.file.write( | ||
f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,' | ||
f' {fps:.1f} tasks/s') | ||
self.file.flush() | ||
|
||
|
||
def main_extract_subimages(args): | ||
"""A multi-thread tool to crop large images to sub-images for faster IO. | ||
It is used for DIV2K dataset. | ||
args (dict): Configuration dict. It contains: | ||
n_thread (int): Thread number. | ||
compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9. | ||
A higher value means a smaller size and longer compression time. | ||
Use 0 for faster CPU decompression. Default: 3, same in cv2. | ||
input_folder (str): Path to the input folder. | ||
save_folder (str): Path to save folder. | ||
crop_size (int): Crop size. | ||
step (int): Step for overlapped sliding window. | ||
thresh_size (int): Threshold size. Patches whose size is lower | ||
than thresh_size will be dropped. | ||
Usage: | ||
For each folder, run this script. | ||
Typically, there are four folders to be processed for DIV2K dataset. | ||
DIV2K_train_HR | ||
DIV2K_train_LR_bicubic/X2 | ||
DIV2K_train_LR_bicubic/X3 | ||
DIV2K_train_LR_bicubic/X4 | ||
After process, each sub_folder should have the same number of | ||
subimages. | ||
Remember to modify opt configurations according to your settings. | ||
""" | ||
|
||
opt = {} | ||
opt['n_thread'] = args.n_thread | ||
opt['compression_level'] = args.compression_level | ||
|
||
# HR images | ||
opt['input_folder'] = osp.join(args.data_root, 'DIV2K_train_HR') | ||
opt['save_folder'] = osp.join(args.data_root, 'DIV2K_train_HR_sub') | ||
opt['crop_size'] = args.crop_size | ||
opt['step'] = args.step | ||
opt['thresh_size'] = args.thresh_size | ||
extract_subimages(opt) | ||
|
||
for scale in [2, 3, 4]: | ||
opt['input_folder'] = osp.join(args.data_root, | ||
f'DIV2K_train_LR_bicubic/X{scale}') | ||
opt['save_folder'] = osp.join(args.data_root, | ||
f'DIV2K_train_LR_bicubic/X{scale}_sub') | ||
opt['crop_size'] = args.crop_size // scale | ||
opt['step'] = args.step // scale | ||
opt['thresh_size'] = args.thresh_size // scale | ||
extract_subimages(opt) | ||
|
||
|
||
def extract_subimages(opt): | ||
"""Crop images to subimages. | ||
Args: | ||
opt (dict): Configuration dict. It contains: | ||
input_folder (str): Path to the input folder. | ||
save_folder (str): Path to save folder. | ||
n_thread (int): Thread number. | ||
""" | ||
input_folder = opt['input_folder'] | ||
save_folder = opt['save_folder'] | ||
if not osp.exists(save_folder): | ||
os.makedirs(save_folder) | ||
print(f'mkdir {save_folder} ...') | ||
else: | ||
print(f'Folder {save_folder} already exists. Exit.') | ||
sys.exit(1) | ||
|
||
img_list = list(scandir(input_folder)) | ||
img_list = [osp.join(input_folder, v) for v in img_list] | ||
|
||
prog_bar = ProgressBar(len(img_list)) | ||
pool = Pool(opt['n_thread']) | ||
for path in img_list: | ||
pool.apply_async(worker, | ||
args=(path, opt), | ||
callback=lambda arg: prog_bar.update()) | ||
pool.close() | ||
pool.join() | ||
print('All processes done.') | ||
|
||
|
||
def worker(path, opt): | ||
"""Worker for each process. | ||
Args: | ||
path (str): Image path. | ||
opt (dict): Configuration dict. It contains: | ||
crop_size (int): Crop size. | ||
step (int): Step for overlapped sliding window. | ||
thresh_size (int): Threshold size. Patches whose size is smaller | ||
than thresh_size will be dropped. | ||
save_folder (str): Path to save folder. | ||
compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION. | ||
Returns: | ||
process_info (str): Process information displayed in progress bar. | ||
""" | ||
crop_size = opt['crop_size'] | ||
step = opt['step'] | ||
thresh_size = opt['thresh_size'] | ||
img_name, extension = osp.splitext(osp.basename(path)) | ||
|
||
# remove the x2, x3, x4 and x8 in the filename for DIV2K | ||
img_name = re.sub('x[2348]', '', img_name) | ||
|
||
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) | ||
|
||
if img.ndim == 2 or img.ndim == 3: | ||
h, w = img.shape[:2] | ||
else: | ||
raise ValueError(f'Image ndim should be 2 or 3, but got {img.ndim}') | ||
|
||
h_space = np.arange(0, h - crop_size + 1, step) | ||
if h - (h_space[-1] + crop_size) > thresh_size: | ||
h_space = np.append(h_space, h - crop_size) | ||
w_space = np.arange(0, w - crop_size + 1, step) | ||
if w - (w_space[-1] + crop_size) > thresh_size: | ||
w_space = np.append(w_space, w - crop_size) | ||
|
||
index = 0 | ||
for x in h_space: | ||
for y in w_space: | ||
index += 1 | ||
cropped_img = img[x:x + crop_size, y:y + crop_size, ...] | ||
cv2.imwrite( | ||
osp.join(opt['save_folder'], | ||
f'{img_name}_s{index:03d}{extension}'), cropped_img, | ||
[cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']]) | ||
process_info = f'Processing {img_name} ...' | ||
return process_info | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser( | ||
description='Prepare DIV2K dataset', | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||
parser.add_argument('--data-root', help='dataset root') | ||
parser.add_argument('--crop-size', | ||
nargs='?', | ||
default=480, | ||
help='cropped size for HR images') | ||
parser.add_argument('--step', | ||
nargs='?', | ||
default=240, | ||
help='step size for HR images') | ||
parser.add_argument('--thresh-size', | ||
nargs='?', | ||
default=0, | ||
help='threshold size for HR images') | ||
parser.add_argument('--compression-level', | ||
nargs='?', | ||
default=3, | ||
help='compression level when save png images') | ||
parser.add_argument('--n-thread', | ||
nargs='?', | ||
default=20, | ||
help='thread number when using multiprocessing') | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
if __name__ == '__main__': | ||
args = parse_args() | ||
# extract subimages | ||
main_extract_subimages(args) |
Oops, something went wrong.