forked from tensorflow/minigo
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
147 lines (117 loc) · 4.53 KB
/
Copy pathutils.py
File metadata and controls
147 lines (117 loc) · 4.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Miscellaneous utilities"""
from contextlib import contextmanager
import functools
import itertools
import logging
import operator
import os
import re
import sys
import time
import datetime
import numpy as np
from typing import Iterator, List, Tuple
def dbg(*objects, file=sys.stderr, flush=True, **kwargs):
"Helper function to print to stderr and flush"
print(*objects, file=file, flush=flush, **kwargs)
def ensure_dir_exists(directory):
"Creates local directories if they don't exist."
if directory.startswith('gs://'):
return
if not os.path.exists(directory):
dbg("Making dir {}".format(directory))
os.makedirs(directory, exist_ok=True)
def parse_game_result(result):
"Parse an SGF result string into value target."
if re.match(r'[bB]\+', result):
return 1
if re.match(r'[wW]\+', result):
return -1
return 0
def format_game_summary(all_moves: List[str], result: str, first_n: int = 12, last_n: int = 2, sgf_fname=''):
open_moves = all_moves[: first_n]
end_moves = all_moves[-last_n:]
line = f'%s ..%3d .. %s \t%-6s' % (' '.join(open_moves), len(all_moves), ' '.join(end_moves), result)
line = line.replace('pass', '--', -1)
short_fname = os.path.basename(sgf_fname).removesuffix('.sgf')
return f'{line}\t{short_fname}'
def product(iterable):
"Like sum(), but with multiplication."
return functools.reduce(operator.mul, iterable)
def _take_n(num_things, iterable):
return list(itertools.islice(iterable, num_things))
def iter_chunks(chunk_size, iterator):
"Yield from an iterator in chunks of chunk_size."
iterator = iter(iterator)
while True:
next_chunk = _take_n(chunk_size, iterator)
# If len(iterable) % chunk_size == 0, don't return an empty chunk.
if next_chunk:
yield next_chunk
else:
break
def grouper(n, iterable: Iterator):
"""Itertools recipe
>>> list(grouper(3, iter('ABCDEFG')))
[['A', 'B', 'C'], ['D', 'E', 'F'], ['G']]
>>> list(grouper(iter(range(10)), 3)) # iter() important!
"""
return iter(lambda: list(itertools.islice(iterable, n)), [])
@contextmanager
def timer(message):
"Context manager for timing snippets of code."
tick = time.time()
yield
tock = time.time()
print("%s: %.3f seconds" % (message, (tock - tick)))
@contextmanager
def logged_timer(message):
"Context manager for timing snippets of code. Echos to logging module."
tick = time.time()
yield
tock = time.time()
logging.info("%s: %.3f seconds", message, (tock - tick))
def microseconds_since_midnight():
now = datetime.datetime.now()
tdelta = now - now.replace(hour=0, minute=0, second=0, microsecond=0)
return tdelta.seconds * 1000000 + tdelta.microseconds
def soft_pick(pi: np.ndarray, temperature=1.0, softpick_topn_cutoff: int = 0) -> int:
""" pick a move by sampling pi**temp
pi might be modified
"""
if softpick_topn_cutoff > 0:
nth = np.partition(pi.flatten(), -softpick_topn_cutoff)[-softpick_topn_cutoff]
pi[pi < nth] = 0
if temperature != 1.0:
pi = pi ** temperature
cdf = pi.cumsum()
if cdf[-2] > 1e-6:
cdf /= cdf[-2] # Prevents passing via softpick.
selection = np.random.random()
fcoord = cdf.searchsorted(selection)
return fcoord
assert False, f'soft_pick {pi} failed: {cdf[-2]}'
def choose_moves_with_probs(moves_with_probs: List[Tuple], temperature=1.0, softpick_topn_cutoff: int = 0, n: int = 1):
""" same as soft_pick, but with a top-moves list, sorted by probs """
if softpick_topn_cutoff > 0:
moves_with_probs = moves_with_probs[:softpick_topn_cutoff]
probs = np.array([x[1] for x in moves_with_probs])
if temperature != 1.0:
probs = probs ** temperature
probs = probs / probs.sum()
indices = np.random.choice(len(moves_with_probs), size=n, p=probs)
moves = [moves_with_probs[i][0] for i in indices]
return moves if n != 1 else moves[0]