-
Notifications
You must be signed in to change notification settings - Fork 16
/
__main__.py
32 lines (24 loc) · 1 KB
/
__main__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import argparse
import fire
import logging
import sys
from datetime import datetime
from neural_nlp import score as score_function
_logger = logging.getLogger(__name__)
parser = argparse.ArgumentParser()
parser.add_argument('--log_level', type=str, default='INFO')
FLAGS, FIRE_FLAGS = parser.parse_known_args()
logging.basicConfig(stream=sys.stdout, level=logging.getLevelName(FLAGS.log_level))
_logger.info(f"Running with args {FLAGS}, {FIRE_FLAGS}")
for ignore_logger in ['transformers.data.processors', 'botocore', 'boto3', 'urllib3', 's3transfer']:
logging.getLogger(ignore_logger).setLevel(logging.INFO)
def run(benchmark, model, layers=None, subsample=None):
start = datetime.now()
score = score_function(model=model, layers=layers, subsample=subsample, benchmark=benchmark)
end = datetime.now()
print(score)
print(f"Duration: {end - start}")
if __name__ == '__main__':
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
fire.Fire(command=FIRE_FLAGS)