Skip to content

Commit

Permalink
Add static model and inference of stylegan2,fom, basicvsr (PaddlePadd…
Browse files Browse the repository at this point in the history
…le#491)

* Update benchmark.yaml

* Update benchmark.yaml

* add static model and inference of fom, basicvsr, stylegan2

* add static model and inference of fom, basicvsr, stylegan2

* fix basicvsr dataset for small datasets

* fix basicvsr dataset for small datasets

* fix basicvsr dataset for small datasets

* fix basicvsr dataset for small datasets
  • Loading branch information
lzzyzlbb authored Nov 22, 2021
1 parent 6e3dad3 commit 2ab96cb
Show file tree
Hide file tree
Showing 11 changed files with 271 additions and 143 deletions.
2 changes: 1 addition & 1 deletion benchmark/benchmark.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ FOMM:
fp_item: fp32
bs_item: 8 16
epochs: 1
log_interval: 11
log_interval: 1

esrgan:
dataset_web: https://paddlegan.bj.bcebos.com/datasets/DIV2KandSet14.tar
Expand Down
4 changes: 4 additions & 0 deletions configs/basicvsr_reds.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dataset:
use_rot: True
scale: 4
val_partition: REDS4
num_clips: 270

test:
name: SRREDSMultipleGTDataset
Expand Down Expand Up @@ -90,3 +91,6 @@ log_config:

snapshot_config:
interval: 5000

export_model:
- {name: 'generator', inputs_num: 1}
4 changes: 4 additions & 0 deletions configs/cyclegan_horse2zebra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,7 @@ log_config:

snapshot_config:
interval: 5

export_model:
- {name: 'netG_A', inputs_num: 1}
- {name: 'netG_B', inputs_num: 1}
3 changes: 3 additions & 0 deletions configs/firstorder_vox_256.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,6 @@ snapshot_config:

optimizer:
name: Adam

export_model:
- {}
3 changes: 3 additions & 0 deletions configs/pix2pix_facades.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,6 @@ validate:
fid: # metric name, can be arbitrary
name: FID
batch_size: 8

export_model:
- {name: 'netG', inputs_num: 1}
10 changes: 7 additions & 3 deletions ppgan/datasets/sr_reds_multiple_gt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def __init__(self,
use_rot=False,
scale=4,
val_partition='REDS4',
batch_size=4):
batch_size=4,
num_clips=270):
super(SRREDSMultipleGTDataset, self).__init__()
self.mode = mode
self.fileroot = str(lq_folder)
Expand All @@ -69,6 +70,7 @@ def __init__(self,
self.scale = scale
self.val_partition = val_partition
self.batch_size = batch_size
self.num_clips = num_clips # training num of LQ and GT pairs
self.data_infos = self.load_annotations()

def __getitem__(self, idx):
Expand All @@ -93,7 +95,7 @@ def load_annotations(self):
dict: Returned dict for LQ and GT pairs.
"""
# generate keys
keys = [f'{i:03d}' for i in range(0, 270)]
keys = [f'{i:03d}' for i in range(0, self.num_clips)]

if self.val_partition == 'REDS4':
val_partition = ['000', '011', '015', '020']
Expand Down Expand Up @@ -170,7 +172,9 @@ def get_sample_data(self,
gt_list = rlt[number_frames:]

# stack LQ images to NHWC, N is the frame number
frame_list = [v.transpose(2, 0, 1).astype('float32') for v in frame_list]
frame_list = [
v.transpose(2, 0, 1).astype('float32') for v in frame_list
]
gt_list = [v.transpose(2, 0, 1).astype('float32') for v in gt_list]

img_LQs = np.stack(frame_list, axis=0)
Expand Down
17 changes: 10 additions & 7 deletions ppgan/models/generators/basicvsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ def flow_warp(x,
Returns:
Tensor: Warped image or feature map.
"""
if x.shape[-2:] != flow.shape[1:3]:
x_h, x_w = x.shape[-2:]
flow_h, flow_w = flow.shape[1:3]
if x_h != flow_h or x_w != flow_w:
raise ValueError(f'The spatial sizes of input ({x.shape[-2:]}) and '
f'flow ({flow.shape[1:3]}) are not the same.')
_, _, h, w = x.shape
Expand Down Expand Up @@ -293,7 +295,7 @@ def compute_flow(self, ref, supp):
supp = supp[::-1]

# flow computation
flow = paddle.to_tensor(np.zeros([n, 2, h // 32, w // 32], 'float32'))
flow = paddle.zeros([n, 2, h // 32, w // 32])

# level=0
flow_up = flow
Expand Down Expand Up @@ -555,6 +557,7 @@ def forward(self, lrs):
"""

n, t, c, h, w = lrs.shape
t = paddle.to_tensor(t)
assert h >= 64 and w >= 64, (
'The height and width of inputs should be at least 64, '
f'but got {h} and {w}.')
Expand All @@ -567,19 +570,18 @@ def forward(self, lrs):

# backward-time propgation
outputs = []
feat_prop = paddle.to_tensor(
np.zeros([n, self.mid_channels, h, w], 'float32'))
feat_prop = paddle.zeros([n, self.mid_channels, h, w])
for i in range(t - 1, -1, -1):
if i < t - 1: # no warping required for the last timestep
flow = flows_backward[:, i, :, :, :]
feat_prop = flow_warp(feat_prop, flow.transpose([0, 2, 3, 1]))
flow1 = flows_backward[:, i, :, :, :]
feat_prop = flow_warp(feat_prop, flow1.transpose([0, 2, 3, 1]))

feat_prop = paddle.concat([lrs[:, i, :, :, :], feat_prop], axis=1)
feat_prop = self.backward_resblocks(feat_prop)

outputs.append(feat_prop)
outputs = outputs[::-1]

# forward-time propagation and upsampling
feat_prop = paddle.zeros_like(feat_prop)
for i in range(0, t):
Expand Down Expand Up @@ -610,6 +612,7 @@ def forward(self, lrs):

class SecondOrderDeformableAlignment(nn.Layer):
"""Second-order deformable alignment module.
Args:
in_channels (int): Same as nn.Conv2d.
out_channels (int): Same as nn.Conv2d.
Expand Down
62 changes: 38 additions & 24 deletions ppgan/models/generators/generator_styleganv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ class PixelNorm(nn.Layer):
def __init__(self):
super().__init__()

def forward(self, input):
return input * paddle.rsqrt(
paddle.mean(input * input, 1, keepdim=True) + 1e-8)
def forward(self, inputs):
return inputs * paddle.rsqrt(
paddle.mean(inputs * inputs, 1, keepdim=True) + 1e-8)


class ModulatedConv2D(nn.Layer):
Expand Down Expand Up @@ -93,8 +93,8 @@ def __repr__(self):
f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
f"upsample={self.upsample}, downsample={self.downsample})")

def forward(self, input, style):
batch, in_channel, height, width = input.shape
def forward(self, inputs, style):
batch, in_channel, height, width = inputs.shape

style = self.modulation(style).reshape((batch, 1, in_channel, 1, 1))
weight = self.scale * self.weight * style
Expand All @@ -107,13 +107,13 @@ def forward(self, input, style):
self.kernel_size, self.kernel_size))

if self.upsample:
input = input.reshape((1, batch * in_channel, height, width))
inputs = inputs.reshape((1, batch * in_channel, height, width))
weight = weight.reshape((batch, self.out_channel, in_channel,
self.kernel_size, self.kernel_size))
weight = weight.transpose((0, 2, 1, 3, 4)).reshape(
(batch * in_channel, self.out_channel, self.kernel_size,
self.kernel_size))
out = F.conv2d_transpose(input,
out = F.conv2d_transpose(inputs,
weight,
padding=0,
stride=2,
Expand All @@ -123,16 +123,16 @@ def forward(self, input, style):
out = self.blur(out)

elif self.downsample:
input = self.blur(input)
_, _, height, width = input.shape
input = input.reshape((1, batch * in_channel, height, width))
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
inputs = self.blur(inputs)
_, _, height, width = inputs.shape
inputs = inputs.reshape((1, batch * in_channel, height, width))
out = F.conv2d(inputs, weight, padding=0, stride=2, groups=batch)
_, _, height, width = out.shape
out = out.reshape((batch, self.out_channel, height, width))

else:
input = input.reshape((1, batch * in_channel, height, width))
out = F.conv2d(input, weight, padding=self.padding, groups=batch)
inputs = inputs.reshape((1, batch * in_channel, height, width))
out = F.conv2d(inputs, weight, padding=self.padding, groups=batch)
_, _, height, width = out.shape
out = out.reshape((batch, self.out_channel, height, width))

Expand Down Expand Up @@ -165,8 +165,8 @@ def __init__(self, channel, size=4):
(1, channel, size, size),
default_initializer=nn.initializer.Normal())

def forward(self, input):
batch = input.shape[0]
def forward(self, inputs):
batch = inputs.shape[0]
out = self.input.tile((batch, 1, 1, 1))

return out
Expand Down Expand Up @@ -198,8 +198,8 @@ def __init__(self,
self.activate = FusedLeakyReLU(out_channel *
2 if is_concat else out_channel)

def forward(self, input, style, noise=None):
out = self.conv(input, style)
def forward(self, inputs, style, noise=None):
out = self.conv(inputs, style)
out = self.noise(out, noise=noise)
out = self.activate(out)

Expand All @@ -225,8 +225,8 @@ def __init__(self,
self.bias = self.create_parameter((1, 3, 1, 1),
nn.initializer.Constant(0.0))

def forward(self, input, style, skip=None):
out = self.conv(input, style)
def forward(self, inputs, style, skip=None):
out = self.conv(inputs, style)
out = out + self.bias

if skip is not None:
Expand Down Expand Up @@ -349,15 +349,28 @@ def mean_latent(self, n_latent):

return latent

def get_latent(self, input):
return self.style(input)
def get_latent(self, inputs):
return self.style(inputs)

def get_mean_style(self):
mean_style = None
with paddle.no_grad():
for i in range(10):
style = self.mean_latent(1024)
if mean_style is None:
mean_style = style
else:
mean_style += style

mean_style /= 10
return mean_style

def forward(
self,
styles,
return_latents=False,
inject_index=None,
truncation=1,
truncation=1.0,
truncation_latent=None,
input_is_latent=False,
noise=None,
Expand All @@ -375,9 +388,10 @@ def forward(
for i in range(self.num_layers)
]

if truncation < 1:
if truncation < 1.0:
style_t = []

if truncation_latent is None:
truncation_latent = self.get_mean_style()
for style in styles:
style_t.append(truncation_latent + truncation *
(style - truncation_latent))
Expand Down
25 changes: 24 additions & 1 deletion ppgan/models/styleganv2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import math
import random
import paddle
Expand All @@ -24,6 +25,7 @@
from ..solver import build_lr_scheduler, build_optimizer



def r1_penalty(real_pred, real_img):
"""
R1 regularization for discriminator. The core idea is to
Expand Down Expand Up @@ -195,7 +197,6 @@ def make_noise(self, batch, num_noise):
noises = []
for _ in range(num_noise):
noises.append(paddle.randn([batch, self.num_style_feat]))

return noises

def mixing_noise(self, batch, prob):
Expand Down Expand Up @@ -294,3 +295,25 @@ def test_iter(self, metrics=None):
metric.update(fake_img, self.real_img)
self.nets['gen_ema'].train()

class InferGenerator(paddle.nn.Layer):
def set_generator(self, generator):
self.generator = generator

def forward(self, style, truncation):
truncation_latent = self.generator.get_mean_style()
out = self.generator(styles=style,
truncation=truncation,
truncation_latent=truncation_latent)
return out[0]

def export_model(self,
export_model=None,
output_dir=None,
inputs_size=[[1, 1, 512], [1, 1]]):
infer_generator = self.InferGenerator()
infer_generator.set_generator(self.nets['gen'])
style = paddle.rand(shape=inputs_size[0], dtype='float32')
truncation = paddle.rand(shape=inputs_size[1], dtype='float32')
paddle.jit.save(infer_generator,
os.path.join(output_dir, "stylegan2model_gen"),
input_spec=[style, truncation])
Loading

0 comments on commit 2ab96cb

Please sign in to comment.