Skip to content

Commit

Permalink
Add MidDaSv2 in ppgan.apps (PaddlePaddle#118)
Browse files Browse the repository at this point in the history
* Add MidDaSv2 in ppgan.apps
* remove ppgan/apps/midas/run.py
  • Loading branch information
qingqing01 authored Dec 11, 2020
1 parent 2cc72be commit 77b8bac
Show file tree
Hide file tree
Showing 11 changed files with 771 additions and 1 deletion.
44 changes: 44 additions & 0 deletions docs/zh_CN/apis/apps.md
Original file line number Diff line number Diff line change
Expand Up @@ -387,3 +387,47 @@ ppgan.apps.AnimeGANPredictor(output_path='output_dir',weight_path=None,use_adjus
> ```
> **返回值:**
> > - anime_image(numpy.ndarray): 返回风格化后的景色图像

## ppgan.apps.MiDaSPredictor

```pyhton
ppgan.apps.MiDaSPredictor(output=None, weight_path=None)
```

> 单目深度估计模型MiDaSv2, 参考 https://github.com/intel-isl/MiDaS, 论文是 Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer , 论文链接: https://arxiv.org/abs/1907.01341v3
> **示例**
>
> ```python
> from ppgan.apps import MiDaSPredictor
> # if set output, will write depth pfm and png file in output/MiDaS
> model = MiDaSPredictor()
> prediction = model.run()
> ```
>
> 深度图彩色显示:
>
> ```python
> import numpy as np
> import PIL.Image as Image
> import matplotlib as mpl
> import matplotlib.cm as cm
>
> vmax = np.percentile(prediction, 95)
> normalizer = mpl.colors.Normalize(vmin=prediction.min(), vmax=vmax)
> mapper = cm.ScalarMappable(norm=normalizer, cmap='magma')
> colormapped_im = (mapper.to_rgba(prediction)[:, :, :3] * 255).astype(np.uint8)
> im = Image.fromarray(colormapped_im)
> im.save('test_disp.jpeg')
> ```
>
> **参数:**
>
> > - output (str): 输出路径,如果是None,则不保存pfm和png的深度图文件。
> > - weight_path (str): 指定模型路径,默认是None,则会自动下载内置的已经训练好的模型。
> **返回值:**
> > - prediction (numpy.ndarray): 返回预测结果。
> > - pfm_f (str): 如果设置output路径,返回pfm文件保存路径。
> > - png_f (str): 如果设置output路径,返回png文件保存路径。
1 change: 1 addition & 0 deletions ppgan/apps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
from .first_order_predictor import FirstOrderPredictor
from .face_parse_predictor import FaceParsePredictor
from .animegan_predictor import AnimeGANPredictor
from .midas_predictor import MiDaSPredictor
2 changes: 1 addition & 1 deletion ppgan/apps/animegan_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

class AnimeGANPredictor(BasePredictor):
def __init__(self,
output_path='output_dir',
output_path='output',
weight_path=None,
use_adjust_brightness=True):
self.output_path = output_path
Expand Down
12 changes: 12 additions & 0 deletions ppgan/apps/midas/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
## Monocular Depth Estimation


The implemention of MiDasv2 refers to https://github.com/intel-isl/MiDaS.


@article{Ranftl2020,
author = {Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun},
title = {Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer},
journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)},
year = {2020},
}
Empty file added ppgan/apps/midas/__init__.py
Empty file.
164 changes: 164 additions & 0 deletions ppgan/apps/midas/blocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Refer https://github.com/intel-isl/MiDaS

import paddle
import paddle.nn as nn


def _make_encoder(backbone,
features,
use_pretrained,
groups=1,
expand=False,
exportable=True):
if backbone == "resnext101_wsl":
# resnext101_wsl
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
scratch = _make_scratch([256, 512, 1024, 2048],
features,
groups=groups,
expand=expand)
else:
print(f"Backbone '{backbone}' not implemented")
assert False
return pretrained, scratch


def _make_scratch(in_shape, out_shape, groups=1, expand=False):
scratch = nn.Layer()

out_shape1 = out_shape
out_shape2 = out_shape
out_shape3 = out_shape
out_shape4 = out_shape
if expand == True:
out_shape1 = out_shape
out_shape2 = out_shape * 2
out_shape3 = out_shape * 4
out_shape4 = out_shape * 8

scratch.layer1_rn = nn.Conv2D(in_shape[0],
out_shape1,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False,
groups=groups)
scratch.layer2_rn = nn.Conv2D(in_shape[1],
out_shape2,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False,
groups=groups)
scratch.layer3_rn = nn.Conv2D(in_shape[2],
out_shape3,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False,
groups=groups)
scratch.layer4_rn = nn.Conv2D(in_shape[3],
out_shape4,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False,
groups=groups)

return scratch


def _make_resnet_backbone(resnet):
pretrained = nn.Layer()
pretrained.layer1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu,
resnet.maxpool, resnet.layer1)

pretrained.layer2 = resnet.layer2
pretrained.layer3 = resnet.layer3
pretrained.layer4 = resnet.layer4

return pretrained


def _make_pretrained_resnext101_wsl(use_pretrained):
from .resnext import resnext101_32x8d_wsl
resnet = resnext101_32x8d_wsl()
return _make_resnet_backbone(resnet)


class ResidualConvUnit(nn.Layer):
"""Residual convolution module.
"""
def __init__(self, features):
"""Init.
Args:
features (int): number of features
"""
super().__init__()

self.conv1 = nn.Conv2D(features,
features,
kernel_size=3,
stride=1,
padding=1,
bias_attr=True)

self.conv2 = nn.Conv2D(features,
features,
kernel_size=3,
stride=1,
padding=1,
bias_attr=True)

self.relu = nn.ReLU()

def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
x = self.relu(x)
out = self.conv1(x)
out = self.relu(out)
out = self.conv2(out)

return out + x


class FeatureFusionBlock(nn.Layer):
"""Feature fusion block.
"""
def __init__(self, features):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock, self).__init__()

self.resConfUnit1 = ResidualConvUnit(features)
self.resConfUnit2 = ResidualConvUnit(features)

def forward(self, *xs):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]

if len(xs) == 2:
output += self.resConfUnit1(xs[1])

output = self.resConfUnit2(output)
output = nn.functional.interpolate(output,
scale_factor=2,
mode="bilinear",
align_corners=True)

return output
92 changes: 92 additions & 0 deletions ppgan/apps/midas/midas_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Refer https://github.com/intel-isl/MiDaS
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
"""
import numpy as np
import paddle
import paddle.nn as nn

from .blocks import FeatureFusionBlock, _make_encoder


class BaseModel(paddle.nn.Layer):
def load(self, path):
"""Load model from file.
Args:
path (str): file path
"""
parameters = paddle.load(path)
self.set_dict(parameters)


class MidasNet(BaseModel):
"""Network for monocular depth estimation.
"""
def __init__(self, path=None, features=256, non_negative=True):
"""Init.
Args:
path (str, optional): Path to saved model. Defaults to None.
features (int, optional): Number of features. Defaults to 256.
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
"""
print("Loading weights: ", path)

super(MidasNet, self).__init__()

use_pretrained = False if path is None else True

self.pretrained, self.scratch = _make_encoder(
backbone="resnext101_wsl",
features=features,
use_pretrained=use_pretrained)

self.scratch.refinenet4 = FeatureFusionBlock(features)
self.scratch.refinenet3 = FeatureFusionBlock(features)
self.scratch.refinenet2 = FeatureFusionBlock(features)
self.scratch.refinenet1 = FeatureFusionBlock(features)

output_conv = [
nn.Conv2D(features, 128, kernel_size=3, stride=1, padding=1),
nn.Upsample(scale_factor=2, mode="bilinear"),
nn.Conv2D(128, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.Conv2D(32, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU() if non_negative else nn.Identity(),
]
if non_negative:
output_conv.append(nn.ReLU())

self.scratch.output_conv = nn.Sequential(*output_conv)

if path:
self.load(path)

def forward(self, x):
"""Forward pass.
Args:
x (tensor): input data (image)
Returns:
tensor: depth
"""

layer_1 = self.pretrained.layer1(x)
layer_2 = self.pretrained.layer2(layer_1)
layer_3 = self.pretrained.layer3(layer_2)
layer_4 = self.pretrained.layer4(layer_3)

layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)

path_4 = self.scratch.refinenet4(layer_4_rn)
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)

out = self.scratch.output_conv(path_1)

return paddle.squeeze(out, axis=1)
Loading

0 comments on commit 77b8bac

Please sign in to comment.