Skip to content

Commit

Permalink
Refactor: Move transformer Estimator-only code to r1 folder.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 286103069
  • Loading branch information
saberkun authored and tensorflower-gardener committed Dec 18, 2019
1 parent 0612e19 commit 2f6c5a5
Show file tree
Hide file tree
Showing 14 changed files with 517 additions and 1,001 deletions.
376 changes: 376 additions & 0 deletions official/r1/transformer/README.md

Large diffs are not rendered by default.

File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import tensorflow as tf

from official.transformer.utils import schedule
from official.r1.transformer import schedule


class ScheduleBaseTester(tf.test.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@

import tensorflow as tf # pylint: disable=g-bad-import-order

from official.transformer.model import attention_layer
from official.r1.transformer import attention_layer
from official.r1.transformer import embedding_layer
from official.r1.transformer import ffn_layer
from official.transformer.model import beam_search
from official.transformer.model import embedding_layer
from official.transformer.model import ffn_layer
from official.transformer.model import model_utils
from official.transformer.utils.tokenizer import EOS_ID

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@

from official.r1.utils import export
from official.r1.utils import tpu as tpu_util
from official.r1.transformer import translate
from official.r1.transformer import transformer
from official.r1.transformer import dataset
from official.r1.transformer import schedule
from official.transformer import compute_bleu
from official.transformer import translate
from official.transformer.model import model_params
from official.transformer.model import transformer
from official.transformer.utils import dataset
from official.transformer.utils import metrics
from official.transformer.utils import schedule
from official.transformer.utils import tokenizer
from official.utils.flags import core as flags_core
from official.utils.logs import hooks_helper
Expand Down Expand Up @@ -115,8 +115,10 @@ def model_fn(features, labels, mode, params):
metric_fn = lambda logits, labels: (
metrics.get_eval_metrics(logits, labels, params=params))
eval_metrics = (metric_fn, [logits, labels])
return tf.contrib.tpu.TPUEstimatorSpec(
mode=mode, loss=loss, predictions={"predictions": logits},
return tf.estimator.tpu.TPUEstimatorSpec(
mode=mode,
loss=loss,
predictions={"predictions": logits},
eval_metrics=eval_metrics)
return tf.estimator.EstimatorSpec(
mode=mode, loss=loss, predictions={"predictions": logits},
Expand All @@ -128,12 +130,14 @@ def model_fn(features, labels, mode, params):
# in TensorBoard.
metric_dict["minibatch_loss"] = loss
if params["use_tpu"]:
return tf.contrib.tpu.TPUEstimatorSpec(
mode=mode, loss=loss, train_op=train_op,
return tf.estimator.tpu.TPUEstimatorSpec(
mode=mode,
loss=loss,
train_op=train_op,
host_call=tpu_util.construct_scalar_host_call(
metric_dict=metric_dict, model_dir=params["model_dir"],
prefix="training/")
)
metric_dict=metric_dict,
model_dir=params["model_dir"],
prefix="training/"))
record_scalars(metric_dict)
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

Expand Down Expand Up @@ -342,6 +346,7 @@ def run_loop(
steps=schedule_manager.single_iteration_train_steps,
hooks=train_hooks)

eval_results = None
eval_results = estimator.evaluate(
input_fn=dataset.eval_input_fn,
steps=schedule_manager.single_iteration_eval_steps)
Expand Down Expand Up @@ -534,25 +539,26 @@ def construct_estimator(flags_obj, params, schedule_manager):
project=flags_obj.tpu_gcp_project
)

tpu_config = tf.contrib.tpu.TPUConfig(
tpu_config = tf.estimator.tpu.TPUConfig(
iterations_per_loop=schedule_manager.single_iteration_train_steps,
num_shards=flags_obj.num_tpu_shards)

run_config = tf.contrib.tpu.RunConfig(
run_config = tf.estimator.tpu.RunConfig(
cluster=tpu_cluster_resolver,
model_dir=flags_obj.model_dir,
session_config=tf.ConfigProto(
allow_soft_placement=True, log_device_placement=True),
tpu_config=tpu_config)

return tf.contrib.tpu.TPUEstimator(
return tf.estimator.tpu.TPUEstimator(
model_fn=model_fn,
use_tpu=params["use_tpu"] and flags_obj.tpu != tpu_util.LOCAL,
train_batch_size=schedule_manager.batch_size,
eval_batch_size=schedule_manager.batch_size,
params={
# TPUEstimator needs to populate batch_size itself due to sharding.
key: value for key, value in params.items() if key != "batch_size"},
key: value for key, value in params.items() if key != "batch_size"
},
config=run_config)


Expand Down
File renamed without changes.
Loading

0 comments on commit 2f6c5a5

Please sign in to comment.