def forward(self, input):
center = input[:, -1, :].unsqueeze(1)
delta_x = input[:, :, 0:3] - center[:, :, 0:3] # (B, npoint, 3), normalized coordinates
for case in switch(self.dataset):
if case('3DMatch'):
z_axis = cm.cal_Z_axis(delta_x, ref_point=input[:, -1, :3])
z_axis = cm.l2_norm(z_axis, axis=1)
R = cm.RodsRotatFormula(z_axis, torch.FloatTensor([0, 0, 1]).unsqueeze(0).repeat(z_axis.shape[0], 1))
delta_x = torch.matmul(delta_x, R)
break
if case('KITTI'):
break