Skip to content

Commit

Permalink
Refactored the keypoint postprocessing logics to reduce duplicated co…
Browse files Browse the repository at this point in the history
…des and

added a new operating mode for single-instance prediction.

1) Refactored the _postprocess_keypoints_for_class_and_image function such that
it can be reused by single/multi class keypoint tasks.
2) Removed the "mod" operator to make the model compatible with WASM.

PiperOrigin-RevId: 345468250
  • Loading branch information
Yu-hui Chen authored and TF Object Detection Team committed Dec 3, 2020
1 parent e3f8ea2 commit ca0eb4a
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 136 deletions.
167 changes: 39 additions & 128 deletions research/object_detection/meta_architectures/center_net_meta_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,9 +942,12 @@ def row_col_channel_indices_from_flattened_indices(indices, num_cols,
indices.
"""
# Avoid using mod operator to make the ops more easy to be compatible with
# different environments, e.g. WASM.
row_indices = (indices // num_channels) // num_cols
col_indices = (indices // num_channels) % num_cols
channel_indices = indices % num_channels
col_indices = (indices // num_channels) - row_indices * num_cols
channel_indices_temp = indices // num_channels
channel_indices = indices - channel_indices_temp * num_channels

return row_indices, col_indices, channel_indices

Expand Down Expand Up @@ -2925,10 +2928,7 @@ def postprocess(self, prediction_dict, true_image_shapes, **params):
# keypoint, we fall back to a simpler postprocessing function which uses
# the ops that are supported by tf.lite on GPU.
if len(self._kp_params_dict) == 1 and self._num_classes == 1:
# keypoints, keypoint_scores = self._postprocess_keypoints_simple(
# prediction_dict, classes, y_indices, x_indices,
# boxes_strided, num_detections)
keypoints, keypoint_scores = self._postprocess_keypoints_simple(
keypoints, keypoint_scores = self._postprocess_keypoints_single_class(
prediction_dict, classes, y_indices, x_indices,
boxes_strided, num_detections)
# The map_fn used to clip out of frame keypoints creates issues when
Expand All @@ -2939,7 +2939,7 @@ def postprocess(self, prediction_dict, true_image_shapes, **params):
keypoints, keypoint_scores, self._stride, true_image_shapes,
clip_out_of_frame_keypoints=False))
else:
keypoints, keypoint_scores = self._postprocess_keypoints(
keypoints, keypoint_scores = self._postprocess_keypoints_multi_class(
prediction_dict, classes, y_indices, x_indices,
boxes_strided, num_detections)
keypoints, keypoint_scores = (
Expand Down Expand Up @@ -3014,10 +3014,18 @@ def _postprocess_embeddings(self, prediction_dict, y_indices, x_indices):

return embeddings

def _postprocess_keypoints(self, prediction_dict, classes, y_indices,
x_indices, boxes, num_detections):
def _postprocess_keypoints_multi_class(self, prediction_dict, classes,
y_indices, x_indices, boxes,
num_detections):
"""Performs postprocessing on keypoint predictions.
This is the most general keypoint postprocessing function which supports
multiple keypoint tasks (e.g. human and dog keypoints) and multiple object
detection classes. Note that it is the most expensive postprocessing logics
and is currently not tf.lite/tf.js compatible. See
_postprocess_keypoints_single_class if you plan to export the model in more
portable format.
Args:
prediction_dict: a dictionary holding predicted tensors, returned from the
predict() method. This dictionary should contain keypoint prediction
Expand Down Expand Up @@ -3060,11 +3068,15 @@ def _postprocess_keypoints(self, prediction_dict, classes, y_indices,
classes, num_detections, ex_ind, kp_params.class_id)
num_ind = _get_shape(instance_inds, 1)

def true_fn(
keypoint_heatmap, keypoint_offsets, keypoint_regression,
classes, y_indices, x_indices, boxes, instance_inds,
ex_ind, kp_params):
def true_fn(keypoint_heatmap, keypoint_offsets, keypoint_regression,
classes, y_indices, x_indices, boxes, instance_inds, ex_ind,
kp_params):
"""Logics to execute when instance_inds is not an empty set."""
# Gather the feature map locations corresponding to the object class.
y_indices_for_kpt_class = tf.gather(y_indices, instance_inds, axis=1)
x_indices_for_kpt_class = tf.gather(x_indices, instance_inds, axis=1)
boxes_for_kpt_class = tf.gather(boxes, instance_inds, axis=1)

# Postprocess keypoints and scores for class and single image. Shapes
# are [1, num_instances_i, num_keypoints_i, 2] and
# [1, num_instances_i, num_keypoints_i], respectively. Note that
Expand All @@ -3073,15 +3085,17 @@ def true_fn(
kpt_coords_for_class, kpt_scores_for_class = (
self._postprocess_keypoints_for_class_and_image(
keypoint_heatmap, keypoint_offsets, keypoint_regression,
classes, y_indices, x_indices, boxes, instance_inds,
ex_ind, kp_params))
classes, y_indices_for_kpt_class, x_indices_for_kpt_class,
boxes_for_kpt_class, ex_ind, kp_params))

