Fix conflict: JTAG VPI changes vs openocd refactor
[riscv-tests.git] / debug / testlib.py
1 import os.path
2 import re
3 import shlex
4 import subprocess
5 import sys
6 import time
7 import traceback
8
9 import pexpect
10
11 # Note that gdb comes with its own testsuite. I was unable to figure out how to
12 # run that testsuite against the spike simulator.
13
14 def find_file(path):
15 for directory in (os.getcwd(), os.path.dirname(__file__)):
16 fullpath = os.path.join(directory, path)
17 if os.path.exists(fullpath):
18 return fullpath
19 return None
20
21 def compile(args, xlen=32): # pylint: disable=redefined-builtin
22 cc = os.path.expandvars("$RISCV/bin/riscv%d-unknown-elf-gcc" % xlen)
23 cmd = [cc, "-g"]
24 for arg in args:
25 found = find_file(arg)
26 if found:
27 cmd.append(found)
28 else:
29 cmd.append(arg)
30 cmd = " ".join(cmd)
31 result = os.system(cmd)
32 assert result == 0, "%r failed" % cmd
33
34 def unused_port():
35 # http://stackoverflow.com/questions/2838244/get-open-tcp-port-in-python/2838309#2838309
36 import socket
37 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
38 s.bind(("", 0))
39 port = s.getsockname()[1]
40 s.close()
41 return port
42
43 class Spike(object):
44 logname = "spike.log"
45
46 def __init__(self, cmd, binary=None, halted=False, with_gdb=True,
47 timeout=None, xlen=64):
48 """Launch spike. Return tuple of its process and the port it's running
49 on."""
50 if cmd:
51 cmd = shlex.split(cmd)
52 else:
53 cmd = ["spike"]
54 if xlen == 32:
55 cmd += ["--isa", "RV32"]
56
57 if timeout:
58 cmd = ["timeout", str(timeout)] + cmd
59
60 if halted:
61 cmd.append('-H')
62 if with_gdb:
63 self.port = unused_port()
64 cmd += ['--gdb-port', str(self.port)]
65 cmd.append("-m32")
66 cmd.append('pk')
67 if binary:
68 cmd.append(binary)
69 logfile = open(self.logname, "w")
70 logfile.write("+ %s\n" % " ".join(cmd))
71 logfile.flush()
72 self.process = subprocess.Popen(cmd, stdin=subprocess.PIPE,
73 stdout=logfile, stderr=logfile)
74
75 def __del__(self):
76 try:
77 self.process.kill()
78 self.process.wait()
79 except OSError:
80 pass
81
82 def wait(self, *args, **kwargs):
83 return self.process.wait(*args, **kwargs)
84
85 class VcsSim(object):
86 def __init__(self, simv=None, debug=False):
87 if simv:
88 cmd = shlex.split(simv)
89 else:
90 cmd = ["simv"]
91 cmd += ["+jtag_vpi_enable"]
92 if debug:
93 cmd[0] = cmd[0] + "-debug"
94 cmd += ["+vcdplusfile=output/gdbserver.vpd"]
95 logfile = open("simv.log", "w")
96 logfile.write("+ %s\n" % " ".join(cmd))
97 logfile.flush()
98 listenfile = open("simv.log", "r")
99 listenfile.seek(0, 2)
100 self.process = subprocess.Popen(cmd, stdin=subprocess.PIPE,
101 stdout=logfile, stderr=logfile)
102 done = False
103 while not done:
104 line = listenfile.readline()
105 if not line:
106 time.sleep(1)
107 match = re.match(r"^Listening on port (\d+)$", line)
108 if match:
109 done = True
110 self.port = int(match.group(1))
111 os.environ['JTAG_VPI_PORT'] = str(self.port)
112
113 def __del__(self):
114 try:
115 self.process.kill()
116 self.process.wait()
117 except OSError:
118 pass
119
120 class Openocd(object):
121 logname = "openocd.log"
122
123 def __init__(self, cmd=None, config=None, debug=False):
124 if cmd:
125 cmd = shlex.split(cmd)
126 else:
127 cmd = ["openocd"]
128 if config:
129 cmd += ["-f", find_file(config)]
130 if debug:
131 cmd.append("-d")
132
133 # Assign port
134 self.port = unused_port()
135 # This command needs to come before any config scripts on the command
136 # line, since they are executed in order.
137 cmd[1:1] = ["--command", "gdb_port %d" % self.port]
138
139 logfile = open(Openocd.logname, "w")
140 logfile.write("+ %s\n" % " ".join(cmd))
141 logfile.flush()
142 self.process = subprocess.Popen(cmd, stdin=subprocess.PIPE,
143 stdout=logfile, stderr=logfile)
144
145 # Wait for OpenOCD to have made it through riscv_examine(). When using
146 # OpenOCD to communicate with a simulator this may take a long time,
147 # and gdb will time out when trying to connect if we attempt too early.
148 start = time.time()
149 messaged = False
150 while True:
151 log = open(Openocd.logname).read()
152 if "Examined RISCV core" in log:
153 break
154 if not self.process.poll() is None:
155 raise Exception(
156 "OpenOCD exited before completing riscv_examine()")
157 if not messaged and time.time() - start > 1:
158 messaged = True
159 print "Waiting for OpenOCD to examine RISCV core..."
160
161 def __del__(self):
162 try:
163 self.process.kill()
164 self.process.wait()
165 except OSError:
166 pass
167
168 class OpenocdCli(object):
169 def __init__(self, port=4444):
170 self.child = pexpect.spawn(
171 "sh -c 'telnet localhost %d | tee openocd-cli.log'" % port)
172 self.child.expect("> ")
173
174 def command(self, cmd):
175 self.child.sendline(cmd)
176 self.child.expect("\n")
177 self.child.expect("> ")
178 return self.child.before.strip()
179
180 class CannotAccess(Exception):
181 def __init__(self, address):
182 Exception.__init__(self)
183 self.address = address
184
185 class Gdb(object):
186 def __init__(self,
187 cmd=os.path.expandvars("$RISCV/bin/riscv64-unknown-elf-gdb")):
188 self.child = pexpect.spawn(cmd)
189 self.child.logfile = open("gdb.log", "w")
190 self.child.logfile.write("+ %s\n" % cmd)
191 self.wait()
192 self.command("set confirm off")
193 self.command("set width 0")
194 self.command("set height 0")
195 # Force consistency.
196 self.command("set print entry-values no")
197
198 def wait(self):
199 """Wait for prompt."""
200 self.child.expect(r"\(gdb\)")
201
202 def command(self, command, timeout=-1):
203 self.child.sendline(command)
204 self.child.expect("\n", timeout=timeout)
205 self.child.expect(r"\(gdb\)", timeout=timeout)
206 return self.child.before.strip()
207
208 def c(self, wait=True):
209 if wait:
210 output = self.command("c")
211 assert "Continuing" in output
212 return output
213 else:
214 self.child.sendline("c")
215 self.child.expect("Continuing")
216
217 def interrupt(self):
218 self.child.send("\003")
219 self.child.expect(r"\(gdb\)", timeout=60)
220 return self.child.before.strip()
221
222 def x(self, address, size='w'):
223 output = self.command("x/%s %s" % (size, address))
224 value = int(output.split(':')[1].strip(), 0)
225 return value
226
227 def p(self, obj):
228 output = self.command("p/x %s" % obj)
229 m = re.search("Cannot access memory at address (0x[0-9a-f]+)", output)
230 if m:
231 raise CannotAccess(int(m.group(1), 0))
232 value = int(output.split('=')[-1].strip(), 0)
233 return value
234
235 def p_string(self, obj):
236 output = self.command("p %s" % obj)
237 value = shlex.split(output.split('=')[-1].strip())[1]
238 return value
239
240 def stepi(self):
241 output = self.command("stepi")
242 return output
243
244 def load(self):
245 output = self.command("load", timeout=60)
246 assert "failed" not in output
247 assert "Transfer rate" in output
248
249 def b(self, location):
250 output = self.command("b %s" % location)
251 assert "not defined" not in output
252 assert "Breakpoint" in output
253 return output
254
255 def hbreak(self, location):
256 output = self.command("hbreak %s" % location)
257 assert "not defined" not in output
258 assert "Hardware assisted breakpoint" in output
259 return output
260
261 def run_all_tests(module, target, tests, fail_fast):
262 good_results = set(('pass', 'not_applicable'))
263
264 start = time.time()
265
266 results = {}
267 count = 0
268 for name in dir(module):
269 definition = getattr(module, name)
270 if type(definition) == type and hasattr(definition, 'test') and \
271 (not tests or any(test in name for test in tests)):
272 instance = definition(target)
273 result = instance.run()
274 results.setdefault(result, []).append(name)
275 count += 1
276 if result not in good_results and fail_fast:
277 break
278
279 header("ran %d tests in %.0fs" % (count, time.time() - start), dash=':')
280
281 result = 0
282 for key, value in results.iteritems():
283 print "%d tests returned %s" % (len(value), key)
284 if key not in good_results:
285 result = 1
286 for test in value:
287 print " ", test
288
289 return result
290
291 def add_test_run_options(parser):
292 parser.add_argument("--fail-fast", "-f", action="store_true",
293 help="Exit as soon as any test fails.")
294 parser.add_argument("test", nargs='*',
295 help="Run only tests that are named here.")
296
297 def header(title, dash='-'):
298 dashes = dash * (36 - len(title))
299 before = dashes[:len(dashes)/2]
300 after = dashes[len(dashes)/2:]
301 print "%s[ %s ]%s" % (before, title, after)
302
303 class BaseTest(object):
304 compiled = {}
305 logs = []
306
307 def __init__(self, target):
308 self.target = target
309 self.server = None
310 self.target_process = None
311 self.binary = None
312 self.start = 0
313
314 def early_applicable(self):
315 """Return a false value if the test has determined it cannot run
316 without ever needing to talk to the target or server."""
317 # pylint: disable=no-self-use
318 return True
319
320 def setup(self):
321 pass
322
323 def compile(self):
324 compile_args = getattr(self, 'compile_args', None)
325 if compile_args:
326 if compile_args not in BaseTest.compiled:
327 try:
328 # pylint: disable=star-args
329 BaseTest.compiled[compile_args] = \
330 self.target.compile(*compile_args)
331 except Exception: # pylint: disable=broad-except
332 print "exception while compiling in %.2fs" % (
333 time.time() - self.start)
334 print "=" * 40
335 header("Traceback")
336 traceback.print_exc(file=sys.stdout)
337 print "/" * 40
338 return "exception"
339 self.binary = BaseTest.compiled.get(compile_args)
340
341 def classSetup(self):
342 self.compile()
343 self.target_process = self.target.target()
344 self.server = self.target.server()
345 self.logs.append(self.server.logname)
346
347 def classTeardown(self):
348 del self.server
349 del self.target_process
350
351 def run(self):
352 """
353 If compile_args is set, compile a program and set self.binary.
354
355 Call setup().
356
357 Then call test() and return the result, displaying relevant information
358 if an exception is raised.
359 """
360
361 print "Running", type(self).__name__, "...",
362 sys.stdout.flush()
363
364 if not self.early_applicable():
365 print "not_applicable"
366 return "not_applicable"
367
368 self.start = time.time()
369
370 self.classSetup()
371
372 try:
373 self.setup()
374 result = self.test() # pylint: disable=no-member
375 except Exception as e: # pylint: disable=broad-except
376 if isinstance(e, TestFailed):
377 result = "fail"
378 else:
379 result = "exception"
380 print "%s in %.2fs" % (result, time.time() - self.start)
381 print "=" * 40
382 if isinstance(e, TestFailed):
383 header("Message")
384 print e.message
385 header("Traceback")
386 traceback.print_exc(file=sys.stdout)
387 for log in self.logs:
388 header(log)
389 print open(log, "r").read()
390 print "/" * 40
391 return result
392
393 finally:
394 self.classTeardown()
395
396 if not result:
397 result = 'pass'
398 print "%s in %.2fs" % (result, time.time() - self.start)
399 return result
400
401 class TestFailed(Exception):
402 def __init__(self, message):
403 Exception.__init__(self)
404 self.message = message
405
406 def assertEqual(a, b):
407 if a != b:
408 raise TestFailed("%r != %r" % (a, b))
409
410 def assertNotEqual(a, b):
411 if a == b:
412 raise TestFailed("%r == %r" % (a, b))
413
414 def assertIn(a, b):
415 if a not in b:
416 raise TestFailed("%r not in %r" % (a, b))
417
418 def assertNotIn(a, b):
419 if a in b:
420 raise TestFailed("%r in %r" % (a, b))
421
422 def assertGreater(a, b):
423 if not a > b:
424 raise TestFailed("%r not greater than %r" % (a, b))
425
426 def assertTrue(a):
427 if not a:
428 raise TestFailed("%r is not True" % a)
429
430 def assertRegexpMatches(text, regexp):
431 if not re.search(regexp, text):
432 raise TestFailed("can't find %r in %r" % (regexp, text))