Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add training of first order motion #256

Merged
merged 6 commits into from
Apr 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions applications/tools/first-order-demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@
type=str,
default='sfd',
help="face detector to be used, can choose s3fd or blazeface")
parser.add_argument("--multi_person",
dest="multi_person",
action="store_true",
default=False,
help="whether there is only one person in the image or not")

parser.set_defaults(relative=False)
parser.set_defaults(adapt_scale=False)
Expand All @@ -72,7 +77,6 @@

if args.cpu:
paddle.set_device('cpu')

predictor = FirstOrderPredictor(output=args.output,
filename=args.filename,
weight_path=args.weight_path,
Expand All @@ -82,5 +86,6 @@
find_best_frame=args.find_best_frame,
best_frame=args.best_frame,
ratio=args.ratio,
face_detector=args.face_detector)
face_detector=args.face_detector,
multi_person=args.multi_person)
predictor.run(args.source_image, args.driving_video)
106 changes: 106 additions & 0 deletions configs/firstorder_fashion.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
epochs: 150
output_dir: output_dir

model:
name: FirstOrderModel
common_params:
num_kp: 10
num_channels: 3
estimate_jacobian: True
generator:
name: FirstOrderGenerator
kp_detector_cfg:
temperature: 0.1
block_expansion: 32
max_features: 1024
scale_factor: 0.25
num_blocks: 5
generator_cfg:
block_expansion: 64
max_features: 512
num_down_blocks: 2
num_bottleneck_blocks: 6
estimate_occlusion_map: True
dense_motion_params:
block_expansion: 64
max_features: 1024
num_blocks: 5
scale_factor: 0.25
discriminator:
name: FirstOrderDiscriminator
discriminator_cfg:
scales: [1]
block_expansion: 32
max_features: 512
num_blocks: 4
sn: True
train_params:
scales: [1, 0.5, 0.25, 0.125]
transform_params:
sigma_affine: 0.05
sigma_tps: 0.005
points_tps: 5
loss_weights:
generator_gan: 1
discriminator_gan: 1
feature_matching: [10, 10, 10, 10]
perceptual: [10, 10, 10, 10, 10]
equivariance_value: 10
equivariance_jacobian: 10

optimizer:
name: Adam

lr_scheduler:
epoch_milestones: [187500, 281250]
lr_generator: 2.0e-4
lr_discriminator: 2.0e-4
lr_kp_detector: 2.0e-4

dataset:
train:
name: FirstOrderDataset
phase: train
dataroot: data/first_order/fashion/
num_repeats: 50
time_flip: True
batch_size: 8
id_sampling: False
frame_shape: [ 256, 256, 3 ]
process_time: False
create_frames_folder: False
num_workers: 4
max_dataset_size: inf
direction: BtoA
input_nc: 3
output_nc: 3
transforms:
- name: PairedRandomHorizontalFlip
prob: 0.5
keys: [image, image]
- name: PairedColorJitter
brightness: 0.1
contrast: 0.1
saturation: 0.1
hue: 0.1
keys: [image, image]
test:
name: FirstOrderDataset
dataroot: data/first_order/fashion/
phase: test
batch_size: 1
num_workers: 1
time_flip: False
id_sampling: False
create_frames_folder: False
frame_shape: [ 256, 256, 3 ]

log_config:
interval: 10
visiual_interval: 10

snapshot_config:
interval: 10

validate:
interval: 31250
163 changes: 114 additions & 49 deletions ppgan/apps/first_order_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import imageio
import numpy as np
from tqdm import tqdm
from skimage import img_as_ubyte
from skimage.transform import resize
from scipy.spatial import ConvexHull