# Expand keypoint dimension (with padding) so that coordinates and
# scores have shape [1, num_instances_i, num_total_keypoints, 2] and
# [1, num_instances_i, num_total_keypoints], respectively.
kpts_coords_for_class_padded, kpt_scores_for_class_padded = (
_pad_to_full_keypoint_dim(
kpt_coords_for_class, kpt_scores_for_class,
kp_params.keypoint_indices, total_num_keypoints))
_pad_to_full_keypoint_dim(kpt_coords_for_class,
kpt_scores_for_class,
kp_params.keypoint_indices,
total_num_keypoints))
return kpts_coords_for_class_padded, kpt_scores_for_class_padded

def false_fn():
Expand Down Expand Up @@ -3135,9 +3149,10 @@ def false_fn():

return keypoints, keypoint_scores

def _postprocess_keypoints_simple(self, prediction_dict, classes, y_indices,
x_indices, boxes, num_detections):
"""Performs postprocessing on keypoint predictions (one class only).
def _postprocess_keypoints_single_class(self, prediction_dict, classes,
y_indices, x_indices, boxes,
num_detections):
"""Performs postprocessing on keypoint predictions (single class only).
This function handles the special case of keypoint task that the model
predicts only one class of the bounding box/keypoint (e.g. person). By the
Expand Down Expand Up @@ -3186,9 +3201,9 @@ def _postprocess_keypoints_simple(self, prediction_dict, classes, y_indices,
# are [1, max_detections, num_keypoints, 2] and
# [1, max_detections, num_keypoints], respectively.
kpt_coords_for_class, kpt_scores_for_class = (
self._postprocess_keypoints_for_class_and_image_simple(
keypoint_heatmap, keypoint_offsets, keypoint_regression,
classes, y_indices, x_indices, boxes, ex_ind, kp_params))
self._postprocess_keypoints_for_class_and_image(
keypoint_heatmap, keypoint_offsets, keypoint_regression, classes,
y_indices, x_indices, boxes, ex_ind, kp_params))

