forked from naver/mast3r
-
Notifications
You must be signed in to change notification settings - Fork 0
/
export_mast3r_onnx_decoder.py
178 lines (155 loc) · 8.47 KB
/
export_mast3r_onnx_decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
import os
import torch
import numpy as np
import trimesh
from scipy.spatial.transform import Rotation
from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
from mast3r.model_decoder import AsymmetricMASt3R
from mast3r.utils.misc import hash_md5
import mast3r.utils.path_to_dust3r # noqa
from dust3r.image_pairs import make_pairs
from dust3r.utils.image import load_images
from dust3r.utils.device import to_numpy
from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
from dust3r.demo import get_args_parser as dust3r_get_args_parser
torch.backends.cuda.matmul.allow_tf32 = True
batch_size = 1
def get_args_parser():
parser = dust3r_get_args_parser()
parser.add_argument('--output_dir', type=str, default='output', help='Output directory')
parser.add_argument('--input_images', nargs='+', required=True, help='Input image files')
parser.add_argument('--optim_level', choices=['coarse', 'refine', 'refine+depth'], default='refine', help='Optimization level')
parser.add_argument('--lr1', type=float, default=0.07, help='Coarse learning rate')
parser.add_argument('--niter1', type=int, default=500, help='Number of coarse iterations')
parser.add_argument('--lr2', type=float, default=0.014, help='Fine learning rate')
parser.add_argument('--niter2', type=int, default=200, help='Number of fine iterations')
parser.add_argument('--scenegraph_type', choices=['complete', 'swin', 'logwin', 'oneref'], default='complete', help='Scene graph type')
actions = parser._actions
for action in actions:
if action.dest == 'model_name':
action.choices = ["MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"]
parser.prog = 'mast3r demo'
return parser
def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
cam_color=None, as_pointcloud=False,
transparent_cams=False, silent=False):
assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
pts3d = to_numpy(pts3d)
imgs = to_numpy(imgs)
focals = to_numpy(focals)
cams2world = to_numpy(cams2world)
scene = trimesh.Scene()
# full pointcloud
if as_pointcloud:
pts = np.concatenate([p[m.ravel()] for p, m in zip(pts3d, mask)])
col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
scene.add_geometry(pct)
else:
meshes = []
for i in range(len(imgs)):
meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i].reshape(imgs[i].shape), mask[i]))
mesh = trimesh.Trimesh(**cat_meshes(meshes))
scene.add_geometry(mesh)
# add each camera
for i, pose_c2w in enumerate(cams2world):
if isinstance(cam_color, list):
camera_edge_color = cam_color[i]
else:
camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
add_scene_cam(scene, pose_c2w, camera_edge_color,
None if transparent_cams else imgs[i], focals[i],
imsize=imgs[i].shape[1::-1], screen_width=cam_size)
rot = np.eye(4)
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
outfile = os.path.join(outdir, 'scene.glb')
if not silent:
print('(exporting 3D scene to', outfile, ')')
scene.export(file_obj=outfile)
return outfile
def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=2, as_pointcloud=False, mask_sky=False,
clean_depth=False, transparent_cams=False, cam_size=0.05, TSDF_thresh=0):
"""
extract 3D_model (glb file) from a reconstructed scene
"""
if scene is None:
return None
# get optimized values from scene
rgbimg = scene.imgs
focals = scene.get_focals().cpu()
cams2world = scene.get_im_poses().cpu()
# 3D pointcloud from depthmap, poses and intrinsics
if TSDF_thresh > 0:
tsdf = TSDFPostProcess(scene, TSDF_thresh=TSDF_thresh)
pts3d, _, confs = to_numpy(tsdf.get_dense_pts3d(clean_depth=clean_depth))
else:
pts3d, _, confs = to_numpy(scene.get_dense_pts3d(clean_depth=clean_depth))
msk = to_numpy([c > min_conf_thr for c in confs])
return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
def get_reconstructed_scene(outdir, model, device, silent, image_size, filelist, optim_level, lr1, niter1, lr2, niter2,
min_conf_thr, matching_conf_thr, as_pointcloud, mask_sky, clean_depth, transparent_cams,
cam_size, scenegraph_type, winsize, win_cyclic, refid, TSDF_thresh, shared_intrinsics,
**kw):
"""
from a list of images, run mast3r inference, sparse global aligner.
then run get_3D_model_from_scene
"""
imgs = load_images(filelist, size=image_size, verbose=not silent)
if len(imgs) == 1:
imgs = [imgs[0], copy.deepcopy(imgs[0])]
imgs[1]['idx'] = 1
filelist = [filelist[0], filelist[0] + '_2']
scene_graph_params = [scenegraph_type]
if scenegraph_type in ["swin", "logwin"]:
scene_graph_params.append(str(winsize))
elif scenegraph_type == "oneref":
scene_graph_params.append(str(refid))
if scenegraph_type in ["swin", "logwin"] and not win_cyclic:
scene_graph_params.append('noncyclic')
scene_graph = '-'.join(scene_graph_params)
pairs = make_pairs(imgs, scene_graph=scene_graph, prefilter=None, symmetrize=True)
if optim_level == 'coarse':
niter2 = 0
# Sparse GA (forward mast3r -> matching -> 3D optim -> 2D refinement -> triangulation)
scene = sparse_global_alignment(filelist, pairs, os.path.join(outdir, 'cache'),
model, lr1=lr1, niter1=niter1, lr2=lr2, niter2=niter2, device=device,
opt_depth='depth' in optim_level, shared_intrinsics=shared_intrinsics,
matching_conf_thr=matching_conf_thr, **kw)
outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
clean_depth, transparent_cams, cam_size, TSDF_thresh)
return scene, outfile
def main(args):
model = AsymmetricMASt3R.from_pretrained(args.weights or "naver/" + args.model_name).to(args.device)
chkpt_tag = hash_md5(args.weights or "naver/" + args.model_name)
output_dir = os.path.join(args.output_dir, chkpt_tag)
os.makedirs(output_dir, exist_ok=True)
scene, outfile = get_reconstructed_scene(
output_dir, model, args.device, args.silent, args.image_size,
args.input_images, 'refine', 0.07, 500, 0.014, 200, 1.5, 5.0,
True, False, True, False, 0.2, 'complete', 1, False, 0, 0.0, False
)
print(f"3D model saved to: {outfile}")
if __name__ == '__main__':
parser = get_args_parser()
args = parser.parse_args()
torch_model = AsymmetricMASt3R.from_pretrained(args.weights or "naver/" + args.model_name).to(args.device)
print(torch_model)
#"AsymmetricMASt3R(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='catmlp+dpt', output_mode='pts3d+desc24', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12, two_confs=True, desc_conf_mode=('exp', 0, inf))" \
chkpt_tag = hash_md5(args.weights or "naver/" + args.model_name)
output_dir = os.path.join(args.output_dir, chkpt_tag)
os.makedirs(output_dir, exist_ok=True)
feat1 = torch.load('input/feat1.pth')
feat2 = torch.load('input/feat2.pth')
pos1 = torch.load('input/pos1.pth')
pos2 = torch.load('input/pos2.pth')
shape1 = torch.load('input/shape1.pth')
shape2 = torch.load('input/shape2.pth')
input = (feat1, feat2, pos1, pos2, shape1, shape2)
#output = model(img1, img2, shape1, shape2)
torch.onnx.export(torch_model, input, os.path.join(output_dir, 'mast3r_decoder_params.onnx'), export_params=True, opset_version=17, do_constant_folding=True, verbose=True)