import paddle
Expand All @@ -47,37 +45,41 @@ def __init__(self,
best_frame=None,
ratio=1.0,
filename='result.mp4',
face_detector='sfd'):
face_detector='sfd',
multi_person=False):
if config is not None and isinstance(config, str):
self.cfg = yaml.load(config, Loader=yaml.SafeLoader)
with open(config) as f:
self.cfg = yaml.load(f, Loader=yaml.SafeLoader)
elif isinstance(config, dict):
self.cfg = config
elif config is None:
self.cfg = {
'model_params': {
'model': {
'common_params': {
'num_kp': 10,
'num_channels': 3,
'estimate_jacobian': True
},
'kp_detector_params': {
'temperature': 0.1,
'block_expansion': 32,
'max_features': 1024,
'scale_factor': 0.25,
'num_blocks': 5
},
'generator_params': {
'block_expansion': 64,
'max_features': 512,
'num_down_blocks': 2,
'num_bottleneck_blocks': 6,
'estimate_occlusion_map': True,
'dense_motion_params': {
'block_expansion': 64,
'generator': {
'kp_detector_cfg': {
'temperature': 0.1,
'block_expansion': 32,
'max_features': 1024,
'num_blocks': 5,
'scale_factor': 0.25
'scale_factor': 0.25,
'num_blocks': 5
},
'generator_cfg': {
'block_expansion': 64,
'max_features': 512,
'num_down_blocks': 2,
'num_bottleneck_blocks': 6,
'estimate_occlusion_map': True,
'dense_motion_params': {
'block_expansion': 64,
'max_features': 1024,
'num_blocks': 5,
'scale_factor': 0.25
}
}
}
}
Expand All @@ -99,28 +101,10 @@ def __init__(self,
self.face_detector = face_detector
self.generator, self.kp_detector = self.load_checkpoints(
self.cfg, self.weight_path)
self.multi_person = multi_person

def run(self, source_image, driving_video):
source_image = imageio.imread(source_image)
bboxes = self.extract_bbox(source_image.copy())
reader = imageio.get_reader(driving_video)
fps = reader.get_meta_data()['fps']
driving_video = []
try:
for im in reader:
driving_video.append(im)
except RuntimeError:
pass
reader.close()

driving_video = [
resize(frame, (256, 256))[..., :3] for frame in driving_video
]
results = []
for rec in bboxes:
face_image = source_image.copy()[rec[1]:rec[3], rec[0]:rec[2]]
face_image = resize(face_image, (256, 256))

def get_prediction(face_image):
if self.find_best_frame or self.best_frame is not None:
i = self.best_frame if self.best_frame is not None else self.find_best_frame_func(
source_image, driving_video)
Expand Down Expand Up @@ -152,15 +136,60 @@ def run(self, source_image, driving_video):
self.kp_detector,
relative=self.relative,
adapt_movement_scale=self.adapt_scale)
return predictions

source_image = imageio.imread(source_image)
reader = imageio.get_reader(driving_video)
fps = reader.get_meta_data()['fps']
driving_video = []
try:
for im in reader:
driving_video.append(im)
except RuntimeError:
pass
reader.close()

driving_video = [
cv2.resize(frame, (256, 256)) / 255.0 for frame in driving_video
]
results = []

# for single person
if not self.multi_person:
h, w, _ = source_image.shape
source_image = cv2.resize(source_image, (256, 256)) / 255.0
predictions = get_prediction(source_image)
imageio.mimsave(os.path.join(self.output, self.filename), [
cv2.resize((frame * 255.0).astype('uint8'), (h, w))
for frame in predictions
])
return

bboxes = self.extract_bbox(source_image.copy())
print(str(len(bboxes)) + " persons have been detected")
if len(bboxes) <= 1:
h, w, _ = source_image.shape
source_image = cv2.resize(source_image, (256, 256)) / 255.0
predictions = get_prediction(source_image)
imageio.mimsave(os.path.join(self.output, self.filename), [
cv2.resize((frame * 255.0).astype('uint8'), (h, w))
for frame in predictions
])
return

# for multi person
for rec in bboxes:
face_image = source_image.copy()[rec[1]:rec[3], rec[0]:rec[2]]
face_image = cv2.resize(face_image, (256, 256)) / 255.0
predictions = get_prediction(face_image)
results.append({'rec': rec, 'predict': predictions})

out_frame = []

for i in range(len(driving_video)):
frame = source_image.copy()
for result in results:
x1, y1, x2, y2 = result['rec']
x1, y1, x2, y2, _ = result['rec']
h = y2 - y1
w = x2 - x1
out = result['predict'][i] * 255.0
Expand All @@ -185,11 +214,12 @@ def run(self, source_image, driving_video):
def load_checkpoints(self, config, checkpoint_path):

generator = OcclusionAwareGenerator(
**config['model_params']['generator_params'],
**config['model_params']['common_params'])
**config['model']['generator']['generator_cfg'],
**config['model']['common_params'])

kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
**config['model_params']['common_params'])
kp_detector = KPDetector(
**config['model']['generator']['kp_detector_cfg'],
**config['model']['common_params'])

checkpoint = paddle.load(self.weight_path)
generator.set_state_dict(checkpoint['generator'])
Expand Down Expand Up @@ -269,7 +299,11 @@ def extract_bbox(self, image):

frame = [image]
predictions = detector.get_detections_for_image(np.array(frame))
person_num = len(predictions)
if person_num == 0:
return np.array([])
results = []
face_boxs = []
h, w, _ = image.shape
for rect in predictions:
bh = rect[3] - rect[1]
Expand All @@ -281,6 +315,37 @@ def extract_bbox(self, image):
x1 = max(0, cx - int(0.8 * margin))
y2 = min(h, cy + margin)
x2 = min(w, cx + int(0.8 * margin))
results.append([x1, y1, x2, y2])
boxes = np.array(results)
area = (y2 - y1) * (x2 - x1)
results.append([x1, y1, x2, y2, area])
# if a person has more than one bbox, keep the largest one
# maybe greedy will be better?
sorted(results, key=lambda area: area[4], reverse=True)
results_box = [results[0]]
for i in range(1, person_num):
num = len(results_box)
add_person = True
for j in range(num):
pre_person = results_box[j]
iou = self.IOU(pre_person[0], pre_person[1], pre_person[2],
pre_person[3], pre_person[4], results[i][0],
results[i][1], results[i][2], results[i][3],
results[i][4])
if iou > 0.5:
add_person = False
break
if add_person:
results_box.append(results[i])
boxes = np.array(results_box)
return boxes

def IOU(self, ax1, ay1, ax2, ay2, sa, bx1, by1, bx2, by2, sb):
#sa = abs((ax2 - ax1) * (ay2 - ay1))
#sb = abs((bx2 - bx1) * (by2 - by1))
x1, y1 = max(ax1, bx1), max(ay1, by1)
x2, y2 = min(ax2, bx2), min(ay2, by2)
w = x2 - x1
h = y2 - y1
if w < 0 or h < 0:
return 0.0
else:
return 1.0 * w * h / (sa + sb - w * h)
1 change: 1 addition & 0 deletions ppgan/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
from .wav2lip_dataset import Wav2LipDataset
from .starganv2_dataset import StarGANv2Dataset
from .edvr_dataset import REDSDataset
from .firstorder_dataset import FirstOrderDataset
Loading