Skip to content

Commit

Permalink
fix tipc (PaddlePaddle#517)
Browse files Browse the repository at this point in the history
* fix tipc

* add cyclegan eval
  • Loading branch information
lzzyzlbb authored Dec 8, 2021
1 parent 79a6419 commit ad631f4
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 9 deletions.
11 changes: 8 additions & 3 deletions configs/cyclegan_horse2zebra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ log_config:
snapshot_config:
interval: 5

export_model:
- {name: 'netG_A', inputs_num: 1}
- {name: 'netG_B', inputs_num: 1}
validate:
interval: 30000
save_img: false
metrics:
fid: # metric name, can be arbitrary
name: FID
batch_size: 8

10 changes: 10 additions & 0 deletions ppgan/models/cycle_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,13 @@ def train_iter(self, optimizers=None):
self.backward_D_B()
# update D_A and D_B's weights
optimizers['optimD'].step()


def test_iter(self, metrics=None):
self.nets['netG_A'].eval()
self.forward()
with paddle.no_grad():
if metrics is not None:
for metric in metrics.values():
metric.update(self.fake_B, self.real_B)
self.nets['netG_A'].train()
8 changes: 4 additions & 4 deletions test_tipc/configs/msvsr/train_infer_python.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ train_infer_img_dir:./data/msvsr_reds/test
null:null
##
trainer:norm_train
norm_train:tools/main.py -c configs/msvsr_reds.yaml --seed 123 -o dataset.train.dataset.num_clips=2 dataset.train.num_workers=0 log_config.interval=1 snapshot_config.interval=5
norm_train:tools/main.py -c configs/msvsr_reds.yaml --seed 123 -o dataset.train.dataset.num_clips=2 dataset.train.num_workers=0 log_config.interval=1 snapshot_config.interval=5 dataset.train.dataset.number_frames=2
pact_train:null
fpgm_train:null
distill_train:null
Expand All @@ -27,7 +27,7 @@ null:null
===========================infer_params===========================
--output_dir:./output/
load:null
norm_export:tools/export_model.py -c configs/msvsr_reds.yaml --inputs_size="1,4,3,180,320" --load
norm_export:tools/export_model.py -c configs/msvsr_reds.yaml --inputs_size="1,2,3,180,320" --load
quant_export:null
fpgm_export:null
distill_export:null
Expand All @@ -37,7 +37,7 @@ inference_dir:multistagevsrmodel_generator
train_model:./inference/msvsr/multistagevsrmodel_generator
infer_export:null
infer_quant:False
inference:tools/inference.py --model_type msvsr -c configs/msvsr_reds.yaml --seed 123 -o dataset.test.num_clips=2 dataset.test.number_frames=4 --output_path test_tipc/output/
inference:tools/inference.py --model_type msvsr -c configs/msvsr_reds.yaml --seed 123 -o dataset.test.num_clips=2 dataset.test.number_frames=2 --output_path test_tipc/output/
--device:gpu
null:null
null:null
Expand All @@ -48,4 +48,4 @@ null:null
null:null
null:null
--benchmark:True
null:null
null:null
4 changes: 2 additions & 2 deletions test_tipc/results/python_msvsr_results_fp32.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Metric psnr: 27.3670
Metric ssim: 0.8021
Metric psnr: 24.3250
Metric ssim: 0.6497

0 comments on commit ad631f4

Please sign in to comment.