Clean up .pyc files.
[riscv-tests.git] / debug / testlib.py
1 import os.path
2 import re
3 import shlex
4 import subprocess
5 import sys
6 import time
7 import traceback
8
9 import pexpect
10
11 # Note that gdb comes with its own testsuite. I was unable to figure out how to
12 # run that testsuite against the spike simulator.
13
14 def find_file(path):
15 for directory in (os.getcwd(), os.path.dirname(__file__)):
16 fullpath = os.path.join(directory, path)
17 if os.path.exists(fullpath):
18 return fullpath
19 return None
20
21 def compile(args, xlen=32): # pylint: disable=redefined-builtin
22 cc = os.path.expandvars("$RISCV/bin/riscv%d-unknown-elf-gcc" % xlen)
23 cmd = [cc, "-g"]
24 for arg in args:
25 found = find_file(arg)
26 if found:
27 cmd.append(found)
28 else:
29 cmd.append(arg)
30 process = subprocess.Popen(cmd, stdout=subprocess.PIPE,
31 stderr=subprocess.PIPE)
32 stdout, stderr = process.communicate()
33 if process.returncode:
34 print
35 header("Compile failed")
36 print "+", " ".join(cmd)
37 print stdout,
38 print stderr,
39 header("")
40 raise Exception("Compile failed!")
41
42 def unused_port():
43 # http://stackoverflow.com/questions/2838244/get-open-tcp-port-in-python/2838309#2838309
44 import socket
45 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
46 s.bind(("", 0))
47 port = s.getsockname()[1]
48 s.close()
49 return port
50
51 class Spike(object):
52 logname = "spike.log"
53
54 def __init__(self, cmd, binary=None, halted=False, with_gdb=True,
55 timeout=None, xlen=64):
56 """Launch spike. Return tuple of its process and the port it's running
57 on."""
58 if cmd:
59 cmd = shlex.split(cmd)
60 else:
61 cmd = ["spike"]
62 if xlen == 32:
63 cmd += ["--isa", "RV32"]
64
65 if timeout:
66 cmd = ["timeout", str(timeout)] + cmd
67
68 if halted:
69 cmd.append('-H')
70 if with_gdb:
71 self.port = unused_port()
72 cmd += ['--gdb-port', str(self.port)]
73 cmd.append("-m32")
74 cmd.append('pk')
75 if binary:
76 cmd.append(binary)
77 logfile = open(self.logname, "w")
78 logfile.write("+ %s\n" % " ".join(cmd))
79 logfile.flush()
80 self.process = subprocess.Popen(cmd, stdin=subprocess.PIPE,
81 stdout=logfile, stderr=logfile)
82
83 def __del__(self):
84 try:
85 self.process.kill()
86 self.process.wait()
87 except OSError:
88 pass
89
90 def wait(self, *args, **kwargs):
91 return self.process.wait(*args, **kwargs)
92
93 class VcsSim(object):
94 def __init__(self, simv=None, debug=False):
95 if simv:
96 cmd = shlex.split(simv)
97 else:
98 cmd = ["simv"]
99 cmd += ["+jtag_vpi_enable"]
100 if debug:
101 cmd[0] = cmd[0] + "-debug"
102 cmd += ["+vcdplusfile=output/gdbserver.vpd"]
103 logfile = open("simv.log", "w")
104 logfile.write("+ %s\n" % " ".join(cmd))
105 logfile.flush()
106 listenfile = open("simv.log", "r")
107 listenfile.seek(0, 2)
108 self.process = subprocess.Popen(cmd, stdin=subprocess.PIPE,
109 stdout=logfile, stderr=logfile)
110 done = False
111 while not done:
112 line = listenfile.readline()
113 if not line:
114 time.sleep(1)
115 match = re.match(r"^Listening on port (\d+)$", line)
116 if match:
117 done = True
118 self.port = int(match.group(1))
119 os.environ['JTAG_VPI_PORT'] = str(self.port)
120
121 def __del__(self):
122 try:
123 self.process.kill()
124 self.process.wait()
125 except OSError:
126 pass
127
128 class Openocd(object):
129 logname = "openocd.log"
130
131 def __init__(self, cmd=None, config=None, debug=False):
132 if cmd:
133 cmd = shlex.split(cmd)
134 else:
135 cmd = ["openocd"]
136 if config:
137 cmd += ["-f", find_file(config)]
138 if debug:
139 cmd.append("-d")
140
141 # This command needs to come before any config scripts on the command
142 # line, since they are executed in order.
143 cmd[1:1] = [
144 # Tell OpenOCD to bind gdb to an unused, ephemeral port.
145 "--command",
146 "gdb_port 0",
147 # Disable tcl and telnet servers, since they are unused and because
148 # the port numbers will conflict if multiple OpenOCD processes are
149 # running on the same server.
150 "--command",
151 "tcl_port disabled",
152 "--command",
153 "telnet_port disabled",
154 ]
155
156 logfile = open(Openocd.logname, "w")
157 logfile.write("+ %s\n" % " ".join(cmd))
158 logfile.flush()
159 self.process = subprocess.Popen(cmd, stdin=subprocess.PIPE,
160 stdout=logfile, stderr=logfile)
161
162 # Wait for OpenOCD to have made it through riscv_examine(). When using
163 # OpenOCD to communicate with a simulator this may take a long time,
164 # and gdb will time out when trying to connect if we attempt too early.
165 start = time.time()
166 messaged = False
167 while True:
168 log = open(Openocd.logname).read()
169 if "Examined RISCV core" in log:
170 break
171 if not self.process.poll() is None:
172 raise Exception(
173 "OpenOCD exited before completing riscv_examine()")
174 if not messaged and time.time() - start > 1:
175 messaged = True
176 print "Waiting for OpenOCD to examine RISCV core..."
177
178 self.port = self._get_gdb_server_port()
179
180 def _get_gdb_server_port(self):
181 """Get port that OpenOCD's gdb server is listening on."""
182 MAX_ATTEMPTS = 50
183 PORT_REGEX = re.compile(r'(?P<port>\d+) \(LISTEN\)')
184 for _ in range(MAX_ATTEMPTS):
185 with open(os.devnull, 'w') as devnull:
186 try:
187 output = subprocess.check_output([
188 'lsof',
189 '-a', # Take the AND of the following selectors
190 '-p{}'.format(self.process.pid), # Filter on PID
191 '-iTCP', # Filter only TCP sockets
192 ], stderr=devnull)
193 except subprocess.CalledProcessError:
194 output = ""
195 matches = list(PORT_REGEX.finditer(output))
196 matches = [m for m in matches
197 if m.group('port') not in ('6666', '4444')]
198 if len(matches) > 1:
199 print output
200 raise Exception(
201 "OpenOCD listening on multiple ports. Cannot uniquely "
202 "identify gdb server port.")
203 elif matches:
204 [match] = matches
205 return int(match.group('port'))
206 time.sleep(0.1)
207 raise Exception("Timed out waiting for gdb server to obtain port.")
208
209 def __del__(self):
210 try:
211 self.process.kill()
212 self.process.wait()
213 except OSError:
214 pass
215
216 class OpenocdCli(object):
217 def __init__(self, port=4444):
218 self.child = pexpect.spawn(
219 "sh -c 'telnet localhost %d | tee openocd-cli.log'" % port)
220 self.child.expect("> ")
221
222 def command(self, cmd):
223 self.child.sendline(cmd)
224 self.child.expect(cmd)
225 self.child.expect("\n")
226 self.child.expect("> ")
227 return self.child.before.strip("\t\r\n \0")
228
229 def reg(self, reg=''):
230 output = self.command("reg %s" % reg)
231 matches = re.findall(r"(\w+) \(/\d+\): (0x[0-9A-F]+)", output)
232 values = {r: int(v, 0) for r, v in matches}
233 if reg:
234 return values[reg]
235 return values
236
237 def load_image(self, image):
238 output = self.command("load_image %s" % image)
239 if 'invalid ELF file, only 32bits files are supported' in output:
240 raise TestNotApplicable(output)
241
242 class CannotAccess(Exception):
243 def __init__(self, address):
244 Exception.__init__(self)
245 self.address = address
246
247 class Gdb(object):
248 def __init__(self,
249 cmd=os.path.expandvars("$RISCV/bin/riscv64-unknown-elf-gdb")):
250 self.child = pexpect.spawn(cmd)
251 self.child.logfile = open("gdb.log", "w")
252 self.child.logfile.write("+ %s\n" % cmd)
253 self.wait()
254 self.command("set confirm off")
255 self.command("set width 0")
256 self.command("set height 0")
257 # Force consistency.
258 self.command("set print entry-values no")
259
260 def wait(self):
261 """Wait for prompt."""
262 self.child.expect(r"\(gdb\)")
263
264 def command(self, command, timeout=-1):
265 self.child.sendline(command)
266 self.child.expect("\n", timeout=timeout)
267 self.child.expect(r"\(gdb\)", timeout=timeout)
268 return self.child.before.strip()
269
270 def c(self, wait=True):
271 if wait:
272 output = self.command("c")
273 assert "Continuing" in output
274 return output
275 else:
276 self.child.sendline("c")
277 self.child.expect("Continuing")
278
279 def interrupt(self):
280 self.child.send("\003")
281 self.child.expect(r"\(gdb\)", timeout=60)
282 return self.child.before.strip()
283
284 def x(self, address, size='w'):
285 output = self.command("x/%s %s" % (size, address))
286 value = int(output.split(':')[1].strip(), 0)
287 return value
288
289 def p_raw(self, obj):
290 output = self.command("p %s" % obj)
291 m = re.search("Cannot access memory at address (0x[0-9a-f]+)", output)
292 if m:
293 raise CannotAccess(int(m.group(1), 0))
294 return output.split('=')[-1].strip()
295
296 def p(self, obj):
297 output = self.command("p/x %s" % obj)
298 m = re.search("Cannot access memory at address (0x[0-9a-f]+)", output)
299 if m:
300 raise CannotAccess(int(m.group(1), 0))
301 value = int(output.split('=')[-1].strip(), 0)
302 return value
303
304 def p_string(self, obj):
305 output = self.command("p %s" % obj)
306 value = shlex.split(output.split('=')[-1].strip())[1]
307 return value
308
309 def stepi(self):
310 output = self.command("stepi")
311 return output
312
313 def load(self):
314 output = self.command("load", timeout=60)
315 assert "failed" not in output
316 assert "Transfer rate" in output
317
318 def b(self, location):
319 output = self.command("b %s" % location)
320 assert "not defined" not in output
321 assert "Breakpoint" in output
322 return output
323
324 def hbreak(self, location):
325 output = self.command("hbreak %s" % location)
326 assert "not defined" not in output
327 assert "Hardware assisted breakpoint" in output
328 return output
329
330 def run_all_tests(module, target, tests, fail_fast):
331 good_results = set(('pass', 'not_applicable'))
332
333 start = time.time()
334
335 results = {}
336 count = 0
337 for name in dir(module):
338 definition = getattr(module, name)
339 if type(definition) == type and hasattr(definition, 'test') and \
340 (not tests or any(test in name for test in tests)):
341 instance = definition(target)
342 result = instance.run()
343 results.setdefault(result, []).append(name)
344 count += 1
345 if result not in good_results and fail_fast:
346 break
347
348 header("ran %d tests in %.0fs" % (count, time.time() - start), dash=':')
349
350 result = 0
351 for key, value in results.iteritems():
352 print "%d tests returned %s" % (len(value), key)
353 if key not in good_results:
354 result = 1
355 for test in value:
356 print " ", test
357
358 return result
359
360 def add_test_run_options(parser):
361 parser.add_argument("--fail-fast", "-f", action="store_true",
362 help="Exit as soon as any test fails.")
363 parser.add_argument("test", nargs='*',
364 help="Run only tests that are named here.")
365
366 def header(title, dash='-'):
367 if title:
368 dashes = dash * (36 - len(title))
369 before = dashes[:len(dashes)/2]
370 after = dashes[len(dashes)/2:]
371 print "%s[ %s ]%s" % (before, title, after)
372 else:
373 print dash * 40
374
375 class BaseTest(object):
376 compiled = {}
377 logs = []
378
379 def __init__(self, target):
380 self.target = target
381 self.server = None
382 self.target_process = None
383 self.binary = None
384 self.start = 0
385
386 def early_applicable(self):
387 """Return a false value if the test has determined it cannot run
388 without ever needing to talk to the target or server."""
389 # pylint: disable=no-self-use
390 return True
391
392 def setup(self):
393 pass
394
395 def compile(self):
396 compile_args = getattr(self, 'compile_args', None)
397 if compile_args:
398 if compile_args not in BaseTest.compiled:
399 # pylint: disable=star-args
400 BaseTest.compiled[compile_args] = \
401 self.target.compile(*compile_args)
402 self.binary = BaseTest.compiled.get(compile_args)
403
404 def classSetup(self):
405 self.compile()
406 self.target_process = self.target.target()
407 self.server = self.target.server()
408 self.logs.append(self.server.logname)
409
410 def classTeardown(self):
411 del self.server
412 del self.target_process
413
414 def run(self):
415 """
416 If compile_args is set, compile a program and set self.binary.
417
418 Call setup().
419
420 Then call test() and return the result, displaying relevant information
421 if an exception is raised.
422 """
423
424 print "Running", type(self).__name__, "...",
425 sys.stdout.flush()
426
427 if not self.early_applicable():
428 print "not_applicable"
429 return "not_applicable"
430
431 self.start = time.time()
432
433 self.classSetup()
434
435 try:
436 self.setup()
437 result = self.test() # pylint: disable=no-member
438 except TestNotApplicable:
439 result = "not_applicable"
440 except Exception as e: # pylint: disable=broad-except
441 if isinstance(e, TestFailed):
442 result = "fail"
443 else:
444 result = "exception"
445 print "%s in %.2fs" % (result, time.time() - self.start)
446 print "=" * 40
447 if isinstance(e, TestFailed):
448 header("Message")
449 print e.message
450 header("Traceback")
451 traceback.print_exc(file=sys.stdout)
452 for log in self.logs:
453 header(log)
454 print open(log, "r").read()
455 print "/" * 40
456 return result
457
458 finally:
459 self.classTeardown()
460
461 if not result:
462 result = 'pass'
463 print "%s in %.2fs" % (result, time.time() - self.start)
464 return result
465
466 class TestFailed(Exception):
467 def __init__(self, message):
468 Exception.__init__(self)
469 self.message = message
470
471 class TestNotApplicable(Exception):
472 def __init__(self, message):
473 Exception.__init__(self)
474 self.message = message
475
476 def assertEqual(a, b):
477 if a != b:
478 raise TestFailed("%r != %r" % (a, b))
479
480 def assertNotEqual(a, b):
481 if a == b:
482 raise TestFailed("%r == %r" % (a, b))
483
484 def assertIn(a, b):
485 if a not in b:
486 raise TestFailed("%r not in %r" % (a, b))
487
488 def assertNotIn(a, b):
489 if a in b:
490 raise TestFailed("%r in %r" % (a, b))
491
492 def assertGreater(a, b):
493 if not a > b:
494 raise TestFailed("%r not greater than %r" % (a, b))
495
496 def assertLess(a, b):
497 if not a < b:
498 raise TestFailed("%r not less than %r" % (a, b))
499
500 def assertTrue(a):
501 if not a:
502 raise TestFailed("%r is not True" % a)
503
504 def assertRegexpMatches(text, regexp):
505 if not re.search(regexp, text):
506 raise TestFailed("can't find %r in %r" % (regexp, text))