Skip to content

Commit

Permalink
[TIPC] Fix Benchmark (PaddlePaddle#684)
Browse files Browse the repository at this point in the history
* Fix amp bug for distributed train & Support amp for edvr

* Change log save dir

* Change data paralle

* Invoke CI

* Re-invoke CI

* Add requirement install in prepare.sh
  • Loading branch information
Birdylx authored Aug 30, 2022
1 parent 6716ae5 commit 0541ace
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 31 deletions.
56 changes: 32 additions & 24 deletions ppgan/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,6 @@ def __init__(self, cfg):

# build model
self.model = build_model(cfg.model)
# multiple gpus prepare
if ParallelEnv().nranks > 1:
self.distributed_data_parallel()

# build metrics
self.metrics = None
Expand All @@ -121,10 +118,6 @@ def __init__(self, cfg):
import visualdl
self.vdl_logger = visualdl.LogWriter(logdir=cfg.output_dir)

# evaluate only
if not cfg.is_train:
return

# build train dataloader
self.train_dataloader = build_dataloader(cfg.dataset.train)
self.iters_per_epoch = len(self.train_dataloader)
Expand All @@ -139,6 +132,17 @@ def __init__(self, cfg):
self.optimizers = self.model.setup_optimizers(self.lr_schedulers,
cfg.optimizer)

# setup amp train
self.scaler = self.setup_amp_train() if self.cfg.amp else None

# multiple gpus prepare
if ParallelEnv().nranks > 1:
self.distributed_data_parallel()

# evaluate only
if not cfg.is_train:
return

self.epochs = cfg.get('epochs', None)
if self.epochs:
self.total_iters = self.epochs * self.iters_per_epoch
Expand All @@ -159,6 +163,26 @@ def __init__(self, cfg):
self.model.set_total_iter(self.total_iters)
self.profiler_options = cfg.profiler_options

def setup_amp_train(self):
""" decerate model, optimizer and return a GradScaler """

self.logger.info('use AMP to train. AMP level = {}'.format(
self.cfg.amp_level))
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
# need to decorate model and optim if amp_level == 'O2'
if self.cfg.amp_level == 'O2':
nets, optimizers = list(self.model.nets.values()), list(
self.optimizers.values())
nets, optimizers = paddle.amp.decorate(models=nets,
optimizers=optimizers,
level='O2',
save_dtype='float32')
for i, (k, _) in enumerate(self.model.nets.items()):
self.model.nets[k] = nets[i]
for i, (k, _) in enumerate(self.optimizers.items()):
self.optimizers[k] = optimizers[i]
return scaler

def distributed_data_parallel(self):
paddle.distributed.init_parallel_env()
find_unused_parameters = self.cfg.get('find_unused_parameters', False)
Expand All @@ -183,22 +207,6 @@ def train(self):

iter_loader = IterLoader(self.train_dataloader)

# use amp
if self.cfg.amp:
self.logger.info('use AMP to train. AMP level = {}'.format(
self.cfg.amp_level))
assert self.cfg.model.name == 'MultiStageVSRModel', "AMP only support msvsr model"
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
# need to decorate model and optim if amp_level == 'O2'
if self.cfg.amp_level == 'O2':
# msvsr has only one generator and one optimizer
self.model.nets['generator'], self.optimizers[
'optim'] = paddle.amp.decorate(
models=self.model.nets['generator'],
optimizers=self.optimizers['optim'],
level='O2',
save_dtype='float32')

# set model.is_train = True
self.model.setup_train_mode(is_train=True)
while self.current_iter < (self.total_iters + 1):
Expand All @@ -215,7 +223,7 @@ def train(self):
self.model.setup_input(data)

if self.cfg.amp:
self.model.train_iter_amp(self.optimizers, scaler,
self.model.train_iter_amp(self.optimizers, self.scaler,
self.cfg.amp_level) # amp train
else:
self.model.train_iter(self.optimizers) # norm train
Expand Down
30 changes: 30 additions & 0 deletions ppgan/models/edvr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class EDVRModel(BaseSRModel):
Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks.
"""

def __init__(self, generator, tsa_iter, pixel_criterion=None):
"""Initialize the EDVR class.
Expand Down Expand Up @@ -74,8 +75,37 @@ def train_iter(self, optims=None):
optims['optim'].step()
self.current_iter += 1

# amp train with brute force implementation
def train_iter_amp(self, optims=None, scaler=None, amp_level='O1'):
optims['optim'].clear_grad()
if self.tsa_iter:
if self.current_iter == 1:
print('Only train TSA module for', self.tsa_iter, 'iters.')
for name, param in self.nets['generator'].named_parameters():
if 'TSAModule' not in name:
param.trainable = False
elif self.current_iter == self.tsa_iter + 1:
print('Train all the parameters.')
for param in self.nets['generator'].parameters():
param.trainable = True

# put loss computation in amp context
with paddle.amp.auto_cast(enable=True, level=amp_level):
self.output = self.nets['generator'](self.lq)
self.visual_items['output'] = self.output
# pixel loss
loss_pixel = self.pixel_criterion(self.output, self.gt)
self.losses['loss_pixel'] = loss_pixel

scaled_loss = scaler.scale(loss_pixel)
scaled_loss.backward()
scaler.minimize(optims['optim'], scaled_loss)

self.current_iter += 1


def init_edvr_weight(net):

def reset_func(m):
if hasattr(m, 'weight') and (not isinstance(
m, (nn.BatchNorm, nn.BatchNorm2D))) and (
Expand Down
6 changes: 3 additions & 3 deletions ppgan/models/msvsr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def train_iter(self, optims=None):

self.current_iter += 1

# amp train with brute force implementation, maybe decorator can simplify this
# amp train with brute force implementation
def train_iter_amp(self, optims=None, scaler=None, amp_level='O1'):
optims['optim'].clear_grad()
if self.fix_iter:
Expand Down Expand Up @@ -131,9 +131,9 @@ def train_iter_amp(self, optims=None, scaler=None, amp_level='O1'):

self.loss = sum(_value for _key, _value in self.losses.items()
if 'loss_pix' in _key)
scaled_loss = scaler.scale(self.loss)
self.losses['loss'] = self.loss
self.losses['loss'] = self.loss

scaled_loss = scaler.scale(self.loss)
scaled_loss.backward()
scaler.minimize(optims['optim'], scaled_loss)

Expand Down
32 changes: 32 additions & 0 deletions test_tipc/configs/edvr/train_amp_infer_python.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
===========================train_params===========================
model_name:edvr
python:python3.7
gpu_list:0
##
auto_cast:null
total_iters:lite_train_lite_infer=100
output_dir:./output/
dataset.train.batch_size:lite_train_lite_infer=4
pretrained_model:null
train_model_name:basicvsr_reds*/*checkpoint.pdparams
train_infer_img_dir:./data/basicvsr_reds/test
null:null
##
trainer:amp_train
amp_train:tools/main.py --amp --amp_level O2 -c configs/edvr_m_wo_tsa.yaml --seed 123 -o log_config.interval=5
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================train_benchmark_params==========================
batch_size:4|64
fp_items:fp32
total_iters:100
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_cudnn_exhaustive_search=1
6 changes: 3 additions & 3 deletions test_tipc/configs/msvsr/train_infer_python.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ train_infer_img_dir:./data/msvsr_reds/test
null:null
##
trainer:norm_train
norm_train:tools/main.py -c configs/msvsr_reds.yaml --seed 123 -o log_config.interval=2 snapshot_config.interval=50 dataset.train.dataset.num_frames=15
norm_train:tools/main.py -c configs/msvsr_reds.yaml --seed 123 -o log_config.interval=1 snapshot_config.interval=5 dataset.train.dataset.num_frames=15
pact_train:null
fpgm_train:null
distill_train:null
Expand Down Expand Up @@ -50,10 +50,10 @@ null:null
--benchmark:True
null:null
===========================train_benchmark_params==========================
batch_size:4
batch_size:2|4
fp_items:fp32
total_iters:60
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:null
flags:FLAGS_cudnn_exhaustive_search=1
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[2,3,180,320]}]
1 change: 1 addition & 0 deletions test_tipc/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ model_name=$(func_parser_value "${lines[1]}")
trainer_list=$(func_parser_value "${lines[14]}")

if [ ${MODE} = "benchmark_train" ];then
pip install -r requirements.txt
MODE="lite_train_lite_infer"
fi

Expand Down
2 changes: 1 addition & 1 deletion test_tipc/test_train_inference_python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ else
set_export_weight="${save_log}/${train_model_name}"
set_export_weight_path=$( echo ${set_export_weight})
set_save_infer_key="${save_infer_key} ${save_infer_path}"
export_cmd="${python} ${run_export} ${set_export_weight_path} ${set_save_infer_key}"
export_cmd="${python} ${run_export} ${set_export_weight_path} ${set_save_infer_key} > ${save_log}_export.log 2>&1"
eval "$export_cmd"
status_check $? "${export_cmd}" "${status_log}"

Expand Down

0 comments on commit 0541ace

Please sign in to comment.