Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 285974579
  • Loading branch information
chr1sj0nes authored and tensorflower-gardener committed Dec 17, 2019
1 parent a19f2f8 commit 7a69f96
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions official/transformer/v2/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import numpy as np
import tensorflow as tf

from tensorflow.python.distribute import values
from official.transformer.utils import tokenizer

_EXTRA_DECODE_LENGTH = 100
Expand Down Expand Up @@ -144,12 +143,15 @@ def _step_fn(inputs):
text = np.reshape(text, [num_replicas, local_batch_size, -1])
# Add tag to the input of each replica with the reordering logic after
# outputs, to ensure the output order matches the input order.
text = [
[tf.convert_to_tensor(tag), tf.convert_to_tensor(per_replica_text)]
for tag, per_replica_text in enumerate(text)
]
# pylint: disable=protected-access
text = values.PerReplica(distribution_strategy.extended._device_map, text)
text = tf.constant(text)

@tf.function
def text_as_per_replica():
replica_context = tf.distribute.get_replica_context()
replica_id = replica_context.replica_id_in_sync_group
return replica_id, text[replica_id]

text = distribution_strategy.experimental_run_v2(text_as_per_replica)
outputs = distribution_strategy.experimental_local_results(
predict_step(text))
tags, unordered_val_outputs = outputs[0]
Expand Down

0 comments on commit 7a69f96

Please sign in to comment.