2 # Copyright (c) 2020 Valve Corporation
4 # Permission is hereby granted, free of charge, to any person obtaining a
5 # copy of this software and associated documentation files (the "Software"),
6 # to deal in the Software without restriction, including without limitation
7 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 # and/or sell copies of the Software, and to permit persons to whom the
9 # Software is furnished to do so, subject to the following conditions:
11 # The above copyright notice and this permission notice (including the next
12 # paragraph) shall be included in all copies or substantial portions of the
15 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
28 from math
import floor
30 if os
.isatty(sys
.stdout
.fileno()):
32 set_green
= "\033[1;32m"
33 set_normal
= "\033[0m"
40 def insert_code(code):
41 insert_queue.append(CodeCheck(code))
43 def insert_pattern(pattern):
44 insert_queue.append(PatternCheck(pattern))
46 def vector_gpr(prefix, name, size, align):
47 insert_code(f'{name} = {name}0')
49 insert_code(f'{name}{i} = {name}0 + {i}')
50 insert_code(f'success = {name}0 + {size - 1} == {name}{size - 1}')
51 insert_code(f'success = {name}0 % {align} == 0')
52 return f'{prefix}[#{name}0:#{name}{size - 1}]'
54 def sgpr_vector(name, size, align):
55 return vector_gpr('s', name, size, align)
58 's64': lambda name: vector_gpr('s', name, 2, 2),
59 's96': lambda name: vector_gpr('s', name, 3, 2),
60 's128': lambda name: vector_gpr('s', name, 4, 4),
61 's256': lambda name: vector_gpr('s', name, 8, 4),
62 's512': lambda name: vector_gpr('s', name, 16, 4),
64 for i in range(2, 14):
65 funcs['v%d' % (i * 32)] = lambda name: vector_gpr('v', name, i, 1)
69 def __init__(self
, data
):
70 self
.data
= data
.rstrip()
75 class CodeCheck(Check
):
78 first_line
= [l
for l
in self
.data
.split('\n') if l
.strip() != ''][0]
79 indent_amount
= len(first_line
) - len(first_line
.lstrip())
80 indent
= first_line
[:indent_amount
]
82 for line
in self
.data
.split('\n'):
83 if line
.strip() == '':
86 if line
[:indent_amount
] != indent
:
87 state
.result
.log
+= 'unexpected indent in code check:\n'
88 state
.result
.log
+= self
.data
+ '\n'
90 new_lines
.append(line
[indent_amount
:])
91 code
= '\n'.join(new_lines
)
95 state
.result
.log
+= state
.g
['log']
97 except BaseException
as e
:
98 state
.result
.log
+= 'code check raised exception:\n'
99 state
.result
.log
+= code
+ '\n'
100 state
.result
.log
+= str(e
)
102 if not state
.g
['success']:
103 state
.result
.log
+= 'code check failed:\n'
104 state
.result
.log
+= code
+ '\n'
114 def __init__(self
, data
, name
):
118 self
.pos
= StringStream
.Pos()
122 self
.pos
= StringStream
.Pos()
124 def peek(self
, num
=1):
125 return self
.data
[self
.offset
:self
.offset
+num
]
127 def peek_test(self
, chars
):
129 return c
!= '' and c
in chars
131 def read(self
, num
=4294967296):
133 self
.offset
+= len(res
)
142 def get_line(self
, num
):
143 return self
.data
.split('\n')[num
- 1].rstrip()
146 while self
.peek(1) not in ['\n', '']:
150 def skip_whitespace(self
, inc_line
):
151 chars
= [' ', '\t'] + (['\n'] if inc_line
else [])
152 while self
.peek(1) in chars
:
155 def get_number(self
):
157 while self
.peek() in string
.digits
:
161 def check_identifier(self
):
162 return self
.peek_test(string
.ascii_letters
+ '_')
164 def get_identifier(self
):
166 if self
.check_identifier():
167 while self
.peek_test(string
.ascii_letters
+ string
.digits
+ '_'):
171 def format_error_lines(at
, line_num
, column_num
, ctx
, line
):
172 pred
= '%s line %d, column %d of %s: "' % (at
, line_num
, column_num
, ctx
)
173 return [pred
+ line
+ '"',
174 '-' * (column_num
- 1 + len(pred
)) + '^']
177 def __init__(self
, pattern
):
180 self
.pattern
= pattern
181 self
.pattern_pos
= StringStream
.Pos()
182 self
.output_pos
= StringStream
.Pos()
183 self
.fail_message
= ''
185 def set_pos(self
, pattern
, output
):
186 self
.pattern_pos
.line
= pattern
.pos
.line
187 self
.pattern_pos
.column
= pattern
.pos
.column
188 self
.output_pos
.line
= output
.pos
.line
189 self
.output_pos
.column
= output
.pos
.column
193 self
.fail_message
= msg
195 def format_pattern_pos(self
):
196 pat_pos
= self
.pattern_pos
197 pat_line
= self
.pattern
.get_line(pat_pos
.line
)
198 res
= format_error_lines('at', pat_pos
.line
, pat_pos
.column
, 'pattern', pat_line
)
199 func_res
= self
.func_res
201 pat_pos
= func_res
.pattern_pos
202 pat_line
= func_res
.pattern
.get_line(pat_pos
.line
)
203 res
+= format_error_lines('in', pat_pos
.line
, pat_pos
.column
, func_res
.pattern
.name
, pat_line
)
204 func_res
= func_res
.func_res
205 return '\n'.join(res
)
207 def do_match(g
, pattern
, output
, skip_lines
, in_func
=False):
208 assert(not in_func
or not skip_lines
)
211 output
.skip_whitespace(False)
212 pattern
.skip_whitespace(False)
215 old_g_keys
= list(g
.keys())
216 res
= MatchResult(pattern
)
219 res
.set_pos(pattern
, output
)
225 elif output
.peek() == '':
226 res
.fail('unexpected end of output')
231 old_line
= output
.pos
.line
232 output
.skip_whitespace(True)
233 if output
.pos
.line
== old_line
:
234 res
.fail('expected newline in output')
235 elif not escape
and c
== '#':
236 num
= output
.get_number()
238 res
.fail('expected number in output')
239 elif pattern
.check_identifier():
240 name
= pattern
.get_identifier()
241 if name
in g
and int(num
) != g
[name
]:
242 res
.fail('unexpected number for \'%s\': %d (expected %d)' % (name
, int(num
), g
[name
]))
245 elif not escape
and c
== '$':
246 name
= pattern
.get_identifier()
249 while not output
.peek_test(string
.whitespace
):
250 val
+= output
.read(1)
252 if name
in g
and val
!= g
[name
]:
253 res
.fail('unexpected value for \'%s\': \'%s\' (expected \'%s\')' % (name
, val
, g
[name
]))
256 elif not escape
and c
== '%' and pattern
.check_identifier():
257 if output
.read(1) != '%':
258 res
.fail('expected \'%\' in output')
260 num
= output
.get_number()
262 res
.fail('expected number in output')
264 name
= pattern
.get_identifier()
265 if name
in g
and int(num
) != g
[name
]:
266 res
.fail('unexpected number for \'%s\': %d (expected %d)' % (name
, int(num
), g
[name
]))
269 elif not escape
and c
== '@' and pattern
.check_identifier():
270 name
= pattern
.get_identifier()
272 if pattern
.peek_test('('):
274 while pattern
.peek() not in ['', ')']:
275 args
+= pattern
.read(1)
276 assert(pattern
.read(1) == ')')
277 func_res
= g
['funcs'][name
](args
)
278 match_res
= do_match(g
, StringStream(func_res
, 'expansion of "%s(%s)"' % (name
, args
)), output
, False, True)
279 if not match_res
.success
:
280 res
.func_res
= match_res
281 res
.output_pos
= match_res
.output_pos
282 res
.fail(match_res
.fail_message
)
283 elif not escape
and c
== ' ':
284 while pattern
.peek_test(' '):
287 read_whitespace
= False
288 while output
.peek_test(' \t'):
290 read_whitespace
= True
291 if not read_whitespace
:
292 res
.fail('expected whitespace in output, got %r' % (output
.peek(1)))
294 outc
= output
.peek(1)
296 res
.fail('expected %r in output, got %r' % (c
, outc
))
300 if skip_lines
and output
.peek() != '':
306 output
.skip_whitespace(False)
307 pattern
.skip_whitespace(False)
314 while output
.peek() in [' ', '\t']:
317 if output
.read(1) not in ['', '\n']:
318 res
.fail('expected end of output')
323 class PatternCheck(Check
):
324 def __init__(self
, data
, search
, position
):
325 Check
.__init
__(self
, data
)
327 self
.position
= position
329 def run(self
, state
):
330 pattern_stream
= StringStream(self
.data
.rstrip(), 'pattern')
331 res
= do_match(state
.g
, pattern_stream
, state
.g
['output'], self
.search
)
333 state
.result
.log
+= 'pattern at %s failed: %s\n' % (self
.position
, res
.fail_message
)
334 state
.result
.log
+= res
.format_pattern_pos() + '\n\n'
336 out_line
= state
.g
['output'].get_line(res
.output_pos
.line
)
337 state
.result
.log
+= '\n'.join(format_error_lines('at', res
.output_pos
.line
, res
.output_pos
.column
, 'output', out_line
))
339 state
.result
.log
+= 'output was:\n'
340 state
.result
.log
+= state
.g
['output'].data
.rstrip() + '\n'
345 def __init__(self
, result
, variant
, checks
, output
):
347 self
.variant
= variant
350 self
.checks
.insert(0, CodeCheck(initial_code
))
351 self
.insert_queue
= []
353 self
.g
= {'success': True, 'funcs': {}, 'insert_queue': self
.insert_queue
,
354 'variant': variant
, 'log': '', 'output': StringStream(output
, 'output'),
355 'CodeCheck': CodeCheck
, 'PatternCheck': PatternCheck
}
358 def __init__(self
, expected
):
360 self
.expected
= expected
363 def check_output(result
, variant
, checks
, output
):
364 state
= CheckState(result
, variant
, checks
, output
)
366 while len(state
.checks
):
367 check
= state
.checks
.pop(0)
368 if not check
.run(state
):
369 result
.result
= 'failed'
372 for check
in state
.insert_queue
[::-1]:
373 state
.checks
.insert(0, check
)
374 state
.insert_queue
.clear()
376 result
.result
= 'passed'
379 def parse_check(variant
, line
, checks
, pos
):
380 if line
.startswith(';'):
382 if len(checks
) and isinstance(checks
[-1], CodeCheck
):
383 checks
[-1].data
+= '\n' + line
385 checks
.append(CodeCheck(line
))
386 elif line
.startswith('!'):
387 checks
.append(PatternCheck(line
[1:], False, pos
))
388 elif line
.startswith('>>'):
389 checks
.append(PatternCheck(line
[2:], True, pos
))
390 elif line
.startswith('~'):
393 for c
in [';', '!', '>>']:
394 if line
.find(c
) != -1 and line
.find(c
) < end
:
397 match
= re
.match(line
[1:end
], variant
)
398 if match
and match
.end() == len(variant
):
399 parse_check(variant
, line
[end
:], checks
, pos
)
401 def parse_test_source(test_name
, variant
, fname
):
404 expected_result
= 'passed'
406 for line
in open(fname
, 'r').readlines():
407 if line
.startswith('BEGIN_TEST(%s)' % test_name
):
409 elif line
.startswith('BEGIN_TEST_TODO(%s)' % test_name
):
411 expected_result
= 'todo'
412 elif line
.startswith('BEGIN_TEST_FAIL(%s)' % test_name
):
414 expected_result
= 'failed'
415 elif line
.startswith('END_TEST'):
418 test
.append((line_num
, line
.strip()))
422 for line_num
, check
in [(line_num
, l
[2:]) for line_num
, l
in test
if l
.startswith('//')]:
423 parse_check(variant
, check
, checks
, 'line %d of %s' % (line_num
, os
.path
.split(fname
)[1]))
425 return checks
, expected_result
427 def parse_and_check_test(test_name
, variant
, test_file
, output
, current_result
):
428 checks
, expected
= parse_test_source(test_name
, variant
, test_file
)
430 result
= TestResult(expected
)
432 result
.result
= 'empty'
433 result
.log
= 'no checks found'
434 elif current_result
!= None:
435 result
.result
, result
.log
= current_result
437 check_output(result
, variant
, checks
, output
)
438 if result
.result
== 'failed' and expected
== 'todo':
439 result
.result
= 'todo'
443 def print_results(results
, output
, expected
):
444 results
= {name
: result
for name
, result
in results
.items() if result
.result
== output
}
445 results
= {name
: result
for name
, result
in results
.items() if (result
.result
== result
.expected
) == expected
}
450 print('%s tests (%s):' % (output
, 'expected' if expected
else 'unexpected'))
451 for test
, result
in results
.items():
452 color
= '' if expected
else set_red
453 print(' %s%s%s' % (color
, test
, set_normal
))
454 if result
.log
.strip() != '':
455 for line
in result
.log
.rstrip().split('\n'):
456 print(' ' + line
.rstrip())
466 return res
.decode('utf-8')
470 if __name__
== "__main__":
473 stdin
= sys
.stdin
.buffer
475 packet_type
= stdin
.read(4)
476 if packet_type
== b
'':
479 test_name
= get_cstr(stdin
)
480 test_variant
= get_cstr(stdin
)
481 if test_variant
!= '':
482 full_name
= test_name
+ '/' + test_variant
484 full_name
= test_name
486 test_source_file
= get_cstr(stdin
)
487 current_result
= None
488 if ord(stdin
.read(1)):
489 current_result
= (get_cstr(stdin
), get_cstr(stdin
))
490 code_size
= struct
.unpack("=L", stdin
.read(4))[0]
491 code
= stdin
.read(code_size
).decode('utf-8')
493 results
[full_name
] = parse_and_check_test(test_name
, test_variant
, test_source_file
, code
, current_result
)
495 result_types
= ['passed', 'failed', 'todo', 'empty']
498 for t
in result_types
:
499 num_expected
+= print_results(results
, t
, True)
500 for t
in result_types
:
501 num_unexpected
+= print_results(results
, t
, False)
502 num_expected_skipped
= print_results(results
, 'skipped', True)
503 num_unexpected_skipped
= print_results(results
, 'skipped', False)
505 num_unskipped
= len(results
) - num_expected_skipped
- num_unexpected_skipped
506 color
= set_red
if num_unexpected
else set_green
507 print('%s%d (%.0f%%) of %d unskipped tests had an expected result%s' % (color
, num_expected
, floor(num_expected
/ num_unskipped
* 100), num_unskipped
, set_normal
))
508 if num_unexpected_skipped
:
509 print('%s%d tests had been unexpectedly skipped%s' % (set_red
, num_unexpected_skipped
, set_normal
))