-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add training of first order motion #256
Conversation
@@ -0,0 +1,250 @@ | |||
import logging |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add copyright
ppgan/datasets/firstorder_dataset.py
Outdated
import tqdm | ||
from imageio import imread, mimread, imwrite | ||
from paddle.io import Dataset | ||
from sklearn.model_selection import train_test_split |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里没必要引入 sklearn包吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果是用户训练自己的数据,需要对数据进行随机分train和test吧
ppgan/datasets/firstorder_dataset.py
Outdated
from paddle.io import Dataset | ||
from sklearn.model_selection import train_test_split | ||
from skimage.transform import resize | ||
from skimage import io, img_as_float32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
skimage是否必须? 如果是,需要 https://github.com/PaddlePaddle/PaddleGAN/blob/develop/requirements.txt 增加依赖
|
||
@DATASETS.register() | ||
class FirstOrderDataset(Dataset): | ||
def __init__(self, **cfg): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add comments, cfg都包含那些字段?觉得还是铺开好一些,否则不知道要输入那些参数
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
输入参数不是一般不要超过6个么,这里有9个,是不是铺开不太好?
ppgan/datasets/firstorder_dataset.py
Outdated
out['name'] = video_name | ||
return out | ||
|
||
def getSample(self, idx): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getSample -> get_sample
|
||
class DownBlock2d(nn.Layer): | ||
""" | ||
Simple block for processing video (encoder). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add comments for input arguments
return out | ||
|
||
|
||
class Discriminator(nn.Layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add comments
@@ -0,0 +1,282 @@ | |||
import paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add copyright
ppgan/models/firstorder_model.py
Outdated
import paddle.nn as nn | ||
import numpy as np | ||
from skimage.draw import circle | ||
import matplotlib.pyplot as plt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
matplotlib还是lazy import吧,否则ppgan整个得依赖matplotlib
return loss_values, generated | ||
|
||
|
||
class Vgg19(nn.Layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Vgg19 -> VGG19
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this reuse criterion/perceptual_loss.py?
kernel_size=kernel_size) | ||
if sn: | ||
self.conv = spectral_norm(self.conv) | ||
#if sn: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete
ppgan/modules/first_order.py
Outdated
@@ -285,6 +319,19 @@ def forward(self, input): | |||
|
|||
out = F.pad(input, [self.ka, self.kb, self.ka, self.kb]) | |||
out = F.conv2d(out, weight=self.weight, groups=self.groups) | |||
out = F.interpolate(out, scale_factor=[self.scale, self.scale]) | |||
out.stop_gradient = False | |||
# The high version of pytorch has a bug that affects the convergence of this model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
clean the code
… rid of sklearn, skimage
Fix #253