Skip to content

Commit

Permalink
internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 422665603
  • Loading branch information
yuanliangzhe authored and tensorflower-gardener committed Jan 18, 2022
1 parent 27fb855 commit f13be76
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 8 deletions.
36 changes: 29 additions & 7 deletions official/projects/movinet/modeling/movinet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,13 @@ def __init__(
# Move backbone after super() call so Keras is happy
self._backbone = backbone

def _build_network(
def _build_backbone(
self,
backbone: tf.keras.Model,
input_specs: Mapping[str, tf.keras.layers.InputSpec],
state_specs: Optional[Mapping[str, tf.keras.layers.InputSpec]] = None,
) -> Tuple[Mapping[str, tf.keras.Input], Union[Tuple[Mapping[ # pytype: disable=invalid-annotation # typed-keras
str, tf.Tensor], Mapping[str, tf.Tensor]], Mapping[str, tf.Tensor]]]:
"""Builds the model network.
) -> Tuple[Mapping[str, Any], Any, Any]:
"""Builds the backbone network and gets states and endpoints.
Args:
backbone: the model backbone.
Expand All @@ -104,9 +103,9 @@ def _build_network(
layer, will overwrite the contents of the buffer(s).
Returns:
Inputs and outputs as a tuple. Inputs are expected to be a dict with
base input and states. Outputs are expected to be a dict of endpoints
and (optionally) output states.
inputs: a dict of input specs.
endpoints: a dict of model endpoints.
states: a dict of model states.
"""
state_specs = state_specs if state_specs is not None else {}

Expand Down Expand Up @@ -145,7 +144,30 @@ def _build_network(
mismatched_shapes))
else:
endpoints, states = backbone(inputs)
return inputs, endpoints, states

def _build_network(
self,
backbone: tf.keras.Model,
input_specs: Mapping[str, tf.keras.layers.InputSpec],
state_specs: Optional[Mapping[str, tf.keras.layers.InputSpec]] = None,
) -> Tuple[Mapping[str, tf.keras.Input], Union[Tuple[Mapping[ # pytype: disable=invalid-annotation # typed-keras
str, tf.Tensor], Mapping[str, tf.Tensor]], Mapping[str, tf.Tensor]]]:
"""Builds the model network.
Args:
backbone: the model backbone.
input_specs: the model input spec to use.
state_specs: a dict of states such that, if any of the keys match for a
layer, will overwrite the contents of the buffer(s).
Returns:
Inputs and outputs as a tuple. Inputs are expected to be a dict with
base input and states. Outputs are expected to be a dict of endpoints
and (optionally) output states.
"""
inputs, endpoints, states = self._build_backbone(
backbone=backbone, input_specs=input_specs, state_specs=state_specs)
x = endpoints['head']

x = movinet_layers.ClassifierHead(
Expand Down
3 changes: 2 additions & 1 deletion official/projects/movinet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@
# Import movinet libraries to register the backbone and model into tf.vision
# model garden factory.
# pylint: disable=unused-import
# the followings are the necessary imports.
from official.projects.movinet.google.configs import movinet_google
from official.projects.movinet.google.modeling import movinet_model_google
from official.projects.movinet.modeling import movinet
from official.projects.movinet.modeling import movinet_model
# pylint: enable=unused-import
Expand Down

0 comments on commit f13be76

Please sign in to comment.