kpt_coords_for_example_list.append(kpt_coords_for_class)
kpt_scores_for_example_list.append(kpt_scores_for_class)
Expand Down Expand Up @@ -3233,114 +3248,10 @@ def _get_instance_indices(self, classes, num_detections, batch_index,
return tf.cast(instance_inds, tf.int32)

def _postprocess_keypoints_for_class_and_image(
self, keypoint_heatmap, keypoint_offsets, keypoint_regression, classes,
y_indices, x_indices, boxes, indices_with_kpt_class, batch_index,
kp_params):
"""Postprocess keypoints for a single image and class.
This function performs the following postprocessing operations on a single
image and single keypoint class:
- Converts keypoints scores to range [0, 1] with sigmoid.
- Determines the detections that correspond to the specified keypoint class.
- Gathers the regressed keypoints at the detection (i.e. box) centers.
- Gathers keypoint candidates from the keypoint heatmaps.
- Snaps regressed keypoints to nearby keypoint candidates.
Args:
keypoint_heatmap: A [batch_size, height, width, num_keypoints] float32
tensor with keypoint heatmaps.
keypoint_offsets: A [batch_size, height, width, 2] float32 tensor with
local offsets to keypoint centers.
keypoint_regression: A [batch_size, height, width, 2 * num_keypoints]
float32 tensor with regressed offsets to all keypoints.
classes: A [batch_size, max_detections] int tensor with class indices for
all detected objects.
y_indices: A [batch_size, max_detections] int tensor with y indices for
all object centers.
x_indices: A [batch_size, max_detections] int tensor with x indices for
all object centers.
boxes: A [batch_size, max_detections, 4] float32 tensor with detected
boxes in the output (strided) frame.
indices_with_kpt_class: A [num_instances] int tensor where each element
indicates the instance location within the `classes` tensor. This is
useful to associate the refined keypoints with the original detections
(i.e. boxes)
batch_index: An integer specifying the index for an example in the batch.
kp_params: A `KeypointEstimationParams` object with parameters for a
single keypoint class.
Returns:
A tuple of
refined_keypoints: A [1, num_instances, num_keypoints, 2] float32 tensor
with refined keypoints for a single class in a single image, expressed
in the output (strided) coordinate frame. Note that `num_instances` is a
dynamic dimension, and corresponds to the number of valid detections
for the specific class.
refined_scores: A [1, num_instances, num_keypoints] float32 tensor with
keypoint scores.
"""
keypoint_indices = kp_params.keypoint_indices
num_keypoints = len(keypoint_indices)

keypoint_heatmap = tf.nn.sigmoid(
keypoint_heatmap[batch_index:batch_index+1, ...])
keypoint_offsets = keypoint_offsets[batch_index:batch_index+1, ...]
keypoint_regression = keypoint_regression[batch_index:batch_index+1, ...]
y_indices = y_indices[batch_index:batch_index+1, ...]
x_indices = x_indices[batch_index:batch_index+1, ...]
boxes_slice = boxes[batch_index:batch_index+1, ...]

# Gather the feature map locations corresponding to the object class.
y_indices_for_kpt_class = tf.gather(y_indices, indices_with_kpt_class,
axis=1)
x_indices_for_kpt_class = tf.gather(x_indices, indices_with_kpt_class,
axis=1)
boxes_for_kpt_class = tf.gather(boxes_slice, indices_with_kpt_class, axis=1)

# Gather the regressed keypoints. Final tensor has shape
# [1, num_instances, num_keypoints, 2].
regressed_keypoints_for_objects = regressed_keypoints_at_object_centers(
keypoint_regression, y_indices_for_kpt_class, x_indices_for_kpt_class)
regressed_keypoints_for_objects = tf.reshape(
regressed_keypoints_for_objects, [1, -1, num_keypoints, 2])

# Get the candidate keypoints and scores.
# The shape of keypoint_candidates and keypoint_scores is:
# [1, num_candidates_per_keypoint, num_keypoints, 2] and
# [1, num_candidates_per_keypoint, num_keypoints], respectively.
keypoint_candidates, keypoint_scores, num_keypoint_candidates = (
prediction_tensors_to_keypoint_candidates(
keypoint_heatmap, keypoint_offsets,
keypoint_score_threshold=(
kp_params.keypoint_candidate_score_threshold),
max_pool_kernel_size=kp_params.peak_max_pool_kernel_size,
max_candidates=kp_params.num_candidates_per_keypoint))

# Get the refined keypoints and scores, of shape
# [1, num_instances, num_keypoints, 2] and
# [1, num_instances, num_keypoints], respectively.
refined_keypoints, refined_scores = refine_keypoints(
regressed_keypoints=regressed_keypoints_for_objects,
keypoint_candidates=keypoint_candidates,
keypoint_scores=keypoint_scores,
num_keypoint_candidates=num_keypoint_candidates,
bboxes=boxes_for_kpt_class,
unmatched_keypoint_score=kp_params.unmatched_keypoint_score,
box_scale=kp_params.box_scale,
candidate_search_scale=kp_params.candidate_search_scale,
candidate_ranking_mode=kp_params.candidate_ranking_mode)

return refined_keypoints, refined_scores

def _postprocess_keypoints_for_class_and_image_simple(
self, keypoint_heatmap, keypoint_offsets, keypoint_regression, classes,
y_indices, x_indices, boxes, batch_index, kp_params):
"""Postprocess keypoints for a single image and class.
This function is similar to "_postprocess_keypoints_for_class_and_image"
except that it assumes there is only one class of bounding box/keypoint to
be handled. The function is tf.lite compatible.
Args:
keypoint_heatmap: A [batch_size, height, width, num_keypoints] float32
tensor with keypoint heatmaps.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1207,13 +1207,13 @@ def graph_fn():
_NUM_FC_LAYERS = 1


def get_fake_center_params():
def get_fake_center_params(max_box_predictions=5):
"""Returns the fake object center parameter namedtuple."""
return cnma.ObjectCenterParams(
classification_loss=losses.WeightedSigmoidClassificationLoss(),
object_center_loss_weight=1.0,
min_box_overlap_iou=1.0,
max_box_predictions=5,
max_box_predictions=max_box_predictions,
use_labeled_classes=False)


Expand All @@ -1225,7 +1225,7 @@ def get_fake_od_params():
scale_loss_weight=0.1)


def get_fake_kp_params():
def get_fake_kp_params(num_candidates_per_keypoint=100):
"""Returns the fake keypoint estimation parameter namedtuple."""
return cnma.KeypointEstimationParams(
task_name=_TASK_NAME,
Expand All @@ -1234,7 +1234,8 @@ def get_fake_kp_params():
keypoint_std_dev=[0.00001] * len(_KEYPOINT_INDICES),
classification_loss=losses.WeightedSigmoidClassificationLoss(),
localization_loss=losses.L1LocalizationLoss(),
keypoint_candidate_score_threshold=0.1)
keypoint_candidate_score_threshold=0.1,
num_candidates_per_keypoint=num_candidates_per_keypoint)


def get_fake_mask_params():
Expand Down Expand Up @@ -1277,7 +1278,9 @@ def get_fake_temporal_offset_params():
task_loss_weight=1.0)


def build_center_net_meta_arch(build_resnet=False, num_classes=_NUM_CLASSES):
def build_center_net_meta_arch(build_resnet=False,
num_classes=_NUM_CLASSES,
max_box_predictions=5):
"""Builds the CenterNet meta architecture."""
if build_resnet:
feature_extractor = (
Expand All @@ -1297,15 +1300,18 @@ def build_center_net_meta_arch(build_resnet=False, num_classes=_NUM_CLASSES):
pad_to_max_dimesnion=True)

if num_classes == 1:
num_candidates_per_keypoint = 100 if max_box_predictions > 1 else 1
return cnma.CenterNetMetaArch(
is_training=True,
add_summaries=False,
num_classes=num_classes,
feature_extractor=feature_extractor,
image_resizer_fn=image_resizer_fn,
object_center_params=get_fake_center_params(),
object_center_params=get_fake_center_params(max_box_predictions),
object_detection_params=get_fake_od_params(),
keypoint_params_dict={_TASK_NAME: get_fake_kp_params()})
keypoint_params_dict={
_TASK_NAME: get_fake_kp_params(num_candidates_per_keypoint)
})
else:
return cnma.CenterNetMetaArch(
is_training=True,
Expand Down Expand Up @@ -1726,7 +1732,7 @@ def graph_fn():
detections['detection_surface_coords'][0, 0, :, :],
np.zeros_like(detections['detection_surface_coords'][0, 0, :, :]))

def test_postprocess_simple(self):
def test_postprocess_single_class(self):
"""Test the postprocess function."""
model = build_center_net_meta_arch(num_classes=1)
max_detection = model._center_params.max_box_predictions
Expand Down

0 comments on commit ca0eb4a

Please sign in to comment.