Use std::hash for API types (#6432)
[cvc5.git] / contrib / learn_resource_weights.py
1 #!/usr/bin/env python3
2
3 import argparse
4 import glob
5 import gzip
6 import json
7 import logging
8 import re
9 from sklearn import linear_model
10 import statistics
11
12
13 def parse_commandline():
14 """Parse commandline arguments"""
15 epilog = """
16 This script can be used to compute good resource weights based on benchmark
17 results. The resource weights are used by cvc5 to approximate the running time
18 by the spent resources, multiplied with their weights.
19
20 In the first stage ("parse") this script reads the output files of a benchmark
21 run as generated on our cluster. The output files are expected to be named
22 "*.smt2/output.log" and should contain the statistics (by use of "--stats").
23 The result is a gziped json file that contains all the relevant information
24 in a compact form.
25
26 In the second stage ("analyze") this script loads the gziped json file and uses
27 a linear regression model to learn resource weights. The resulting weights can
28 be used as constants for the resource options ("--*-step=n"). Additionally,
29 this script performs some analysis on the results to identify outliers where
30 the linear model performs particularly bad, i.e., the runtime estimation is way
31 off.
32 """
33 usage = """
34 first stage to parse the solver output:
35 %(prog)s parse <output directory>
36
37 second stage to learn resource weights:
38 %(prog)s analyze
39 """
40 parser = argparse.ArgumentParser(description='export and analyze resources from statistics',
41 formatter_class=argparse.ArgumentDefaultsHelpFormatter,
42 epilog=epilog,
43 usage=usage)
44 parser.add_argument('command', choices=[
45 'parse', 'analyze'], help='task to perform')
46 parser.add_argument('basedir', default=None, nargs='?',
47 help='path of benchmark results')
48 parser.add_argument('-v', '--verbose',
49 action='store_true', help='be more verbose')
50 parser.add_argument('--json', default='data.json.gz',
51 help='path of json file')
52 parser.add_argument('--threshold', metavar='SEC', type=int, default=1,
53 help='ignore benchmarks with a runtime below this threshold')
54 parser.add_argument('--mult', type=int, default=1000,
55 help='multiply running times with this factor for regression')
56
57 return parser.parse_args()
58
59
60 def load_zipped_json(filename):
61 """Load data from a gziped json file"""
62 with gzip.GzipFile(args.json, 'r') as fin:
63 return json.loads(fin.read().decode('utf-8'))
64
65
66 def save_zipped_json(filename, data):
67 """Store data to a gziped json file"""
68 with gzip.GzipFile(args.json, 'w') as fout:
69 fout.write(json.dumps(data).encode('utf-8'))
70
71
72 def get_sorted_values(data):
73 """Transform [['name', value], ...] to [value, ...] sorted by names"""
74 return [d[1] for d in sorted(data)]
75
76
77 def parse(args):
78 if args.basedir is None:
79 raise Exception('Specify basedir for parsing!')
80 filename_re = re.compile('(.*\\.smt2)/output\\.log')
81 resource_re = re.compile('resource::([^,]+), ([0-9]+)')
82 result_re = re.compile('driver::sat/unsat, ([a-z]+)')
83 totaltime_re = re.compile('driver::totalTime, ([0-9\\.]+)')
84
85 logging.info('Parsing files from {}'.format(args.basedir))
86 data = {}
87 failed = 0
88 for file in glob.iglob('{}/**/output.log'.format(args.basedir), recursive=True):
89 content = open(file).read()
90 try:
91 filename = filename_re.match(file).group(1)
92 r = resource_re.findall(content)
93 r = list(map(lambda x: (x[0], int(x[1])), r))
94 data[filename] = {
95 'resources': r,
96 'result': result_re.search(content).group(1),
97 'time': float(totaltime_re.search(content).group(1)),
98 }
99 except Exception as e:
100 logging.debug('Failed to parse {}: {}'.format(file, e))
101 failed += 1
102
103 if failed > 0:
104 logging.info('Failed to parse {} out of {} files'.format(
105 failed, failed + len(data)))
106 logging.info('Dumping data to {}'.format(args.json))
107 save_zipped_json(args.json, data)
108
109
110 def analyze(args):
111 logging.info('Loading data from {}'.format(args.json))
112 data = load_zipped_json(args.json)
113
114 logging.info('Extracting resources')
115 resources = set()
116 for f in data:
117 for r in data[f]['resources']:
118 resources.add(r[0])
119 resources = list(sorted(resources))
120
121 vals = {r: [] for r in resources}
122
123 logging.info('Collecting data from {} benchmarks'.format(len(data)))
124 x = []
125 y = []
126 for filename in data:
127 d = data[filename]
128 if d['time'] < args.threshold:
129 continue
130 x.append(get_sorted_values(d['resources']))
131 y.append(d['time'] * args.mult)
132
133 for r in d['resources']:
134 vals[r[0]].append(r[1])
135
136 logging.info('Training regression model')
137 clf = linear_model.LinearRegression()
138 r = clf.fit(x, y)
139 coeffs = zip(resources, r.coef_)
140 for c in sorted(coeffs, key=lambda c: c[1]):
141 minval = min(vals[c[0]])
142 maxval = max(vals[c[0]])
143 avgval = statistics.mean(vals[c[0]])
144 medval = statistics.median(vals[c[0]])
145 impact = c[1] * avgval
146 logging.info('{:23}-> {:15.10f}\t({} .. {:10}, avg {:9.2f}, med {:8}, impact {:7.3f})'.format(
147 *c, minval, maxval, avgval, medval, impact))
148
149 logging.info('Comparing regression model with reality')
150 outliers = {
151 'over-estimated': [],
152 'under-estimated': []
153 }
154 for filename in data:
155 d = data[filename]
156 actual = d['time']
157 if actual < args.threshold:
158 continue
159 vals = get_sorted_values(d['resources'])
160 predict = float(r.predict([vals])) / args.mult
161 outliers['over-estimated'].append([predict / actual, predict, actual, filename])
162 outliers['under-estimated'].append([actual / predict, predict, actual, filename])
163
164 for out in outliers:
165 logging.info('Showing outliers for {}'.format(out))
166 filtered = outliers[out]
167 for vals in sorted(filtered)[-5:]:
168 logging.info(
169 ' -> {:6.2f} ({:6.2f}, actual {:6.2f}): {}'.format(*vals))
170
171 cur = 0
172 gnuplot = open('plot.data', 'w')
173 for out in sorted(outliers['under-estimated']):
174 gnuplot.write('{}\t{}\n'.format(cur, out[0]))
175 cur += 1
176
177
178 if __name__ == "__main__":
179 logging.basicConfig(format='[%(levelname)s] %(message)s')
180 args = parse_commandline()
181 if args.verbose:
182 logging.getLogger().setLevel(level=logging.DEBUG)
183 else:
184 logging.getLogger().setLevel(level=logging.INFO)
185 if args.command == 'parse':
186 parse(args)
187 elif args.command == 'analyze':
188 analyze(args)