forked from tensorflow/minigo
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtf_utils.cc
More file actions
124 lines (103 loc) · 4.05 KB
/
Copy pathtf_utils.cc
File metadata and controls
124 lines (103 loc) · 4.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "cc/tf_utils.h"
#include <algorithm>
#include <array>
#include <memory>
#include "cc/constants.h"
#include "cc/dual_net/dual_net.h"
#include "cc/file/path.h"
#include "cc/file/utils.h"
#include "cc/mcts_player.h"
#include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/io/record_writer.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/file_system.h"
using tensorflow::io::RecordWriter;
using tensorflow::io::RecordWriterOptions;
namespace minigo {
namespace tf_utils {
namespace {
template <typename T, size_t N>
std::array<uint8_t, N> ConvertToBytes(const std::array<T, N>& src) {
std::array<uint8_t, N> dst;
std::copy(src.begin(), src.end(), dst.begin());
return dst;
}
template <typename T>
tensorflow::Feature MakeBytesFeature(const T& data) {
tensorflow::Feature feature;
feature.mutable_bytes_list()->add_value(
reinterpret_cast<const void*>(data.data()),
sizeof(typename T::value_type) * data.size());
return feature;
}
// Converts board features, and the pi & value outputs of MTCS to a tensorflow
// example proto.
tensorflow::Example MakeTfExample(const DualNet::BoardFeatures& features,
const std::array<float, kNumMoves>& pi,
float outcome) {
tensorflow::Example example;
auto& dst_features = *example.mutable_features()->mutable_feature();
// The input features are expected to be uint8 bytes.
dst_features["x"] = MakeBytesFeature(ConvertToBytes(features));
// pi is expected to be a float array serialized as bytes.
dst_features["pi"] = MakeBytesFeature(pi);
// outcome is a single float.
dst_features["outcome"].mutable_float_list()->add_value(outcome);
return example;
}
// Writes a list of tensorflow Example protos to a zlib compressed TFRecord
// file.
void WriteTfExamples(const std::string& path,
const std::vector<tensorflow::Example>& examples) {
std::unique_ptr<tensorflow::WritableFile> file;
TF_CHECK_OK(tensorflow::Env::Default()->NewWritableFile(path, &file));
RecordWriterOptions options;
options.compression_type = RecordWriterOptions::ZLIB_COMPRESSION;
RecordWriter writer(file.get(), options);
std::string data;
for (const auto& example : examples) {
example.SerializeToString(&data);
TF_CHECK_OK(writer.WriteRecord(data));
}
TF_CHECK_OK(writer.Close());
TF_CHECK_OK(file->Close());
}
} // namespace
std::vector<tensorflow::Example> MakeExamples(const MctsPlayer& player) {
// Write the TensorFlow examples.
std::vector<tensorflow::Example> examples;
examples.reserve(player.history().size());
DualNet::BoardFeatures features;
std::vector<const Position::Stones*> recent_positions;
for (const auto& h : player.history()) {
h.node->GetMoveHistory(DualNet::kMoveHistory, &recent_positions);
DualNet::SetFeatures(recent_positions, h.node->position.to_play(),
&features);
examples.push_back(MakeTfExample(features, h.search_pi, player.result()));
}
return examples;
}
void WriteGameExamples(const std::string& output_dir,
const std::string& output_name,
const MctsPlayer& player) {
MG_CHECK(file::RecursivelyCreateDir(output_dir));
auto output_path = file::JoinPath(output_dir, output_name + ".tfrecord.zz");
auto examples = MakeExamples(player);
WriteTfExamples(output_path, examples);
}
} // namespace tf_utils
} // namespace minigo