Skip to content

Commit

Permalink
Add an exponential decay learning rate schedule with warmup.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 344743049
  • Loading branch information
pkulzc authored and TF Object Detection Team committed Nov 30, 2020
1 parent 067d35f commit 3a6079c
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 12 deletions.
82 changes: 70 additions & 12 deletions research/object_detection/utils/learning_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@
import tensorflow.compat.v1 as tf


def _learning_rate_return_value(eager_decay_rate):
"""Helper function to return proper learning rate based on tf version."""
if tf.executing_eagerly():
return eager_decay_rate
else:
return eager_decay_rate()


def exponential_decay_with_burnin(global_step,
learning_rate_base,
learning_rate_decay_steps,
Expand Down Expand Up @@ -76,10 +84,65 @@ def eager_decay_rate():
tf.constant(burnin_learning_rate),
post_burnin_learning_rate), min_learning_rate, name='learning_rate')

if tf.executing_eagerly():
return eager_decay_rate
else:
return eager_decay_rate()
return _learning_rate_return_value(eager_decay_rate)


def exponential_decay_with_warmup(global_step,
learning_rate_base,
learning_rate_decay_steps,
learning_rate_decay_factor,
warmup_learning_rate=0.0,
warmup_steps=0,
min_learning_rate=0.0,
staircase=True):
"""Exponential decay schedule with warm up period.
Args:
global_step: int tensor representing global step.
learning_rate_base: base learning rate.
learning_rate_decay_steps: steps to take between decaying the learning rate.
Note that this includes the number of burn-in steps.
learning_rate_decay_factor: multiplicative factor by which to decay learning
rate.
warmup_learning_rate: initial learning rate during warmup period.
warmup_steps: number of steps to use warmup learning rate.
min_learning_rate: the minimum learning rate.
staircase: whether use staircase decay.
Returns:
If executing eagerly:
returns a no-arg callable that outputs the (scalar)
float tensor learning rate given the current value of global_step.
If in a graph:
immediately returns a (scalar) float tensor representing learning rate.
"""

def eager_decay_rate():
"""Callable to compute the learning rate."""
post_warmup_learning_rate = tf.train.exponential_decay(
learning_rate_base,
global_step - warmup_steps,
learning_rate_decay_steps,
learning_rate_decay_factor,
staircase=staircase)
if callable(post_warmup_learning_rate):
post_warmup_learning_rate = post_warmup_learning_rate()

if learning_rate_base < warmup_learning_rate:
raise ValueError('learning_rate_base must be larger or equal to '
'warmup_learning_rate.')
slope = (learning_rate_base - warmup_learning_rate) / warmup_steps
warmup_rate = slope * tf.cast(global_step,
tf.float32) + warmup_learning_rate
learning_rate = tf.where(
tf.less(tf.cast(global_step, tf.int32), tf.constant(warmup_steps)),
warmup_rate,
tf.maximum(post_warmup_learning_rate, min_learning_rate),
name='learning_rate')

return learning_rate

return _learning_rate_return_value(eager_decay_rate)


def cosine_decay_with_warmup(global_step,
Expand Down Expand Up @@ -142,10 +205,7 @@ def eager_decay_rate():
return tf.where(global_step > total_steps, 0.0, learning_rate,
name='learning_rate')

if tf.executing_eagerly():
return eager_decay_rate
else:
return eager_decay_rate()
return _learning_rate_return_value(eager_decay_rate)


def manual_stepping(global_step, boundaries, rates, warmup=False):
Expand Down Expand Up @@ -212,7 +272,5 @@ def eager_decay_rate():
[0] * num_boundaries))
return tf.reduce_sum(rates * tf.one_hot(rate_index, depth=num_boundaries),
name='learning_rate')
if tf.executing_eagerly():
return eager_decay_rate
else:
return eager_decay_rate()

return _learning_rate_return_value(eager_decay_rate)
22 changes: 22 additions & 0 deletions research/object_detection/utils/learning_schedules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,28 @@ def graph_fn(global_step):
exp_rates = [.5, .5, 1, 1, 1, .1, .1, .1, .05]
self.assertAllClose(output_rates, exp_rates, rtol=1e-4)

def testExponentialDecayWithWarmup(self):
def graph_fn(global_step):
learning_rate_base = 1.0
learning_rate_decay_steps = 3
learning_rate_decay_factor = .1
warmup_learning_rate = .5
warmup_steps = 2
min_learning_rate = .05
learning_rate = learning_schedules.exponential_decay_with_warmup(
global_step, learning_rate_base, learning_rate_decay_steps,
learning_rate_decay_factor, warmup_learning_rate, warmup_steps,
min_learning_rate)
assert learning_rate.op.name.endswith('learning_rate')
return (learning_rate,)

output_rates = [
self.execute(graph_fn, [np.array(i).astype(np.int64)]) for i in range(9)
]

exp_rates = [.5, .75, 1, 1, 1, .1, .1, .1, .05]
self.assertAllClose(output_rates, exp_rates, rtol=1e-4)

def testCosineDecayWithWarmup(self):
def graph_fn(global_step):
learning_rate_base = 1.0
Expand Down

0 comments on commit 3a6079c

Please sign in to comment.