Merge pull request #35 from richardxia/have-openocd-pick-gdb-server-port
[riscv-tests.git] / debug / testlib.py
1 import os.path
2 import re
3 import shlex
4 import subprocess
5 import sys
6 import time
7 import traceback
8
9 import pexpect
10
11 # Note that gdb comes with its own testsuite. I was unable to figure out how to
12 # run that testsuite against the spike simulator.
13
14 def find_file(path):
15 for directory in (os.getcwd(), os.path.dirname(__file__)):
16 fullpath = os.path.join(directory, path)
17 if os.path.exists(fullpath):
18 return fullpath
19 return None
20
21 def compile(args, xlen=32): # pylint: disable=redefined-builtin
22 cc = os.path.expandvars("$RISCV/bin/riscv%d-unknown-elf-gcc" % xlen)
23 cmd = [cc, "-g"]
24 for arg in args:
25 found = find_file(arg)
26 if found:
27 cmd.append(found)
28 else:
29 cmd.append(arg)
30 cmd = " ".join(cmd)
31 result = os.system(cmd)
32 assert result == 0, "%r failed" % cmd
33
34 def unused_port():
35 # http://stackoverflow.com/questions/2838244/get-open-tcp-port-in-python/2838309#2838309
36 import socket
37 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
38 s.bind(("", 0))
39 port = s.getsockname()[1]
40 s.close()
41 return port
42
43 class Spike(object):
44 logname = "spike.log"
45
46 def __init__(self, cmd, binary=None, halted=False, with_gdb=True,
47 timeout=None, xlen=64):
48 """Launch spike. Return tuple of its process and the port it's running
49 on."""
50 if cmd:
51 cmd = shlex.split(cmd)
52 else:
53 cmd = ["spike"]
54 if xlen == 32:
55 cmd += ["--isa", "RV32"]
56
57 if timeout:
58 cmd = ["timeout", str(timeout)] + cmd
59
60 if halted:
61 cmd.append('-H')
62 if with_gdb:
63 self.port = unused_port()
64 cmd += ['--gdb-port', str(self.port)]
65 cmd.append("-m32")
66 cmd.append('pk')
67 if binary:
68 cmd.append(binary)
69 logfile = open(self.logname, "w")
70 logfile.write("+ %s\n" % " ".join(cmd))
71 logfile.flush()
72 self.process = subprocess.Popen(cmd, stdin=subprocess.PIPE,
73 stdout=logfile, stderr=logfile)
74
75 def __del__(self):
76 try:
77 self.process.kill()
78 self.process.wait()
79 except OSError:
80 pass
81
82 def wait(self, *args, **kwargs):
83 return self.process.wait(*args, **kwargs)
84
85 class VcsSim(object):
86 def __init__(self, simv=None, debug=False):
87 if simv:
88 cmd = shlex.split(simv)
89 else:
90 cmd = ["simv"]
91 cmd += ["+jtag_vpi_enable"]
92 if debug:
93 cmd[0] = cmd[0] + "-debug"
94 cmd += ["+vcdplusfile=output/gdbserver.vpd"]
95 logfile = open("simv.log", "w")
96 logfile.write("+ %s\n" % " ".join(cmd))
97 logfile.flush()
98 listenfile = open("simv.log", "r")
99 listenfile.seek(0, 2)
100 self.process = subprocess.Popen(cmd, stdin=subprocess.PIPE,
101 stdout=logfile, stderr=logfile)
102 done = False
103 while not done:
104 line = listenfile.readline()
105 if not line:
106 time.sleep(1)
107 match = re.match(r"^Listening on port (\d+)$", line)
108 if match:
109 done = True
110 self.port = int(match.group(1))
111 os.environ['JTAG_VPI_PORT'] = str(self.port)
112
113 def __del__(self):
114 try:
115 self.process.kill()
116 self.process.wait()
117 except OSError:
118 pass
119
120 class Openocd(object):
121 logname = "openocd.log"
122
123 def __init__(self, cmd=None, config=None, debug=False):
124 if cmd:
125 cmd = shlex.split(cmd)
126 else:
127 cmd = ["openocd"]
128 if config:
129 cmd += ["-f", find_file(config)]
130 if debug:
131 cmd.append("-d")
132
133 # This command needs to come before any config scripts on the command
134 # line, since they are executed in order.
135 # Tell OpenOCD to bind to an unused port.
136 cmd[1:1] = ["--command", "gdb_port %d" % 0]
137
138 logfile = open(Openocd.logname, "w")
139 logfile.write("+ %s\n" % " ".join(cmd))
140 logfile.flush()
141 self.process = subprocess.Popen(cmd, stdin=subprocess.PIPE,
142 stdout=logfile, stderr=logfile)
143
144 # Wait for OpenOCD to have made it through riscv_examine(). When using
145 # OpenOCD to communicate with a simulator this may take a long time,
146 # and gdb will time out when trying to connect if we attempt too early.
147 start = time.time()
148 messaged = False
149 while True:
150 log = open(Openocd.logname).read()
151 if "Examined RISCV core" in log:
152 break
153 if not self.process.poll() is None:
154 raise Exception(
155 "OpenOCD exited before completing riscv_examine()")
156 if not messaged and time.time() - start > 1:
157 messaged = True
158 print "Waiting for OpenOCD to examine RISCV core..."
159
160 self.port = self._get_gdb_server_port()
161
162 def _get_gdb_server_port(self):
163 """Get port that OpenOCD's gdb server is listening on."""
164 MAX_ATTEMPTS = 50
165 PORT_REGEX = re.compile(r'(?P<port>\d+) \(LISTEN\)')
166 for _ in range(MAX_ATTEMPTS):
167 with open(os.devnull, 'w') as devnull:
168 output = subprocess.check_output([
169 'lsof',
170 '-a', # Take the AND of the following selectors
171 '-p{}'.format(self.process.pid), # Filter on PID
172 '-iTCP', # Filter only TCP sockets
173 ], stderr=devnull)
174 matches = list(PORT_REGEX.finditer(output))
175 if len(matches) > 1:
176 raise Exception(
177 "OpenOCD listening on multiple ports. Cannot uniquely "
178 "identify gdb server port.")
179 elif matches:
180 [match] = matches
181 return int(match.group('port'))
182 time.sleep(0.1)
183 raise Exception("Timed out waiting for gdb server to obtain port.")
184
185 def __del__(self):
186 try:
187 self.process.kill()
188 self.process.wait()
189 except OSError:
190 pass
191
192 class OpenocdCli(object):
193 def __init__(self, port=4444):
194 self.child = pexpect.spawn(
195 "sh -c 'telnet localhost %d | tee openocd-cli.log'" % port)
196 self.child.expect("> ")
197
198 def command(self, cmd):
199 self.child.sendline(cmd)
200 self.child.expect(cmd)
201 self.child.expect("\n")
202 self.child.expect("> ")
203 return self.child.before.strip("\t\r\n \0")
204
205 def reg(self, reg=''):
206 output = self.command("reg %s" % reg)
207 matches = re.findall(r"(\w+) \(/\d+\): (0x[0-9A-F]+)", output)
208 values = {r: int(v, 0) for r, v in matches}
209 if reg:
210 return values[reg]
211 return values
212
213 def load_image(self, image):
214 output = self.command("load_image %s" % image)
215 if 'invalid ELF file, only 32bits files are supported' in output:
216 raise TestNotApplicable(output)
217
218 class CannotAccess(Exception):
219 def __init__(self, address):
220 Exception.__init__(self)
221 self.address = address
222
223 class Gdb(object):
224 def __init__(self,
225 cmd=os.path.expandvars("$RISCV/bin/riscv64-unknown-elf-gdb")):
226 self.child = pexpect.spawn(cmd)
227 self.child.logfile = open("gdb.log", "w")
228 self.child.logfile.write("+ %s\n" % cmd)
229 self.wait()
230 self.command("set confirm off")
231 self.command("set width 0")
232 self.command("set height 0")
233 # Force consistency.
234 self.command("set print entry-values no")
235
236 def wait(self):
237 """Wait for prompt."""
238 self.child.expect(r"\(gdb\)")
239
240 def command(self, command, timeout=-1):
241 self.child.sendline(command)
242 self.child.expect("\n", timeout=timeout)
243 self.child.expect(r"\(gdb\)", timeout=timeout)
244 return self.child.before.strip()
245
246 def c(self, wait=True):
247 if wait:
248 output = self.command("c")
249 assert "Continuing" in output
250 return output
251 else:
252 self.child.sendline("c")
253 self.child.expect("Continuing")
254
255 def interrupt(self):
256 self.child.send("\003")
257 self.child.expect(r"\(gdb\)", timeout=60)
258 return self.child.before.strip()
259
260 def x(self, address, size='w'):
261 output = self.command("x/%s %s" % (size, address))
262 value = int(output.split(':')[1].strip(), 0)
263 return value
264
265 def p_raw(self, obj):
266 output = self.command("p %s" % obj)
267 m = re.search("Cannot access memory at address (0x[0-9a-f]+)", output)
268 if m:
269 raise CannotAccess(int(m.group(1), 0))
270 return output.split('=')[-1].strip()
271
272 def p(self, obj):
273 output = self.command("p/x %s" % obj)
274 m = re.search("Cannot access memory at address (0x[0-9a-f]+)", output)
275 if m:
276 raise CannotAccess(int(m.group(1), 0))
277 value = int(output.split('=')[-1].strip(), 0)
278 return value
279
280 def p_string(self, obj):
281 output = self.command("p %s" % obj)
282 value = shlex.split(output.split('=')[-1].strip())[1]
283 return value
284
285 def stepi(self):
286 output = self.command("stepi")
287 return output
288
289 def load(self):
290 output = self.command("load", timeout=60)
291 assert "failed" not in output
292 assert "Transfer rate" in output
293
294 def b(self, location):
295 output = self.command("b %s" % location)
296 assert "not defined" not in output
297 assert "Breakpoint" in output
298 return output
299
300 def hbreak(self, location):
301 output = self.command("hbreak %s" % location)
302 assert "not defined" not in output
303 assert "Hardware assisted breakpoint" in output
304 return output
305
306 def run_all_tests(module, target, tests, fail_fast):
307 good_results = set(('pass', 'not_applicable'))
308
309 start = time.time()
310
311 results = {}
312 count = 0
313 for name in dir(module):
314 definition = getattr(module, name)
315 if type(definition) == type and hasattr(definition, 'test') and \
316 (not tests or any(test in name for test in tests)):
317 instance = definition(target)
318 result = instance.run()
319 results.setdefault(result, []).append(name)
320 count += 1
321 if result not in good_results and fail_fast:
322 break
323
324 header("ran %d tests in %.0fs" % (count, time.time() - start), dash=':')
325
326 result = 0
327 for key, value in results.iteritems():
328 print "%d tests returned %s" % (len(value), key)
329 if key not in good_results:
330 result = 1
331 for test in value:
332 print " ", test
333
334 return result
335
336 def add_test_run_options(parser):
337 parser.add_argument("--fail-fast", "-f", action="store_true",
338 help="Exit as soon as any test fails.")
339 parser.add_argument("test", nargs='*',
340 help="Run only tests that are named here.")
341
342 def header(title, dash='-'):
343 dashes = dash * (36 - len(title))
344 before = dashes[:len(dashes)/2]
345 after = dashes[len(dashes)/2:]
346 print "%s[ %s ]%s" % (before, title, after)
347
348 class BaseTest(object):
349 compiled = {}
350 logs = []
351
352 def __init__(self, target):
353 self.target = target
354 self.server = None
355 self.target_process = None
356 self.binary = None
357 self.start = 0
358
359 def early_applicable(self):
360 """Return a false value if the test has determined it cannot run
361 without ever needing to talk to the target or server."""
362 # pylint: disable=no-self-use
363 return True
364
365 def setup(self):
366 pass
367
368 def compile(self):
369 compile_args = getattr(self, 'compile_args', None)
370 if compile_args:
371 if compile_args not in BaseTest.compiled:
372 try:
373 # pylint: disable=star-args
374 BaseTest.compiled[compile_args] = \
375 self.target.compile(*compile_args)
376 except Exception: # pylint: disable=broad-except
377 print "exception while compiling in %.2fs" % (
378 time.time() - self.start)
379 print "=" * 40
380 header("Traceback")
381 traceback.print_exc(file=sys.stdout)
382 print "/" * 40
383 return "exception"
384 self.binary = BaseTest.compiled.get(compile_args)
385
386 def classSetup(self):
387 self.compile()
388 self.target_process = self.target.target()
389 self.server = self.target.server()
390 self.logs.append(self.server.logname)
391
392 def classTeardown(self):
393 del self.server
394 del self.target_process
395
396 def run(self):
397 """
398 If compile_args is set, compile a program and set self.binary.
399
400 Call setup().
401
402 Then call test() and return the result, displaying relevant information
403 if an exception is raised.
404 """
405
406 print "Running", type(self).__name__, "...",
407 sys.stdout.flush()
408
409 if not self.early_applicable():
410 print "not_applicable"
411 return "not_applicable"
412
413 self.start = time.time()
414
415 self.classSetup()
416
417 try:
418 self.setup()
419 result = self.test() # pylint: disable=no-member
420 except TestNotApplicable:
421 result = "not_applicable"
422 except Exception as e: # pylint: disable=broad-except
423 if isinstance(e, TestFailed):
424 result = "fail"
425 else:
426 result = "exception"
427 print "%s in %.2fs" % (result, time.time() - self.start)
428 print "=" * 40
429 if isinstance(e, TestFailed):
430 header("Message")
431 print e.message
432 header("Traceback")
433 traceback.print_exc(file=sys.stdout)
434 for log in self.logs:
435 header(log)
436 print open(log, "r").read()
437 print "/" * 40
438 return result
439
440 finally:
441 self.classTeardown()
442
443 if not result:
444 result = 'pass'
445 print "%s in %.2fs" % (result, time.time() - self.start)
446 return result
447
448 class TestFailed(Exception):
449 def __init__(self, message):
450 Exception.__init__(self)
451 self.message = message
452
453 class TestNotApplicable(Exception):
454 def __init__(self, message):
455 Exception.__init__(self)
456 self.message = message
457
458 def assertEqual(a, b):
459 if a != b:
460 raise TestFailed("%r != %r" % (a, b))
461
462 def assertNotEqual(a, b):
463 if a == b:
464 raise TestFailed("%r == %r" % (a, b))
465
466 def assertIn(a, b):
467 if a not in b:
468 raise TestFailed("%r not in %r" % (a, b))
469
470 def assertNotIn(a, b):
471 if a in b:
472 raise TestFailed("%r in %r" % (a, b))
473
474 def assertGreater(a, b):
475 if not a > b:
476 raise TestFailed("%r not greater than %r" % (a, b))
477
478 def assertLess(a, b):
479 if not a < b:
480 raise TestFailed("%r not less than %r" % (a, b))
481
482 def assertTrue(a):
483 if not a:
484 raise TestFailed("%r is not True" % a)
485
486 def assertRegexpMatches(text, regexp):
487 if not re.search(regexp, text):
488 raise TestFailed("can't find %r in %r" % (regexp, text))