forked from tensorflow/minigo
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlr_finder.py
More file actions
171 lines (138 loc) · 6.56 KB
/
Copy pathlr_finder.py
File metadata and controls
171 lines (138 loc) · 6.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from matplotlib import pyplot as plt
import math
from keras.callbacks import LambdaCallback
import keras.backend as K
import numpy as np
class LRFinder:
""" https://github.com/surmenok/keras_lr_finder
Plots the change of the loss function of a Keras model when the learning rate is exponentially increasing.
See for details:
https://towardsdatascience.com/estimating-optimal-learning-rate-for-a-deep-neural-network-ce32f2556ce0
Usage:
# model is a Keras model
lr_finder = LRFinder(model)
# Train a model with batch size 512 for 5 epochs
# with learning rate growing exponentially from 0.0001 to 1
lr_finder.find(x_train, y_train, start_lr=0.0001, end_lr=1, batch_size=512, epochs=5)
# Plot the loss, ignore 20 batches in the beginning and 5 in the end
lr_finder.plot_loss(n_skip_beginning=20, n_skip_end=5)
"""
def __init__(self, model):
self.model = model
self.losses = []
self.lrs = []
self.best_loss = 1e9
def on_batch_end(self, batch, logs):
# Log the learning rate
lr = K.get_value(self.model.optimizer.lr)
self.lrs.append(lr)
# Log the loss
loss = logs['loss']
self.losses.append(loss)
# Check whether the loss got too large or NaN
if batch > 5 and (math.isnan(loss) or loss > self.best_loss * 4):
self.model.stop_training = True
return
if loss < self.best_loss:
self.best_loss = loss
# Increase the learning rate for the next batch
lr *= self.lr_mult
K.set_value(self.model.optimizer.lr, lr)
def find(self, x_train, y_train, start_lr, end_lr, batch_size=64, epochs=1, **kw_fit):
# If x_train contains data for multiple inputs, use length of the first input.
# Assumption: the first element in the list is single input; NOT a list of inputs.
N = x_train[0].shape[0] if isinstance(x_train, list) else x_train.shape[0]
# Compute number of batches and LR multiplier
num_batches = epochs * N / batch_size
self.lr_mult = (float(end_lr) / float(start_lr)) ** (float(1) / float(num_batches))
# Save weights into a file
initial_weights = self.model.get_weights()
# Remember the original learning rate
original_lr = K.get_value(self.model.optimizer.lr)
# Set the initial learning rate
K.set_value(self.model.optimizer.lr, start_lr)
callback = LambdaCallback(on_batch_end=lambda batch, logs: self.on_batch_end(batch, logs))
self.model.fit(x_train, y_train,
batch_size=batch_size, epochs=epochs,
callbacks=[callback],
**kw_fit)
# Restore the weights to the state before model fitting
self.model.set_weights(initial_weights)
# Restore the original learning rate
K.set_value(self.model.optimizer.lr, original_lr)
def find_generator(self, generator, start_lr, end_lr, epochs=1, steps_per_epoch=None, **kw_fit):
""" steps_per_epoch: #mini-batches to use per epoch """
if steps_per_epoch is None:
try:
steps_per_epoch = len(generator)
except (ValueError, NotImplementedError) as e:
raise e('`steps_per_epoch=None` is only valid for a'
' generator based on the '
'`keras.utils.Sequence`'
' class. Please specify `steps_per_epoch` '
'or use the `keras.utils.Sequence` class.')
self.lr_mult = (float(end_lr) / float(start_lr)) ** (float(1) / float(epochs * steps_per_epoch))
# Save weights into a file
initial_weights = self.model.get_weights()
# Remember the original learning rate
original_lr = K.get_value(self.model.optimizer.lr)
# Set the initial learning rate
K.set_value(self.model.optimizer.lr, start_lr)
callback = LambdaCallback(on_batch_end=lambda batch, logs: self.on_batch_end(batch, logs))
self.model.fit(generator,
epochs=epochs,
steps_per_epoch=steps_per_epoch,
callbacks=[callback],
**kw_fit)
# Restore the weights to the state before model fitting
self.model.set_weights(initial_weights)
# Restore the original learning rate
K.set_value(self.model.optimizer.lr, original_lr)
def plot_loss(self, n_skip_beginning=10, n_skip_end=5, x_scale='log'):
"""
Plots the loss.
Parameters:
n_skip_beginning - number of batches to skip on the left.
n_skip_end - number of batches to skip on the right.
"""
plt.figure(figsize=(15, 5))
plt.ylabel("loss")
plt.xlabel("learning rate (log scale)")
plt.plot(self.lrs[n_skip_beginning:-n_skip_end], self.losses[n_skip_beginning:-n_skip_end])
plt.xscale(x_scale)
plt.show()
def plot_loss_change(self, sma=1, n_skip_beginning=10, n_skip_end=5, y_lim=(-0.01, 0.01)):
"""
Plots rate of change of the loss function.
Parameters:
sma - number of batches for simple moving average to smooth out the curve.
n_skip_beginning - number of batches to skip on the left.
n_skip_end - number of batches to skip on the right.
y_lim - limits for the y axis.
"""
derivatives = self.get_derivatives(sma)[n_skip_beginning:-n_skip_end]
lrs = self.lrs[n_skip_beginning:-n_skip_end]
plt.ylabel("rate of loss change")
plt.xlabel("learning rate (log scale)")
plt.plot(lrs, derivatives)
plt.xscale('log')
plt.ylim(y_lim)
plt.show()
def get_derivatives(self, sma):
assert sma >= 1
derivatives = [0] * sma
for i in range(sma, len(self.lrs)):
derivatives.append((self.losses[i] - self.losses[i - sma]) / sma)
return derivatives
def get_best_lr(self, sma, n_skip_beginning=10, n_skip_end=5):
derivatives = self.get_derivatives(sma)
best_der_idx = np.argmin(derivatives[n_skip_beginning:-n_skip_end])
return self.lrs[n_skip_beginning:-n_skip_end][best_der_idx]
def test_finder():
import myconf
import run_train
model = run_train.load_model(f'{myconf.MODELS_DIR}/model6_epoch2.h5')
data_dir = f'{myconf.EXP_HOME}/selfplay/enhance'
ds_train = run_train.load_selfplay_data(f'{data_dir}/train', 'full')
lrf = LRFinder(model)
lrf.find_generator(ds_train.shuffle(4000).batch(64), start_lr=0.0001, end_lr=1, steps_per_epoch=16573 // 64)