forked from tensorflow/minigo
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmcts_player.h
More file actions
294 lines (233 loc) · 10.6 KB
/
Copy pathmcts_player.h
File metadata and controls
294 lines (233 loc) · 10.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef CC_MCTS_PLAYER_H_
#define CC_MCTS_PLAYER_H_
#include <cmath>
#include <cstdint>
#include <memory>
#include <ostream>
#include <string>
#include <vector>
#include "absl/memory/memory.h"
#include "absl/time/time.h"
#include "cc/algorithm.h"
#include "cc/constants.h"
#include "cc/game.h"
#include "cc/mcts_tree.h"
#include "cc/model/inference_cache.h"
#include "cc/model/model.h"
#include "cc/position.h"
#include "cc/random.h"
#include "cc/symmetries.h"
namespace minigo {
// Exposed for testing.
float TimeRecommendation(int move_num, float seconds_per_move, float time_limit,
float decay_factor);
class MctsPlayer {
public:
struct Options {
MctsTree::Options tree;
// If inject_noise is true, the amount of noise to mix into the root.
float noise_mix = 0.25;
bool inject_noise = true;
int virtual_losses = 8;
// Random seed & stream used for random permutations.
uint64_t random_seed = Random::kUniqueSeed;
// If true, flip & rotate the board features when performing inference. The
// symmetry chosen is psuedo-randomly chosen in a deterministic way based
// on the position itself and the random_seed.
bool random_symmetry = true;
// Number of readouts to perform (ignored if seconds_per_move is non-zero).
int num_readouts = 0;
// If non-zero, the number of seconds to spend thinking about each move
// instead of using a fixed number of readouts.
float seconds_per_move = 0;
// If non-zero, the maximum amount of time to spend thinking in a game:
// we spend seconds_per_move thinking for each move for as many moves as
// possible before exponentially decaying the amount of time.
float time_limit = 0;
// If time_limit is non-zero, the decay factor used to shorten the amount
// of time spent thinking as the game progresses.
float decay_factor = 0.98;
// "Playout Cap Oscillation" as per the KataGo paper.
// If fastplay_frequency > 0, tree search is modified as follows:
// - Each move is either a "low-readout" fast move, or a full, slow move.
// The percent of fast moves corresponds to "fastplay_frequency"
// - A "fast" move will:
// - Reuse the tree
// - Not mix noise in at root
// - Only perform 'fastplay_readouts' readouts.
// - Not be used as a training target.
// - A "slow" move will:
// - Clear the tree (*not* the cache).
// - Mix in dirichlet noise
// - Perform 'num_readouts' readouts.
// - Be noted in the Game object, to be written as a training example.
float fastplay_frequency = 0;
int fastplay_readouts = 20;
// "Target pruning" adjusts the targets after reading to discard reads
// caused by 'unhelpful' noise & reflect the 'better' understanding of the
// reward distribution. "False" == no pruning will be applied.
bool target_pruning = false;
friend std::ostream& operator<<(std::ostream& ios, const Options& options);
};
// Callback invoked on each batch of leaves expanded during tree search.
using TreeSearchCallback =
std::function<void(const std::vector<const MctsNode*>&)>;
// If position is non-null, the player will be initilized with that board
// state. Otherwise, the player is initialized with an empty board with black
// to play.
MctsPlayer(std::unique_ptr<Model> model,
std::shared_ptr<InferenceCache> inference_cache, Game* game,
const Options& options);
~MctsPlayer();
void InitializeGame(const Position& position);
void NewGame();
Coord SuggestMove(int new_readouts, bool inject_noise = false);
// Plays the move at point c.
// If game is non-null, adds a new move to the game's move history and sets
// the game over state if appropriate.
bool PlayMove(Coord c, bool is_trainable = false);
// Used in eval mode to update this player's tree in response to the
// opponent's move.
// TODO(tommadams): write a new eval binary similar to concurrent_eval so
// we can delete MctsPlayer.
void PlayOpponentsMove(Coord c);
// Moves the root_ node up to its parent, popping the last move off the game
// history but preserving the game tree.
bool UndoMove();
bool ShouldResign() const;
void SetTreeSearchCallback(TreeSearchCallback cb);
// TODO(tommadams): after changing MctsTree::PlayMove() to delete all
// ancestors of the new tree root, we can simply create a new tree instead of
// needing this method.
void ClearSubtrees() { tree_->ClearSubtrees(); }
// Returns a string containing the list of all models used for inference, and
// which moves they were used for.
std::string GetModelsUsedForInference() const;
// Returns the root of the current search tree, i.e. the current board state.
// TODO(tommadams): convert all callers to player->tree().root();
const MctsNode* root() const { return tree_->root(); }
const MctsTree& tree() const { return *tree_; }
const Options& options() const { return options_; }
const std::string& name() const { return model_->name(); }
Model* model() { return model_.get(); }
uint64_t seed() const { return rnd_.seed(); }
void SetOptions(const Options& options) { options_ = options; }
void TreeSearch(int num_leaves, int max_num_reads);
// Protected methods that get exposed for testing.
protected:
MctsTree* mutable_tree() { return tree_.get(); }
private:
// State that tracks which model is used for each inference.
struct InferenceInfo {
InferenceInfo(std::string model, int first_move)
: model(std::move(model)),
first_move(first_move),
last_move(first_move) {}
// Model name returned from RunMany.
std::string model;
// Total number of times a model was used for inference.
size_t total_count = 0;
// The first move a model was used for inference.
int first_move = 0;
// The last move a model was used for inference.
// This needs to be tracked separately from first_move because the common
// case is that the model changes part-way through a tree search.
int last_move = 0;
};
// A position's canonical symmetry is the symmetry that transforms the
// canonical form of a position into its actual form. For example, one way of
// defining a canonical symmetry is that the first move must be played in the
// top-right corner. For the early moves of a game, there will not be a
// canonical symmetry defined; in these cases, GetCanonicalSymmetry returns
// symmetry::kIdentity.
symmetry::Symmetry GetCanonicalSymmetry(const MctsNode* node) const {
return node->canonical_symmetry;
}
// Returns the symmetry that should be applied to this node's position when
// performing inference. The MctsPlayer picks a symmetry using a pseudo-random
// but deterministic function so that the same MctsPlayer instance is
// guaranteed to return the same symmetry for a given position but different
// MctsPlayer instances may return different symmetries for the same position.
symmetry::Symmetry GetInferenceSymmetry(const MctsNode* node) const {
if (options_.random_symmetry) {
uint64_t bits = Random::MixBits(
node->position.stone_hash() * Random::kLargePrime + inference_mix_);
return static_cast<symmetry::Symmetry>(bits % symmetry::kNumSymmetries);
} else {
return symmetry::kIdentity;
}
}
// Inject noise into the root node.
void InjectNoise(float dirichlet_alpha);
// Expand the root node if necessary.
// In order to correctly count the number of reads performed or to inject
// noise, the root node must be expanded. The root will always be expanded
// unless this is the first time SuggestMove has been called for a game, or
// PlayMove was called without a prior call to SuggestMove, or the child nodes
// of the tree have been cleared.
void MaybeExpandRoot();
// Select up to `num_leaves` leaves to perform inference on, storing the
// selected leaves in `tree_search_inferences_`. If the player has an
// inference cache, this can cause more nodes to be added to the tree when
// the selected leaves are already in the cache. To limit this, SelectLeaves
// will stop once the root has `max_num_reads`.
//
// In some positions, the model may favor one move so heavily that it
// overcomes the effects of virtual loss. In this case, SelectLeaves may
// choose the same leaf multiple times.
void SelectLeaves(int num_leaves, int max_num_reads);
// Run inference on the contents of `inferences_` that was previously
// populated by a call to SelectLeaves, and propagate the results back up the
// tree to the root.
void ProcessLeaves();
void UpdateGame(Coord c, bool is_trainable);
std::unique_ptr<Model> model_;
std::unique_ptr<MctsTree> tree_;
Game* game_;
Random rnd_;
Options options_;
// The name of the model used for inferences. In the case of ReloadingModel,
// this is different from the model's name: the model name is the pattern used
// to match each generation of model, while the inference model name is the
// path to the actual serialized model file.
std::string inference_model_;
std::vector<InferenceInfo> inferences_;
std::shared_ptr<InferenceCache> inference_cache_;
struct TreeSearchInference {
TreeSearchInference(InferenceCache::Key cache_key,
symmetry::Symmetry canonical_sym,
symmetry::Symmetry inference_sym, MctsNode* leaf)
: cache_key(cache_key),
canonical_sym(canonical_sym),
inference_sym(inference_sym),
leaf(leaf) {}
InferenceCache::Key cache_key;
symmetry::Symmetry canonical_sym;
symmetry::Symmetry inference_sym;
MctsNode* leaf;
ModelInput input;
ModelOutput output;
};
std::vector<TreeSearchInference> tree_search_inferences_;
std::vector<const ModelInput*> input_ptrs_;
std::vector<ModelOutput*> output_ptrs_;
TreeSearchCallback tree_search_cb_ = nullptr;
// Random number combined with each Position's Zobrist hash in order to
// deterministically choose the symmetry to apply when performing inference.
const int64_t inference_mix_;
};
} // namespace minigo
#endif // CC_MCTS_PLAYER_H_