Nicely display compile failures.
[riscv-tests.git] / debug / testlib.py
1 import os.path
2 import re
3 import shlex
4 import subprocess
5 import sys
6 import time
7 import traceback
8
9 import pexpect
10
11 # Note that gdb comes with its own testsuite. I was unable to figure out how to
12 # run that testsuite against the spike simulator.
13
14 def find_file(path):
15 for directory in (os.getcwd(), os.path.dirname(__file__)):
16 fullpath = os.path.join(directory, path)
17 if os.path.exists(fullpath):
18 return fullpath
19 return None
20
21 def compile(args, xlen=32): # pylint: disable=redefined-builtin
22 cc = os.path.expandvars("$RISCV/bin/riscv%d-unknown-elf-gcc" % xlen)
23 cmd = [cc, "-g"]
24 for arg in args:
25 found = find_file(arg)
26 if found:
27 cmd.append(found)
28 else:
29 cmd.append(arg)
30 process = subprocess.Popen(cmd, stdout=subprocess.PIPE,
31 stderr=subprocess.PIPE)
32 stdout, stderr = process.communicate()
33 if process.returncode:
34 print
35 header("Compile failed")
36 print "+", " ".join(cmd)
37 print stdout,
38 print stderr,
39 header("")
40 raise Exception("Compile failed!")
41
42 def unused_port():
43 # http://stackoverflow.com/questions/2838244/get-open-tcp-port-in-python/2838309#2838309
44 import socket
45 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
46 s.bind(("", 0))
47 port = s.getsockname()[1]
48 s.close()
49 return port
50
51 class Spike(object):
52 logname = "spike.log"
53
54 def __init__(self, cmd, binary=None, halted=False, with_gdb=True,
55 timeout=None, xlen=64):
56 """Launch spike. Return tuple of its process and the port it's running
57 on."""
58 if cmd:
59 cmd = shlex.split(cmd)
60 else:
61 cmd = ["spike"]
62 if xlen == 32:
63 cmd += ["--isa", "RV32"]
64
65 if timeout:
66 cmd = ["timeout", str(timeout)] + cmd
67
68 if halted:
69 cmd.append('-H')
70 if with_gdb:
71 self.port = unused_port()
72 cmd += ['--gdb-port', str(self.port)]
73 cmd.append("-m32")
74 cmd.append('pk')
75 if binary:
76 cmd.append(binary)
77 logfile = open(self.logname, "w")
78 logfile.write("+ %s\n" % " ".join(cmd))
79 logfile.flush()
80 self.process = subprocess.Popen(cmd, stdin=subprocess.PIPE,
81 stdout=logfile, stderr=logfile)
82
83 def __del__(self):
84 try:
85 self.process.kill()
86 self.process.wait()
87 except OSError:
88 pass
89
90 def wait(self, *args, **kwargs):
91 return self.process.wait(*args, **kwargs)
92
93 class VcsSim(object):
94 def __init__(self, simv=None, debug=False):
95 if simv:
96 cmd = shlex.split(simv)
97 else:
98 cmd = ["simv"]
99 cmd += ["+jtag_vpi_enable"]
100 if debug:
101 cmd[0] = cmd[0] + "-debug"
102 cmd += ["+vcdplusfile=output/gdbserver.vpd"]
103 logfile = open("simv.log", "w")
104 logfile.write("+ %s\n" % " ".join(cmd))
105 logfile.flush()
106 listenfile = open("simv.log", "r")
107 listenfile.seek(0, 2)
108 self.process = subprocess.Popen(cmd, stdin=subprocess.PIPE,
109 stdout=logfile, stderr=logfile)
110 done = False
111 while not done:
112 line = listenfile.readline()
113 if not line:
114 time.sleep(1)
115 match = re.match(r"^Listening on port (\d+)$", line)
116 if match:
117 done = True
118 self.port = int(match.group(1))
119 os.environ['JTAG_VPI_PORT'] = str(self.port)
120
121 def __del__(self):
122 try:
123 self.process.kill()
124 self.process.wait()
125 except OSError:
126 pass
127
128 class Openocd(object):
129 logname = "openocd.log"
130
131 def __init__(self, cmd=None, config=None, debug=False):
132 if cmd:
133 cmd = shlex.split(cmd)
134 else:
135 cmd = ["openocd"]
136 if config:
137 cmd += ["-f", find_file(config)]
138 if debug:
139 cmd.append("-d")
140
141 # This command needs to come before any config scripts on the command
142 # line, since they are executed in order.
143 # Tell OpenOCD to bind to an unused port.
144 cmd[1:1] = ["--command", "gdb_port %d" % 0]
145
146 logfile = open(Openocd.logname, "w")
147 logfile.write("+ %s\n" % " ".join(cmd))
148 logfile.flush()
149 self.process = subprocess.Popen(cmd, stdin=subprocess.PIPE,
150 stdout=logfile, stderr=logfile)
151
152 # Wait for OpenOCD to have made it through riscv_examine(). When using
153 # OpenOCD to communicate with a simulator this may take a long time,
154 # and gdb will time out when trying to connect if we attempt too early.
155 start = time.time()
156 messaged = False
157 while True:
158 log = open(Openocd.logname).read()
159 if "Examined RISCV core" in log:
160 break
161 if not self.process.poll() is None:
162 raise Exception(
163 "OpenOCD exited before completing riscv_examine()")
164 if not messaged and time.time() - start > 1:
165 messaged = True
166 print "Waiting for OpenOCD to examine RISCV core..."
167
168 self.port = self._get_gdb_server_port()
169
170 def _get_gdb_server_port(self):
171 """Get port that OpenOCD's gdb server is listening on."""
172 MAX_ATTEMPTS = 50
173 PORT_REGEX = re.compile(r'(?P<port>\d+) \(LISTEN\)')
174 for _ in range(MAX_ATTEMPTS):
175 with open(os.devnull, 'w') as devnull:
176 try:
177 output = subprocess.check_output([
178 'lsof',
179 '-a', # Take the AND of the following selectors
180 '-p{}'.format(self.process.pid), # Filter on PID
181 '-iTCP', # Filter only TCP sockets
182 ], stderr=devnull)
183 except subprocess.CalledProcessError:
184 output = ""
185 matches = list(PORT_REGEX.finditer(output))
186 matches = [m for m in matches
187 if m.group('port') not in ('6666', '4444')]
188 if len(matches) > 1:
189 print output
190 raise Exception(
191 "OpenOCD listening on multiple ports. Cannot uniquely "
192 "identify gdb server port.")
193 elif matches:
194 [match] = matches
195 return int(match.group('port'))
196 time.sleep(0.1)
197 raise Exception("Timed out waiting for gdb server to obtain port.")
198
199 def __del__(self):
200 try:
201 self.process.kill()
202 self.process.wait()
203 except OSError:
204 pass
205
206 class OpenocdCli(object):
207 def __init__(self, port=4444):
208 self.child = pexpect.spawn(
209 "sh -c 'telnet localhost %d | tee openocd-cli.log'" % port)
210 self.child.expect("> ")
211
212 def command(self, cmd):
213 self.child.sendline(cmd)
214 self.child.expect(cmd)
215 self.child.expect("\n")
216 self.child.expect("> ")
217 return self.child.before.strip("\t\r\n \0")
218
219 def reg(self, reg=''):
220 output = self.command("reg %s" % reg)
221 matches = re.findall(r"(\w+) \(/\d+\): (0x[0-9A-F]+)", output)
222 values = {r: int(v, 0) for r, v in matches}
223 if reg:
224 return values[reg]
225 return values
226
227 def load_image(self, image):
228 output = self.command("load_image %s" % image)
229 if 'invalid ELF file, only 32bits files are supported' in output:
230 raise TestNotApplicable(output)
231
232 class CannotAccess(Exception):
233 def __init__(self, address):
234 Exception.__init__(self)
235 self.address = address
236
237 class Gdb(object):
238 def __init__(self,
239 cmd=os.path.expandvars("$RISCV/bin/riscv64-unknown-elf-gdb")):
240 self.child = pexpect.spawn(cmd)
241 self.child.logfile = open("gdb.log", "w")
242 self.child.logfile.write("+ %s\n" % cmd)
243 self.wait()
244 self.command("set confirm off")
245 self.command("set width 0")
246 self.command("set height 0")
247 # Force consistency.
248 self.command("set print entry-values no")
249
250 def wait(self):
251 """Wait for prompt."""
252 self.child.expect(r"\(gdb\)")
253
254 def command(self, command, timeout=-1):
255 self.child.sendline(command)
256 self.child.expect("\n", timeout=timeout)
257 self.child.expect(r"\(gdb\)", timeout=timeout)
258 return self.child.before.strip()
259
260 def c(self, wait=True):
261 if wait:
262 output = self.command("c")
263 assert "Continuing" in output
264 return output
265 else:
266 self.child.sendline("c")
267 self.child.expect("Continuing")
268
269 def interrupt(self):
270 self.child.send("\003")
271 self.child.expect(r"\(gdb\)", timeout=60)
272 return self.child.before.strip()
273
274 def x(self, address, size='w'):
275 output = self.command("x/%s %s" % (size, address))
276 value = int(output.split(':')[1].strip(), 0)
277 return value
278
279 def p_raw(self, obj):
280 output = self.command("p %s" % obj)
281 m = re.search("Cannot access memory at address (0x[0-9a-f]+)", output)
282 if m:
283 raise CannotAccess(int(m.group(1), 0))
284 return output.split('=')[-1].strip()
285
286 def p(self, obj):
287 output = self.command("p/x %s" % obj)
288 m = re.search("Cannot access memory at address (0x[0-9a-f]+)", output)
289 if m:
290 raise CannotAccess(int(m.group(1), 0))
291 value = int(output.split('=')[-1].strip(), 0)
292 return value
293
294 def p_string(self, obj):
295 output = self.command("p %s" % obj)
296 value = shlex.split(output.split('=')[-1].strip())[1]
297 return value
298
299 def stepi(self):
300 output = self.command("stepi")
301 return output
302
303 def load(self):
304 output = self.command("load", timeout=60)
305 assert "failed" not in output
306 assert "Transfer rate" in output
307
308 def b(self, location):
309 output = self.command("b %s" % location)
310 assert "not defined" not in output
311 assert "Breakpoint" in output
312 return output
313
314 def hbreak(self, location):
315 output = self.command("hbreak %s" % location)
316 assert "not defined" not in output
317 assert "Hardware assisted breakpoint" in output
318 return output
319
320 def run_all_tests(module, target, tests, fail_fast):
321 good_results = set(('pass', 'not_applicable'))
322
323 start = time.time()
324
325 results = {}
326 count = 0
327 for name in dir(module):
328 definition = getattr(module, name)
329 if type(definition) == type and hasattr(definition, 'test') and \
330 (not tests or any(test in name for test in tests)):
331 instance = definition(target)
332 result = instance.run()
333 results.setdefault(result, []).append(name)
334 count += 1
335 if result not in good_results and fail_fast:
336 break
337
338 header("ran %d tests in %.0fs" % (count, time.time() - start), dash=':')
339
340 result = 0
341 for key, value in results.iteritems():
342 print "%d tests returned %s" % (len(value), key)
343 if key not in good_results:
344 result = 1
345 for test in value:
346 print " ", test
347
348 return result
349
350 def add_test_run_options(parser):
351 parser.add_argument("--fail-fast", "-f", action="store_true",
352 help="Exit as soon as any test fails.")
353 parser.add_argument("test", nargs='*',
354 help="Run only tests that are named here.")
355
356 def header(title, dash='-'):
357 if title:
358 dashes = dash * (36 - len(title))
359 before = dashes[:len(dashes)/2]
360 after = dashes[len(dashes)/2:]
361 print "%s[ %s ]%s" % (before, title, after)
362 else:
363 print dash * 40
364
365 class BaseTest(object):
366 compiled = {}
367 logs = []
368
369 def __init__(self, target):
370 self.target = target
371 self.server = None
372 self.target_process = None
373 self.binary = None
374 self.start = 0
375
376 def early_applicable(self):
377 """Return a false value if the test has determined it cannot run
378 without ever needing to talk to the target or server."""
379 # pylint: disable=no-self-use
380 return True
381
382 def setup(self):
383 pass
384
385 def compile(self):
386 compile_args = getattr(self, 'compile_args', None)
387 if compile_args:
388 if compile_args not in BaseTest.compiled:
389 # pylint: disable=star-args
390 BaseTest.compiled[compile_args] = \
391 self.target.compile(*compile_args)
392 self.binary = BaseTest.compiled.get(compile_args)
393
394 def classSetup(self):
395 self.compile()
396 self.target_process = self.target.target()
397 self.server = self.target.server()
398 self.logs.append(self.server.logname)
399
400 def classTeardown(self):
401 del self.server
402 del self.target_process
403
404 def run(self):
405 """
406 If compile_args is set, compile a program and set self.binary.
407
408 Call setup().
409
410 Then call test() and return the result, displaying relevant information
411 if an exception is raised.
412 """
413
414 print "Running", type(self).__name__, "...",
415 sys.stdout.flush()
416
417 if not self.early_applicable():
418 print "not_applicable"
419 return "not_applicable"
420
421 self.start = time.time()
422
423 self.classSetup()
424
425 try:
426 self.setup()
427 result = self.test() # pylint: disable=no-member
428 except TestNotApplicable:
429 result = "not_applicable"
430 except Exception as e: # pylint: disable=broad-except
431 if isinstance(e, TestFailed):
432 result = "fail"
433 else:
434 result = "exception"
435 print "%s in %.2fs" % (result, time.time() - self.start)
436 print "=" * 40
437 if isinstance(e, TestFailed):
438 header("Message")
439 print e.message
440 header("Traceback")
441 traceback.print_exc(file=sys.stdout)
442 for log in self.logs:
443 header(log)
444 print open(log, "r").read()
445 print "/" * 40
446 return result
447
448 finally:
449 self.classTeardown()
450
451 if not result:
452 result = 'pass'
453 print "%s in %.2fs" % (result, time.time() - self.start)
454 return result
455
456 class TestFailed(Exception):
457 def __init__(self, message):
458 Exception.__init__(self)
459 self.message = message
460
461 class TestNotApplicable(Exception):
462 def __init__(self, message):
463 Exception.__init__(self)
464 self.message = message
465
466 def assertEqual(a, b):
467 if a != b:
468 raise TestFailed("%r != %r" % (a, b))
469
470 def assertNotEqual(a, b):
471 if a == b:
472 raise TestFailed("%r == %r" % (a, b))
473
474 def assertIn(a, b):
475 if a not in b:
476 raise TestFailed("%r not in %r" % (a, b))
477
478 def assertNotIn(a, b):
479 if a in b:
480 raise TestFailed("%r in %r" % (a, b))
481
482 def assertGreater(a, b):
483 if not a > b:
484 raise TestFailed("%r not greater than %r" % (a, b))
485
486 def assertLess(a, b):
487 if not a < b:
488 raise TestFailed("%r not less than %r" % (a, b))
489
490 def assertTrue(a):
491 if not a:
492 raise TestFailed("%r is not True" % a)
493
494 def assertRegexpMatches(text, regexp):
495 if not re.search(regexp, text):
496 raise TestFailed("can't find %r in %r" % (regexp, text))