6b01d8df68dad2256e483299974408cd255a2cd0
[riscv-tests.git] / debug / testlib.py
1 import os.path
2 import re
3 import shlex
4 import subprocess
5 import sys
6 import time
7 import traceback
8
9 import pexpect
10
11 # Note that gdb comes with its own testsuite. I was unable to figure out how to
12 # run that testsuite against the spike simulator.
13
14 def find_file(path):
15 for directory in (os.getcwd(), os.path.dirname(__file__)):
16 fullpath = os.path.join(directory, path)
17 if os.path.exists(fullpath):
18 return fullpath
19 return None
20
21 def compile(args, xlen=32): # pylint: disable=redefined-builtin
22 cc = os.path.expandvars("$RISCV/bin/riscv%d-unknown-elf-gcc" % xlen)
23 cmd = [cc, "-g"]
24 for arg in args:
25 found = find_file(arg)
26 if found:
27 cmd.append(found)
28 else:
29 cmd.append(arg)
30 cmd = " ".join(cmd)
31 result = os.system(cmd)
32 assert result == 0, "%r failed" % cmd
33
34 def unused_port():
35 # http://stackoverflow.com/questions/2838244/get-open-tcp-port-in-python/2838309#2838309
36 import socket
37 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
38 s.bind(("", 0))
39 port = s.getsockname()[1]
40 s.close()
41 return port
42
43 class Spike(object):
44 logname = "spike.log"
45
46 def __init__(self, cmd, binary=None, halted=False, with_gdb=True,
47 timeout=None, xlen=64):
48 """Launch spike. Return tuple of its process and the port it's running
49 on."""
50 if cmd:
51 cmd = shlex.split(cmd)
52 else:
53 cmd = ["spike"]
54 if xlen == 32:
55 cmd += ["--isa", "RV32"]
56
57 if timeout:
58 cmd = ["timeout", str(timeout)] + cmd
59
60 if halted:
61 cmd.append('-H')
62 if with_gdb:
63 self.port = unused_port()
64 cmd += ['--gdb-port', str(self.port)]
65 cmd.append("-m32")
66 cmd.append('pk')
67 if binary:
68 cmd.append(binary)
69 logfile = open(self.logname, "w")
70 logfile.write("+ %s\n" % " ".join(cmd))
71 logfile.flush()
72 self.process = subprocess.Popen(cmd, stdin=subprocess.PIPE,
73 stdout=logfile, stderr=logfile)
74
75 def __del__(self):
76 try:
77 self.process.kill()
78 self.process.wait()
79 except OSError:
80 pass
81
82 def wait(self, *args, **kwargs):
83 return self.process.wait(*args, **kwargs)
84
85 class VcsSim(object):
86 def __init__(self, simv=None, debug=False):
87 if simv:
88 cmd = shlex.split(simv)
89 else:
90 cmd = ["simv"]
91 cmd += ["+jtag_vpi_enable"]
92 if debug:
93 cmd[0] = cmd[0] + "-debug"
94 cmd += ["+vcdplusfile=output/gdbserver.vpd"]
95 logfile = open("simv.log", "w")
96 logfile.write("+ %s\n" % " ".join(cmd))
97 logfile.flush()
98 listenfile = open("simv.log", "r")
99 listenfile.seek(0, 2)
100 self.process = subprocess.Popen(cmd, stdin=subprocess.PIPE,
101 stdout=logfile, stderr=logfile)
102 done = False
103 while not done:
104 line = listenfile.readline()
105 if not line:
106 time.sleep(1)
107 match = re.match(r"^Listening on port (\d+)$", line)
108 if match:
109 done = True
110 self.port = int(match.group(1))
111 os.environ['JTAG_VPI_PORT'] = str(self.port)
112
113 def __del__(self):
114 try:
115 self.process.kill()
116 self.process.wait()
117 except OSError:
118 pass
119
120 class Openocd(object):
121 logname = "openocd.log"
122
123 def __init__(self, cmd=None, config=None, debug=False):
124 if cmd:
125 cmd = shlex.split(cmd)
126 else:
127 cmd = ["openocd"]
128 if config:
129 cmd += ["-f", find_file(config)]
130 if debug:
131 cmd.append("-d")
132
133 # This command needs to come before any config scripts on the command
134 # line, since they are executed in order.
135 # Tell OpenOCD to bind to an unused port.
136 cmd[1:1] = ["--command", "gdb_port %d" % 0]
137
138 logfile = open(Openocd.logname, "w")
139 logfile.write("+ %s\n" % " ".join(cmd))
140 logfile.flush()
141 self.process = subprocess.Popen(cmd, stdin=subprocess.PIPE,
142 stdout=logfile, stderr=logfile)
143
144 # Wait for OpenOCD to have made it through riscv_examine(). When using
145 # OpenOCD to communicate with a simulator this may take a long time,
146 # and gdb will time out when trying to connect if we attempt too early.
147 start = time.time()
148 messaged = False
149 while True:
150 log = open(Openocd.logname).read()
151 if "Examined RISCV core" in log:
152 break
153 if not self.process.poll() is None:
154 raise Exception(
155 "OpenOCD exited before completing riscv_examine()")
156 if not messaged and time.time() - start > 1:
157 messaged = True
158 print "Waiting for OpenOCD to examine RISCV core..."
159
160 self.port = self._get_gdb_server_port()
161
162 def _get_gdb_server_port(self):
163 """Get port that OpenOCD's gdb server is listening on."""
164 MAX_ATTEMPTS = 50
165 PORT_REGEX = re.compile(r'(?P<port>\d+) \(LISTEN\)')
166 for _ in range(MAX_ATTEMPTS):
167 with open(os.devnull, 'w') as devnull:
168 try:
169 output = subprocess.check_output([
170 'lsof',
171 '-a', # Take the AND of the following selectors
172 '-p{}'.format(self.process.pid), # Filter on PID
173 '-iTCP', # Filter only TCP sockets
174 ], stderr=devnull)
175 except subprocess.CalledProcessError:
176 output = ""
177 matches = list(PORT_REGEX.finditer(output))
178 matches = [m for m in matches if m.group('port') not in ('6666', '4444')]
179 if len(matches) > 1:
180 print output
181 raise Exception(
182 "OpenOCD listening on multiple ports. Cannot uniquely "
183 "identify gdb server port.")
184 elif matches:
185 [match] = matches
186 return int(match.group('port'))
187 time.sleep(0.1)
188 raise Exception("Timed out waiting for gdb server to obtain port.")
189
190 def __del__(self):
191 try:
192 self.process.kill()
193 self.process.wait()
194 except OSError:
195 pass
196
197 class OpenocdCli(object):
198 def __init__(self, port=4444):
199 self.child = pexpect.spawn(
200 "sh -c 'telnet localhost %d | tee openocd-cli.log'" % port)
201 self.child.expect("> ")
202
203 def command(self, cmd):
204 self.child.sendline(cmd)
205 self.child.expect(cmd)
206 self.child.expect("\n")
207 self.child.expect("> ")
208 return self.child.before.strip("\t\r\n \0")
209
210 def reg(self, reg=''):
211 output = self.command("reg %s" % reg)
212 matches = re.findall(r"(\w+) \(/\d+\): (0x[0-9A-F]+)", output)
213 values = {r: int(v, 0) for r, v in matches}
214 if reg:
215 return values[reg]
216 return values
217
218 def load_image(self, image):
219 output = self.command("load_image %s" % image)
220 if 'invalid ELF file, only 32bits files are supported' in output:
221 raise TestNotApplicable(output)
222
223 class CannotAccess(Exception):
224 def __init__(self, address):
225 Exception.__init__(self)
226 self.address = address
227
228 class Gdb(object):
229 def __init__(self,
230 cmd=os.path.expandvars("$RISCV/bin/riscv64-unknown-elf-gdb")):
231 self.child = pexpect.spawn(cmd)
232 self.child.logfile = open("gdb.log", "w")
233 self.child.logfile.write("+ %s\n" % cmd)
234 self.wait()
235 self.command("set confirm off")
236 self.command("set width 0")
237 self.command("set height 0")
238 # Force consistency.
239 self.command("set print entry-values no")
240
241 def wait(self):
242 """Wait for prompt."""
243 self.child.expect(r"\(gdb\)")
244
245 def command(self, command, timeout=-1):
246 self.child.sendline(command)
247 self.child.expect("\n", timeout=timeout)
248 self.child.expect(r"\(gdb\)", timeout=timeout)
249 return self.child.before.strip()
250
251 def c(self, wait=True):
252 if wait:
253 output = self.command("c")
254 assert "Continuing" in output
255 return output
256 else:
257 self.child.sendline("c")
258 self.child.expect("Continuing")
259
260 def interrupt(self):
261 self.child.send("\003")
262 self.child.expect(r"\(gdb\)", timeout=60)
263 return self.child.before.strip()
264
265 def x(self, address, size='w'):
266 output = self.command("x/%s %s" % (size, address))
267 value = int(output.split(':')[1].strip(), 0)
268 return value
269
270 def p_raw(self, obj):
271 output = self.command("p %s" % obj)
272 m = re.search("Cannot access memory at address (0x[0-9a-f]+)", output)
273 if m:
274 raise CannotAccess(int(m.group(1), 0))
275 return output.split('=')[-1].strip()
276
277 def p(self, obj):
278 output = self.command("p/x %s" % obj)
279 m = re.search("Cannot access memory at address (0x[0-9a-f]+)", output)
280 if m:
281 raise CannotAccess(int(m.group(1), 0))
282 value = int(output.split('=')[-1].strip(), 0)
283 return value
284
285 def p_string(self, obj):
286 output = self.command("p %s" % obj)
287 value = shlex.split(output.split('=')[-1].strip())[1]
288 return value
289
290 def stepi(self):
291 output = self.command("stepi")
292 return output
293
294 def load(self):
295 output = self.command("load", timeout=60)
296 assert "failed" not in output
297 assert "Transfer rate" in output
298
299 def b(self, location):
300 output = self.command("b %s" % location)
301 assert "not defined" not in output
302 assert "Breakpoint" in output
303 return output
304
305 def hbreak(self, location):
306 output = self.command("hbreak %s" % location)
307 assert "not defined" not in output
308 assert "Hardware assisted breakpoint" in output
309 return output
310
311 def run_all_tests(module, target, tests, fail_fast):
312 good_results = set(('pass', 'not_applicable'))
313
314 start = time.time()
315
316 results = {}
317 count = 0
318 for name in dir(module):
319 definition = getattr(module, name)
320 if type(definition) == type and hasattr(definition, 'test') and \
321 (not tests or any(test in name for test in tests)):
322 instance = definition(target)
323 result = instance.run()
324 results.setdefault(result, []).append(name)
325 count += 1
326 if result not in good_results and fail_fast:
327 break
328
329 header("ran %d tests in %.0fs" % (count, time.time() - start), dash=':')
330
331 result = 0
332 for key, value in results.iteritems():
333 print "%d tests returned %s" % (len(value), key)
334 if key not in good_results:
335 result = 1
336 for test in value:
337 print " ", test
338
339 return result
340
341 def add_test_run_options(parser):
342 parser.add_argument("--fail-fast", "-f", action="store_true",
343 help="Exit as soon as any test fails.")
344 parser.add_argument("test", nargs='*',
345 help="Run only tests that are named here.")
346
347 def header(title, dash='-'):
348 dashes = dash * (36 - len(title))
349 before = dashes[:len(dashes)/2]
350 after = dashes[len(dashes)/2:]
351 print "%s[ %s ]%s" % (before, title, after)
352
353 class BaseTest(object):
354 compiled = {}
355 logs = []
356
357 def __init__(self, target):
358 self.target = target
359 self.server = None
360 self.target_process = None
361 self.binary = None
362 self.start = 0
363
364 def early_applicable(self):
365 """Return a false value if the test has determined it cannot run
366 without ever needing to talk to the target or server."""
367 # pylint: disable=no-self-use
368 return True
369
370 def setup(self):
371 pass
372
373 def compile(self):
374 compile_args = getattr(self, 'compile_args', None)
375 if compile_args:
376 if compile_args not in BaseTest.compiled:
377 try:
378 # pylint: disable=star-args
379 BaseTest.compiled[compile_args] = \
380 self.target.compile(*compile_args)
381 except Exception: # pylint: disable=broad-except
382 print "exception while compiling in %.2fs" % (
383 time.time() - self.start)
384 print "=" * 40
385 header("Traceback")
386 traceback.print_exc(file=sys.stdout)
387 print "/" * 40
388 return "exception"
389 self.binary = BaseTest.compiled.get(compile_args)
390
391 def classSetup(self):
392 self.compile()
393 self.target_process = self.target.target()
394 self.server = self.target.server()
395 self.logs.append(self.server.logname)
396
397 def classTeardown(self):
398 del self.server
399 del self.target_process
400
401 def run(self):
402 """
403 If compile_args is set, compile a program and set self.binary.
404
405 Call setup().
406
407 Then call test() and return the result, displaying relevant information
408 if an exception is raised.
409 """
410
411 print "Running", type(self).__name__, "...",
412 sys.stdout.flush()
413
414 if not self.early_applicable():
415 print "not_applicable"
416 return "not_applicable"
417
418 self.start = time.time()
419
420 self.classSetup()
421
422 try:
423 self.setup()
424 result = self.test() # pylint: disable=no-member
425 except TestNotApplicable:
426 result = "not_applicable"
427 except Exception as e: # pylint: disable=broad-except
428 if isinstance(e, TestFailed):
429 result = "fail"
430 else:
431 result = "exception"
432 print "%s in %.2fs" % (result, time.time() - self.start)
433 print "=" * 40
434 if isinstance(e, TestFailed):
435 header("Message")
436 print e.message
437 header("Traceback")
438 traceback.print_exc(file=sys.stdout)
439 for log in self.logs:
440 header(log)
441 print open(log, "r").read()
442 print "/" * 40
443 return result
444
445 finally:
446 self.classTeardown()
447
448 if not result:
449 result = 'pass'
450 print "%s in %.2fs" % (result, time.time() - self.start)
451 return result
452
453 class TestFailed(Exception):
454 def __init__(self, message):
455 Exception.__init__(self)
456 self.message = message
457
458 class TestNotApplicable(Exception):
459 def __init__(self, message):
460 Exception.__init__(self)
461 self.message = message
462
463 def assertEqual(a, b):
464 if a != b:
465 raise TestFailed("%r != %r" % (a, b))
466
467 def assertNotEqual(a, b):
468 if a == b:
469 raise TestFailed("%r == %r" % (a, b))
470
471 def assertIn(a, b):
472 if a not in b:
473 raise TestFailed("%r not in %r" % (a, b))
474
475 def assertNotIn(a, b):
476 if a in b:
477 raise TestFailed("%r in %r" % (a, b))
478
479 def assertGreater(a, b):
480 if not a > b:
481 raise TestFailed("%r not greater than %r" % (a, b))
482
483 def assertLess(a, b):
484 if not a < b:
485 raise TestFailed("%r not less than %r" % (a, b))
486
487 def assertTrue(a):
488 if not a:
489 raise TestFailed("%r is not True" % a)
490
491 def assertRegexpMatches(text, regexp):
492 if not re.search(regexp, text):
493 raise TestFailed("can't find %r in %r" % (regexp, text))