Skip to content

Commit

Permalink
Code cleanup.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 283837279
  • Loading branch information
yeqingli authored and tensorflower-gardener committed Dec 4, 2019
1 parent 5b25005 commit 91a1ce9
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 31 deletions.
29 changes: 29 additions & 0 deletions official/vision/detection/evaluation/coco_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,35 @@
from official.vision.detection.utils import class_utils


class MetricWrapper(object):
# This is only a wrapper for COCO metric and works on for numpy array. So it
# doesn't inherit from tf.keras.layers.Layer or tf.keras.metrics.Metric.

def __init__(self, evaluator):
self._evaluator = evaluator

def update_state(self, y_true, y_pred):
labels = tf.nest.map_structure(lambda x: x.numpy(), y_true)
outputs = tf.nest.map_structure(lambda x: x.numpy(), y_pred)
groundtruths = {}
predictions = {}
for key, val in outputs.items():
if isinstance(val, tuple):
val = np.concatenate(val)
predictions[key] = val
for key, val in labels.items():
if isinstance(val, tuple):
val = np.concatenate(val)
groundtruths[key] = val
self._evaluator.update(predictions, groundtruths)

def result(self):
return self._evaluator.evaluate()

def reset_states(self):
return self._evaluator.reset()


class COCOEvaluator(object):
"""COCO evaluation metric class."""

Expand Down
2 changes: 1 addition & 1 deletion official/vision/detection/evaluation/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ def evaluator_generator(params):
else:
raise ValueError('Evaluator %s is not supported.' % params.type)

return evaluator
return coco_evaluator.MetricWrapper(evaluator)
36 changes: 6 additions & 30 deletions official/vision/detection/modeling/retinanet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,35 +32,6 @@
from official.vision.detection.ops import postprocess_ops


class COCOMetrics(object):
# This is only a wrapper for COCO metric and works on for numpy array. So it
# doesn't inherit from tf.keras.layers.Layer or tf.keras.metrics.Metric.

def __init__(self, params):
self._evaluator = eval_factory.evaluator_generator(params.eval)

def update_state(self, y_true, y_pred):
labels = tf.nest.map_structure(lambda x: x.numpy(), y_true)
outputs = tf.nest.map_structure(lambda x: x.numpy(), y_pred)
groundtruths = {}
predictions = {}
for key, val in outputs.items():
if isinstance(val, tuple):
val = np.concatenate(val)
predictions[key] = val
for key, val in labels.items():
if isinstance(val, tuple):
val = np.concatenate(val)
groundtruths[key] = val
self._evaluator.update(predictions, groundtruths)

def result(self):
return self._evaluator.evaluate()

def reset_states(self):
return self._evaluator.reset()


class RetinanetModel(base_model.Model):
"""RetinaNet model function."""

Expand Down Expand Up @@ -97,6 +68,11 @@ def __init__(self, params):
dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32)

def build_outputs(self, inputs, mode):
# If the input image is transposed (from NHWC to HWCN), we need to revert it
# back to the original shape before it's used in the computation.
if self._transpose_input:
inputs = tf.transpose(inputs, [3, 0, 1, 2])

backbone_features = self._backbone_fn(
inputs, is_training=(mode == mode_keys.TRAIN))
fpn_features = self._fpn_fn(
Expand Down Expand Up @@ -192,4 +168,4 @@ def post_processing(self, labels, outputs):
return labels, outputs

def eval_metrics(self):
return COCOMetrics(self._params)
return eval_factory.evaluator_generator(self._params.eval)

0 comments on commit 91a1ce9

Please sign in to comment.