aco: add framework for unit testing
[mesa.git] / src / amd / compiler / tests / check_output.py
1 #
2 # Copyright (c) 2020 Valve Corporation
3 #
4 # Permission is hereby granted, free of charge, to any person obtaining a
5 # copy of this software and associated documentation files (the "Software"),
6 # to deal in the Software without restriction, including without limitation
7 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 # and/or sell copies of the Software, and to permit persons to whom the
9 # Software is furnished to do so, subject to the following conditions:
10 #
11 # The above copyright notice and this permission notice (including the next
12 # paragraph) shall be included in all copies or substantial portions of the
13 # Software.
14 #
15 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 # IN THE SOFTWARE.
22 import re
23 import sys
24 import os.path
25 import struct
26 import string
27 import copy
28 from math import floor
29
30 if os.isatty(sys.stdout.fileno()):
31 set_red = "\033[31m"
32 set_green = "\033[1;32m"
33 set_normal = "\033[0m"
34 else:
35 set_red = ''
36 set_green = ''
37 set_normal = ''
38
39 initial_code = '''
40 def insert_code(code):
41 insert_queue.append(CodeCheck(code))
42
43 def insert_pattern(pattern):
44 insert_queue.append(PatternCheck(pattern))
45
46 def vector_gpr(prefix, name, size, align):
47 insert_code(f'{name} = {name}0')
48 for i in range(size):
49 insert_code(f'{name}{i} = {name}0 + {i}')
50 insert_code(f'success = {name}0 + {size - 1} == {name}{size - 1}')
51 insert_code(f'success = {name}0 % {align} == 0')
52 return f'{prefix}[#{name}0:#{name}{size - 1}]'
53
54 def sgpr_vector(name, size, align):
55 return vector_gpr('s', name, size, align)
56
57 funcs.update({
58 's64': lambda name: vector_gpr('s', name, 2, 2),
59 's96': lambda name: vector_gpr('s', name, 3, 2),
60 's128': lambda name: vector_gpr('s', name, 4, 4),
61 's256': lambda name: vector_gpr('s', name, 8, 4),
62 's512': lambda name: vector_gpr('s', name, 16, 4),
63 })
64 for i in range(2, 14):
65 funcs['v%d' % (i * 32)] = lambda name: vector_gpr('v', name, i, 1)
66 '''
67
68 class Check:
69 def __init__(self, data):
70 self.data = data.rstrip()
71
72 def run(self, state):
73 pass
74
75 class CodeCheck(Check):
76 def run(self, state):
77 indent = 0
78 first_line = [l for l in self.data.split('\n') if l.strip() != ''][0]
79 indent_amount = len(first_line) - len(first_line.lstrip())
80 indent = first_line[:indent_amount]
81 new_lines = []
82 for line in self.data.split('\n'):
83 if line.strip() == '':
84 new_lines.append('')
85 continue
86 if line[:indent_amount] != indent:
87 state.result.log += 'unexpected indent in code check:\n'
88 state.result.log += self.data + '\n'
89 return False
90 new_lines.append(line[indent_amount:])
91 code = '\n'.join(new_lines)
92
93 try:
94 exec(code, state.g)
95 state.result.log += state.g['log']
96 state.g['log'] = ''
97 except BaseException as e:
98 state.result.log += 'code check raised exception:\n'
99 state.result.log += code + '\n'
100 state.result.log += str(e)
101 return False
102 if not state.g['success']:
103 state.result.log += 'code check failed:\n'
104 state.result.log += code + '\n'
105 return False
106 return True
107
108 class StringStream:
109 class Pos:
110 def __init__(self):
111 self.line = 1
112 self.column = 1
113
114 def __init__(self, data, name):
115 self.name = name
116 self.data = data
117 self.offset = 0
118 self.pos = StringStream.Pos()
119
120 def reset(self):
121 self.offset = 0
122 self.pos = StringStream.Pos()
123
124 def peek(self, num=1):
125 return self.data[self.offset:self.offset+num]
126
127 def peek_test(self, chars):
128 c = self.peek(1)
129 return c != '' and c in chars
130
131 def read(self, num=4294967296):
132 res = self.peek(num)
133 self.offset += len(res)
134 for c in res:
135 if c == '\n':
136 self.pos.line += 1
137 self.pos.column = 1
138 else:
139 self.pos.column += 1
140 return res
141
142 def get_line(self, num):
143 return self.data.split('\n')[num - 1].rstrip()
144
145 def skip_line(self):
146 while self.peek(1) not in ['\n', '']:
147 self.read(1)
148 self.read(1)
149
150 def skip_whitespace(self, inc_line):
151 chars = [' ', '\t'] + (['\n'] if inc_line else [])
152 while self.peek(1) in chars:
153 self.read(1)
154
155 def get_number(self):
156 num = ''
157 while self.peek() in string.digits:
158 num += self.read(1)
159 return num
160
161 def check_identifier(self):
162 return self.peek_test(string.ascii_letters + '_')
163
164 def get_identifier(self):
165 res = ''
166 if self.check_identifier():
167 while self.peek_test(string.ascii_letters + string.digits + '_'):
168 res += self.read(1)
169 return res
170
171 def format_error_lines(at, line_num, column_num, ctx, line):
172 pred = '%s line %d, column %d of %s: "' % (at, line_num, column_num, ctx)
173 return [pred + line + '"',
174 '-' * (column_num - 1 + len(pred)) + '^']
175
176 class MatchResult:
177 def __init__(self, pattern):
178 self.success = True
179 self.func_res = None
180 self.pattern = pattern
181 self.pattern_pos = StringStream.Pos()
182 self.output_pos = StringStream.Pos()
183 self.fail_message = ''
184
185 def set_pos(self, pattern, output):
186 self.pattern_pos.line = pattern.pos.line
187 self.pattern_pos.column = pattern.pos.column
188 self.output_pos.line = output.pos.line
189 self.output_pos.column = output.pos.column
190
191 def fail(self, msg):
192 self.success = False
193 self.fail_message = msg
194
195 def format_pattern_pos(self):
196 pat_pos = self.pattern_pos
197 pat_line = self.pattern.get_line(pat_pos.line)
198 res = format_error_lines('at', pat_pos.line, pat_pos.column, 'pattern', pat_line)
199 func_res = self.func_res
200 while func_res:
201 pat_pos = func_res.pattern_pos
202 pat_line = func_res.pattern.get_line(pat_pos.line)
203 res += format_error_lines('in', pat_pos.line, pat_pos.column, func_res.pattern.name, pat_line)
204 func_res = func_res.func_res
205 return '\n'.join(res)
206
207 def do_match(g, pattern, output, skip_lines, in_func=False):
208 assert(not in_func or not skip_lines)
209
210 if not in_func:
211 output.skip_whitespace(False)
212 pattern.skip_whitespace(False)
213
214 old_g = copy.copy(g)
215 old_g_keys = list(g.keys())
216 res = MatchResult(pattern)
217 escape = False
218 while True:
219 res.set_pos(pattern, output)
220
221 c = pattern.read(1)
222 fail = False
223 if c == '':
224 break
225 elif output.peek() == '':
226 res.fail('unexpected end of output')
227 elif c == '\\':
228 escape = True
229 continue
230 elif c == '\n':
231 old_line = output.pos.line
232 output.skip_whitespace(True)
233 if output.pos.line == old_line:
234 res.fail('expected newline in output')
235 elif not escape and c == '#':
236 num = output.get_number()
237 if num == '':
238 res.fail('expected number in output')
239 elif pattern.check_identifier():
240 name = pattern.get_identifier()
241 if name in g and int(num) != g[name]:
242 res.fail('unexpected number for \'%s\': %d (expected %d)' % (name, int(num), g[name]))
243 elif name != '_':
244 g[name] = int(num)
245 elif not escape and c == '$':
246 name = pattern.get_identifier()
247
248 val = ''
249 while not output.peek_test(string.whitespace):
250 val += output.read(1)
251
252 if name in g and val != g[name]:
253 res.fail('unexpected value for \'%s\': \'%s\' (expected \'%s\')' % (name, val, g[name]))
254 elif name != '_':
255 g[name] = val
256 elif not escape and c == '%' and pattern.check_identifier():
257 if output.read(1) != '%':
258 res.fail('expected \'%\' in output')
259 else:
260 num = output.get_number()
261 if num == '':
262 res.fail('expected number in output')
263 else:
264 name = pattern.get_identifier()
265 if name in g and int(num) != g[name]:
266 res.fail('unexpected number for \'%s\': %d (expected %d)' % (name, int(num), g[name]))
267 elif name != '_':
268 g[name] = int(num)
269 elif not escape and c == '@' and pattern.check_identifier():
270 name = pattern.get_identifier()
271 args = ''
272 if pattern.peek_test('('):
273 pattern.read(1)
274 while pattern.peek() not in ['', ')']:
275 args += pattern.read(1)
276 assert(pattern.read(1) == ')')
277 func_res = g['funcs'][name](args)
278 match_res = do_match(g, StringStream(func_res, 'expansion of "%s(%s)"' % (name, args)), output, False, True)
279 if not match_res.success:
280 res.func_res = match_res
281 res.output_pos = match_res.output_pos
282 res.fail(match_res.fail_message)
283 elif not escape and c == ' ':
284 while pattern.peek_test(' '):
285 pattern.read(1)
286
287 read_whitespace = False
288 while output.peek_test(' \t'):
289 output.read(1)
290 read_whitespace = True
291 if not read_whitespace:
292 res.fail('expected whitespace in output, got %r' % (output.peek(1)))
293 else:
294 outc = output.peek(1)
295 if outc != c:
296 res.fail('expected %r in output, got %r' % (c, outc))
297 else:
298 output.read(1)
299 if not res.success:
300 if skip_lines and output.peek() != '':
301 g.clear()
302 g.update(old_g)
303 res.success = True
304 output.skip_line()
305 pattern.reset()
306 output.skip_whitespace(False)
307 pattern.skip_whitespace(False)
308 else:
309 return res
310
311 escape = False
312
313 if not in_func:
314 while output.peek() in [' ', '\t']:
315 output.read(1)
316
317 if output.read(1) not in ['', '\n']:
318 res.fail('expected end of output')
319 return res
320
321 return res
322
323 class PatternCheck(Check):
324 def __init__(self, data, search, position):
325 Check.__init__(self, data)
326 self.search = search
327 self.position = position
328
329 def run(self, state):
330 pattern_stream = StringStream(self.data.rstrip(), 'pattern')
331 res = do_match(state.g, pattern_stream, state.g['output'], self.search)
332 if not res.success:
333 state.result.log += 'pattern at %s failed: %s\n' % (self.position, res.fail_message)
334 state.result.log += res.format_pattern_pos() + '\n\n'
335 if not self.search:
336 out_line = state.g['output'].get_line(res.output_pos.line)
337 state.result.log += '\n'.join(format_error_lines('at', res.output_pos.line, res.output_pos.column, 'output', out_line))
338 else:
339 state.result.log += 'output was:\n'
340 state.result.log += state.g['output'].data.rstrip() + '\n'
341 return False
342 return True
343
344 class CheckState:
345 def __init__(self, result, variant, checks, output):
346 self.result = result
347 self.variant = variant
348 self.checks = checks
349
350 self.checks.insert(0, CodeCheck(initial_code))
351 self.insert_queue = []
352
353 self.g = {'success': True, 'funcs': {}, 'insert_queue': self.insert_queue,
354 'variant': variant, 'log': '', 'output': StringStream(output, 'output'),
355 'CodeCheck': CodeCheck, 'PatternCheck': PatternCheck}
356
357 class TestResult:
358 def __init__(self, expected):
359 self.result = ''
360 self.expected = expected
361 self.log = ''
362
363 def check_output(result, variant, checks, output):
364 state = CheckState(result, variant, checks, output)
365
366 while len(state.checks):
367 check = state.checks.pop(0)
368 if not check.run(state):
369 result.result = 'failed'
370 return
371
372 for check in state.insert_queue[::-1]:
373 state.checks.insert(0, check)
374 state.insert_queue.clear()
375
376 result.result = 'passed'
377 return
378
379 def parse_check(variant, line, checks, pos):
380 if line.startswith(';'):
381 line = line[1:]
382 if len(checks) and isinstance(checks[-1], CodeCheck):
383 checks[-1].data += '\n' + line
384 else:
385 checks.append(CodeCheck(line))
386 elif line.startswith('!'):
387 checks.append(PatternCheck(line[1:], False, pos))
388 elif line.startswith('>>'):
389 checks.append(PatternCheck(line[2:], True, pos))
390 elif line.startswith('~'):
391 end = len(line)
392 start = len(line)
393 for c in [';', '!', '>>']:
394 if line.find(c) != -1 and line.find(c) < end:
395 end = line.find(c)
396 if end != len(line):
397 match = re.match(line[1:end], variant)
398 if match and match.end() == len(variant):
399 parse_check(variant, line[end:], checks, pos)
400
401 def parse_test_source(test_name, variant, fname):
402 in_test = False
403 test = []
404 expected_result = 'passed'
405 line_num = 1
406 for line in open(fname, 'r').readlines():
407 if line.startswith('BEGIN_TEST(%s)' % test_name):
408 in_test = True
409 elif line.startswith('BEGIN_TEST_TODO(%s)' % test_name):
410 in_test = True
411 expected_result = 'todo'
412 elif line.startswith('BEGIN_TEST_FAIL(%s)' % test_name):
413 in_test = True
414 expected_result = 'failed'
415 elif line.startswith('END_TEST'):
416 in_test = False
417 elif in_test:
418 test.append((line_num, line.strip()))
419 line_num += 1
420
421 checks = []
422 for line_num, check in [(line_num, l[2:]) for line_num, l in test if l.startswith('//')]:
423 parse_check(variant, check, checks, 'line %d of %s' % (line_num, os.path.split(fname)[1]))
424
425 return checks, expected_result
426
427 def parse_and_check_test(test_name, variant, test_file, output, current_result):
428 checks, expected = parse_test_source(test_name, variant, test_file)
429
430 result = TestResult(expected)
431 if len(checks) == 0:
432 result.result = 'empty'
433 result.log = 'no checks found'
434 elif current_result != None:
435 result.result, result.log = current_result
436 else:
437 check_output(result, variant, checks, output)
438 if result.result == 'failed' and expected == 'todo':
439 result.result = 'todo'
440
441 return result
442
443 def print_results(results, output, expected):
444 results = {name: result for name, result in results.items() if result.result == output}
445 results = {name: result for name, result in results.items() if (result.result == result.expected) == expected}
446
447 if not results:
448 return 0
449
450 print('%s tests (%s):' % (output, 'expected' if expected else 'unexpected'))
451 for test, result in results.items():
452 color = '' if expected else set_red
453 print(' %s%s%s' % (color, test, set_normal))
454 if result.log.strip() != '':
455 for line in result.log.rstrip().split('\n'):
456 print(' ' + line.rstrip())
457 print('')
458
459 return len(results)
460
461 def get_cstr(fp):
462 res = b''
463 while True:
464 c = fp.read(1)
465 if c == b'\x00':
466 return res.decode('utf-8')
467 else:
468 res += c
469
470 if __name__ == "__main__":
471 results = {}
472
473 stdin = sys.stdin.buffer
474 while True:
475 packet_type = stdin.read(4)
476 if packet_type == b'':
477 break;
478
479 test_name = get_cstr(stdin)
480 test_variant = get_cstr(stdin)
481 if test_variant != '':
482 full_name = test_name + '/' + test_variant
483 else:
484 full_name = test_name
485
486 test_source_file = get_cstr(stdin)
487 current_result = None
488 if ord(stdin.read(1)):
489 current_result = (get_cstr(stdin), get_cstr(stdin))
490 code_size = struct.unpack("=L", stdin.read(4))[0]
491 code = stdin.read(code_size).decode('utf-8')
492
493 results[full_name] = parse_and_check_test(test_name, test_variant, test_source_file, code, current_result)
494
495 result_types = ['passed', 'failed', 'todo', 'empty']
496 num_expected = 0
497 num_unexpected = 0
498 for t in result_types:
499 num_expected += print_results(results, t, True)
500 for t in result_types:
501 num_unexpected += print_results(results, t, False)
502 num_expected_skipped = print_results(results, 'skipped', True)
503 num_unexpected_skipped = print_results(results, 'skipped', False)
504
505 num_unskipped = len(results) - num_expected_skipped - num_unexpected_skipped
506 color = set_red if num_unexpected else set_green
507 print('%s%d (%.0f%%) of %d unskipped tests had an expected result%s' % (color, num_expected, floor(num_expected / num_unskipped * 100), num_unskipped, set_normal))
508 if num_unexpected_skipped:
509 print('%s%d tests had been unexpectedly skipped%s' % (set_red, num_unexpected_skipped, set_normal))
510
511 if num_unexpected:
512 sys.exit(1)