76989806b8ff57278c341c0776a3bb754dd4384e
[riscv-tests.git] / debug / testlib.py
1 import os.path
2 import re
3 import shlex
4 import subprocess
5 import sys
6 import time
7 import traceback
8
9 import pexpect
10
11 # Note that gdb comes with its own testsuite. I was unable to figure out how to
12 # run that testsuite against the spike simulator.
13
14 def find_file(path):
15 for directory in (os.getcwd(), os.path.dirname(__file__)):
16 fullpath = os.path.join(directory, path)
17 if os.path.exists(fullpath):
18 return fullpath
19 return None
20
21 def compile(args, xlen=32): # pylint: disable=redefined-builtin
22 cc = os.path.expandvars("$RISCV/bin/riscv%d-unknown-elf-gcc" % xlen)
23 cmd = [cc, "-g"]
24 for arg in args:
25 found = find_file(arg)
26 if found:
27 cmd.append(found)
28 else:
29 cmd.append(arg)
30 cmd = " ".join(cmd)
31 result = os.system(cmd)
32 assert result == 0, "%r failed" % cmd
33
34 def unused_port():
35 # http://stackoverflow.com/questions/2838244/get-open-tcp-port-in-python/2838309#2838309
36 import socket
37 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
38 s.bind(("", 0))
39 port = s.getsockname()[1]
40 s.close()
41 return port
42
43 class Spike(object):
44 logname = "spike.log"
45
46 def __init__(self, cmd, binary=None, halted=False, with_gdb=True,
47 timeout=None, xlen=64):
48 """Launch spike. Return tuple of its process and the port it's running
49 on."""
50 if cmd:
51 cmd = shlex.split(cmd)
52 else:
53 cmd = ["spike"]
54 if xlen == 32:
55 cmd += ["--isa", "RV32"]
56
57 if timeout:
58 cmd = ["timeout", str(timeout)] + cmd
59
60 if halted:
61 cmd.append('-H')
62 if with_gdb:
63 self.port = unused_port()
64 cmd += ['--gdb-port', str(self.port)]
65 cmd.append("-m32")
66 cmd.append('pk')
67 if binary:
68 cmd.append(binary)
69 logfile = open(self.logname, "w")
70 logfile.write("+ %s\n" % " ".join(cmd))
71 logfile.flush()
72 self.process = subprocess.Popen(cmd, stdin=subprocess.PIPE,
73 stdout=logfile, stderr=logfile)
74
75 def __del__(self):
76 try:
77 self.process.kill()
78 self.process.wait()
79 except OSError:
80 pass
81
82 def wait(self, *args, **kwargs):
83 return self.process.wait(*args, **kwargs)
84
85 class VcsSim(object):
86 def __init__(self, simv=None, debug=False):
87 if simv:
88 cmd = shlex.split(simv)
89 else:
90 cmd = ["simv"]
91 cmd += ["+jtag_vpi_enable"]
92 if debug:
93 cmd[0] = cmd[0] + "-debug"
94 cmd += ["+vcdplusfile=output/gdbserver.vpd"]
95 logfile = open("simv.log", "w")
96 logfile.write("+ %s\n" % " ".join(cmd))
97 logfile.flush()
98 listenfile = open("simv.log", "r")
99 listenfile.seek(0, 2)
100 self.process = subprocess.Popen(cmd, stdin=subprocess.PIPE,
101 stdout=logfile, stderr=logfile)
102 done = False
103 while not done:
104 line = listenfile.readline()
105 if not line:
106 time.sleep(1)
107 match = re.match(r"^Listening on port (\d+)$", line)
108 if match:
109 done = True
110 self.port = int(match.group(1))
111 print "Using port %d for JTAG VPI" % self.port
112
113 def __del__(self):
114 try:
115 self.process.kill()
116 self.process.wait()
117 except OSError:
118 pass
119
120 class Openocd(object):
121 logname = "openocd.log"
122
123 def __init__(self, cmd=None, config=None, debug=False):
124 if cmd:
125 cmd = shlex.split(cmd)
126 else:
127 cmd = ["openocd"]
128 if config:
129 cmd += ["-f", find_file(config)]
130 if debug:
131 cmd.append("-d")
132
133 # Assign port
134 self.port = unused_port()
135 print "Using port %d for gdb server" % self.port
136 # This command needs to come before any config scripts on the command
137 # line, since they are executed in order.
138 cmd[1:1] = ["--command", "gdb_port %d" % self.port]
139
140 env = os.environ.copy()
141 env['JTAG_VPI_PORT'] = str(otherProcess.port)
142
143 logfile = open(Openocd.logname, "w")
144 logfile.write("+ %s\n" % " ".join(cmd))
145 logfile.flush()
146 self.process = subprocess.Popen(cmd, stdin=subprocess.PIPE,
147 stdout=logfile, stderr=logfile, env=env)
148
149 # Wait for OpenOCD to have made it through riscv_examine(). When using
150 # OpenOCD to communicate with a simulator this may take a long time,
151 # and gdb will time out when trying to connect if we attempt too early.
152 start = time.time()
153 messaged = False
154 while True:
155 log = open(Openocd.logname).read()
156 if "Examined RISCV core" in log:
157 break
158 if not self.process.poll() is None:
159 raise Exception(
160 "OpenOCD exited before completing riscv_examine()")
161 if not messaged and time.time() - start > 1:
162 messaged = True
163 print "Waiting for OpenOCD to examine RISCV core..."
164
165 def __del__(self):
166 try:
167 self.process.kill()
168 self.process.wait()
169 except OSError:
170 pass
171
172 class OpenocdCli(object):
173 def __init__(self, port=4444):
174 self.child = pexpect.spawn("sh -c 'telnet localhost %d | tee openocd-cli.log'" % port)
175 self.child.expect("> ")
176
177 def command(self, cmd):
178 self.child.sendline(cmd)
179 self.child.expect("\n")
180 self.child.expect("> ")
181 return self.child.before.strip()
182
183 class CannotAccess(Exception):
184 def __init__(self, address):
185 Exception.__init__(self)
186 self.address = address
187
188 class Gdb(object):
189 def __init__(self,
190 cmd=os.path.expandvars("$RISCV/bin/riscv64-unknown-elf-gdb")):
191 self.child = pexpect.spawn(cmd)
192 self.child.logfile = open("gdb.log", "w")
193 self.child.logfile.write("+ %s\n" % cmd)
194 self.wait()
195 self.command("set confirm off")
196 self.command("set width 0")
197 self.command("set height 0")
198 # Force consistency.
199 self.command("set print entry-values no")
200
201 def wait(self):
202 """Wait for prompt."""
203 self.child.expect(r"\(gdb\)")
204
205 def command(self, command, timeout=-1):
206 self.child.sendline(command)
207 self.child.expect("\n", timeout=timeout)
208 self.child.expect(r"\(gdb\)", timeout=timeout)
209 return self.child.before.strip()
210
211 def c(self, wait=True):
212 if wait:
213 output = self.command("c")
214 assert "Continuing" in output
215 return output
216 else:
217 self.child.sendline("c")
218 self.child.expect("Continuing")
219
220 def interrupt(self):
221 self.child.send("\003")
222 self.child.expect(r"\(gdb\)", timeout=60)
223 return self.child.before.strip()
224
225 def x(self, address, size='w'):
226 output = self.command("x/%s %s" % (size, address))
227 value = int(output.split(':')[1].strip(), 0)
228 return value
229
230 def p(self, obj):
231 output = self.command("p/x %s" % obj)
232 m = re.search("Cannot access memory at address (0x[0-9a-f]+)", output)
233 if m:
234 raise CannotAccess(int(m.group(1), 0))
235 value = int(output.split('=')[-1].strip(), 0)
236 return value
237
238 def p_string(self, obj):
239 output = self.command("p %s" % obj)
240 value = shlex.split(output.split('=')[-1].strip())[1]
241 return value
242
243 def stepi(self):
244 output = self.command("stepi")
245 return output
246
247 def load(self):
248 output = self.command("load", timeout=60)
249 assert "failed" not in output
250 assert "Transfer rate" in output
251
252 def b(self, location):
253 output = self.command("b %s" % location)
254 assert "not defined" not in output
255 assert "Breakpoint" in output
256 return output
257
258 def hbreak(self, location):
259 output = self.command("hbreak %s" % location)
260 assert "not defined" not in output
261 assert "Hardware assisted breakpoint" in output
262 return output
263
264 def run_all_tests(module, target, tests, fail_fast):
265 good_results = set(('pass', 'not_applicable'))
266
267 start = time.time()
268
269 results = {}
270 count = 0
271 for name in dir(module):
272 definition = getattr(module, name)
273 if type(definition) == type and hasattr(definition, 'test') and \
274 (not tests or any(test in name for test in tests)):
275 instance = definition(target)
276 result = instance.run()
277 results.setdefault(result, []).append(name)
278 count += 1
279 if result not in good_results and fail_fast:
280 break
281
282 header("ran %d tests in %.0fs" % (count, time.time() - start), dash=':')
283
284 result = 0
285 for key, value in results.iteritems():
286 print "%d tests returned %s" % (len(value), key)
287 if key not in good_results:
288 result = 1
289 for test in value:
290 print " ", test
291
292 return result
293
294 def add_test_run_options(parser):
295 parser.add_argument("--fail-fast", "-f", action="store_true",
296 help="Exit as soon as any test fails.")
297 parser.add_argument("test", nargs='*',
298 help="Run only tests that are named here.")
299
300 def header(title, dash='-'):
301 dashes = dash * (36 - len(title))
302 before = dashes[:len(dashes)/2]
303 after = dashes[len(dashes)/2:]
304 print "%s[ %s ]%s" % (before, title, after)
305
306 class BaseTest(object):
307 compiled = {}
308 logs = []
309
310 def __init__(self, target):
311 self.target = target
312 self.server = None
313 self.target_process = None
314 self.binary = None
315 self.start = 0
316
317 def early_applicable(self):
318 """Return a false value if the test has determined it cannot run
319 without ever needing to talk to the target or server."""
320 # pylint: disable=no-self-use
321 return True
322
323 def setup(self):
324 pass
325
326 def compile(self):
327 compile_args = getattr(self, 'compile_args', None)
328 if compile_args:
329 if compile_args not in BaseTest.compiled:
330 try:
331 # pylint: disable=star-args
332 BaseTest.compiled[compile_args] = \
333 self.target.compile(*compile_args)
334 except Exception: # pylint: disable=broad-except
335 print "exception while compiling in %.2fs" % (
336 time.time() - self.start)
337 print "=" * 40
338 header("Traceback")
339 traceback.print_exc(file=sys.stdout)
340 print "/" * 40
341 return "exception"
342 self.binary = BaseTest.compiled.get(compile_args)
343
344 def classSetup(self):
345 self.compile()
346 self.target_process = self.target.target()
347 self.server = self.target.server()
348 self.logs.append(self.server.logname)
349
350 def classTeardown(self):
351 del self.server
352 del self.target_process
353
354 def run(self):
355 """
356 If compile_args is set, compile a program and set self.binary.
357
358 Call setup().
359
360 Then call test() and return the result, displaying relevant information
361 if an exception is raised.
362 """
363
364 print "Running", type(self).__name__, "...",
365 sys.stdout.flush()
366
367 if not self.early_applicable():
368 print "not_applicable"
369 return "not_applicable"
370
371 self.start = time.time()
372
373 self.classSetup()
374
375 try:
376 self.setup()
377 result = self.test() # pylint: disable=no-member
378 except Exception as e: # pylint: disable=broad-except
379 if isinstance(e, TestFailed):
380 result = "fail"
381 else:
382 result = "exception"
383 print "%s in %.2fs" % (result, time.time() - self.start)
384 print "=" * 40
385 if isinstance(e, TestFailed):
386 header("Message")
387 print e.message
388 header("Traceback")
389 traceback.print_exc(file=sys.stdout)
390 for log in self.logs:
391 header(log)
392 print open(log, "r").read()
393 print "/" * 40
394 return result
395
396 finally:
397 self.classTeardown()
398
399 if not result:
400 result = 'pass'
401 print "%s in %.2fs" % (result, time.time() - self.start)
402 return result
403
404 class TestFailed(Exception):
405 def __init__(self, message):
406 Exception.__init__(self)
407 self.message = message
408
409 def assertEqual(a, b):
410 if a != b:
411 raise TestFailed("%r != %r" % (a, b))
412
413 def assertNotEqual(a, b):
414 if a == b:
415 raise TestFailed("%r == %r" % (a, b))
416
417 def assertIn(a, b):
418 if a not in b:
419 raise TestFailed("%r not in %r" % (a, b))
420
421 def assertNotIn(a, b):
422 if a in b:
423 raise TestFailed("%r in %r" % (a, b))
424
425 def assertGreater(a, b):
426 if not a > b:
427 raise TestFailed("%r not greater than %r" % (a, b))
428
429 def assertTrue(a):
430 if not a:
431 raise TestFailed("%r is not True" % a)
432
433 def assertRegexpMatches(text, regexp):
434 if not re.search(regexp, text):
435 raise TestFailed("can't find %r in %r" % (regexp, text))