Add framework to test OpenOCD directly.
[riscv-tests.git] / debug / testlib.py
1 import os.path
2 import re
3 import shlex
4 import subprocess
5 import sys
6 import time
7 import traceback
8
9 import pexpect
10
11 # Note that gdb comes with its own testsuite. I was unable to figure out how to
12 # run that testsuite against the spike simulator.
13
14 def find_file(path):
15 for directory in (os.getcwd(), os.path.dirname(__file__)):
16 fullpath = os.path.join(directory, path)
17 if os.path.exists(fullpath):
18 return fullpath
19 return None
20
21 def compile(args, xlen=32): # pylint: disable=redefined-builtin
22 cc = os.path.expandvars("$RISCV/bin/riscv%d-unknown-elf-gcc" % xlen)
23 cmd = [cc, "-g"]
24 for arg in args:
25 found = find_file(arg)
26 if found:
27 cmd.append(found)
28 else:
29 cmd.append(arg)
30 cmd = " ".join(cmd)
31 result = os.system(cmd)
32 assert result == 0, "%r failed" % cmd
33
34 def unused_port():
35 # http://stackoverflow.com/questions/2838244/get-open-tcp-port-in-python/2838309#2838309
36 import socket
37 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
38 s.bind(("", 0))
39 port = s.getsockname()[1]
40 s.close()
41 return port
42
43 class Spike(object):
44 logname = "spike.log"
45
46 def __init__(self, cmd, binary=None, halted=False, with_gdb=True,
47 timeout=None, xlen=64):
48 """Launch spike. Return tuple of its process and the port it's running
49 on."""
50 if cmd:
51 cmd = shlex.split(cmd)
52 else:
53 cmd = ["spike"]
54 if xlen == 32:
55 cmd += ["--isa", "RV32"]
56
57 if timeout:
58 cmd = ["timeout", str(timeout)] + cmd
59
60 if halted:
61 cmd.append('-H')
62 if with_gdb:
63 self.port = unused_port()
64 cmd += ['--gdb-port', str(self.port)]
65 cmd.append("-m32")
66 cmd.append('pk')
67 if binary:
68 cmd.append(binary)
69 logfile = open(self.logname, "w")
70 logfile.write("+ %s\n" % " ".join(cmd))
71 logfile.flush()
72 self.process = subprocess.Popen(cmd, stdin=subprocess.PIPE,
73 stdout=logfile, stderr=logfile)
74
75 def __del__(self):
76 try:
77 self.process.kill()
78 self.process.wait()
79 except OSError:
80 pass
81
82 def wait(self, *args, **kwargs):
83 return self.process.wait(*args, **kwargs)
84
85 class VcsSim(object):
86 def __init__(self, simv=None, debug=False):
87 if simv:
88 cmd = shlex.split(simv)
89 else:
90 cmd = ["simv"]
91 cmd += ["+jtag_vpi_enable"]
92 if debug:
93 cmd[0] = cmd[0] + "-debug"
94 cmd += ["+vcdplusfile=output/gdbserver.vpd"]
95 logfile = open("simv.log", "w")
96 logfile.write("+ %s\n" % " ".join(cmd))
97 logfile.flush()
98 listenfile = open("simv.log", "r")
99 listenfile.seek(0, 2)
100 self.process = subprocess.Popen(cmd, stdin=subprocess.PIPE,
101 stdout=logfile, stderr=logfile)
102 done = False
103 while not done:
104 line = listenfile.readline()
105 if not line:
106 time.sleep(1)
107 if "Listening on port 5555" in line:
108 done = True
109
110 def __del__(self):
111 try:
112 self.process.kill()
113 self.process.wait()
114 except OSError:
115 pass
116
117 class Openocd(object):
118 logname = "openocd.log"
119
120 def __init__(self, cmd=None, config=None, debug=False):
121 if cmd:
122 cmd = shlex.split(cmd)
123 else:
124 cmd = ["openocd"]
125 if config:
126 cmd += ["-f", find_file(config)]
127 if debug:
128 cmd.append("-d")
129 logfile = open(Openocd.logname, "w")
130 logfile.write("+ %s\n" % " ".join(cmd))
131 logfile.flush()
132 self.process = subprocess.Popen(cmd, stdin=subprocess.PIPE,
133 stdout=logfile, stderr=logfile)
134 # TODO: Pick a random port
135 self.port = 3333
136
137 # Wait for OpenOCD to have made it through riscv_examine(). When using
138 # OpenOCD to communicate with a simulator this may take a long time,
139 # and gdb will time out when trying to connect if we attempt too early.
140 start = time.time()
141 messaged = False
142 while True:
143 log = open(Openocd.logname).read()
144 if "Examined RISCV core" in log:
145 break
146 if not self.process.poll() is None:
147 raise Exception(
148 "OpenOCD exited before completing riscv_examine()")
149 if not messaged and time.time() - start > 1:
150 messaged = True
151 print "Waiting for OpenOCD to examine RISCV core..."
152
153 def __del__(self):
154 try:
155 self.process.kill()
156 self.process.wait()
157 except OSError:
158 pass
159
160 class OpenocdCli(object):
161 def __init__(self, port=4444):
162 self.child = pexpect.spawn("sh -c 'telnet localhost %d | tee openocd-cli.log'" % port)
163 self.child.expect("> ")
164
165 def command(self, cmd):
166 self.child.sendline(cmd)
167 self.child.expect("\n")
168 self.child.expect("> ")
169 return self.child.before.strip()
170
171 class CannotAccess(Exception):
172 def __init__(self, address):
173 Exception.__init__(self)
174 self.address = address
175
176 class Gdb(object):
177 def __init__(self,
178 cmd=os.path.expandvars("$RISCV/bin/riscv64-unknown-elf-gdb")):
179 self.child = pexpect.spawn(cmd)
180 self.child.logfile = open("gdb.log", "w")
181 self.child.logfile.write("+ %s\n" % cmd)
182 self.wait()
183 self.command("set confirm off")
184 self.command("set width 0")
185 self.command("set height 0")
186 # Force consistency.
187 self.command("set print entry-values no")
188
189 def wait(self):
190 """Wait for prompt."""
191 self.child.expect(r"\(gdb\)")
192
193 def command(self, command, timeout=-1):
194 self.child.sendline(command)
195 self.child.expect("\n", timeout=timeout)
196 self.child.expect(r"\(gdb\)", timeout=timeout)
197 return self.child.before.strip()
198
199 def c(self, wait=True):
200 if wait:
201 output = self.command("c")
202 assert "Continuing" in output
203 return output
204 else:
205 self.child.sendline("c")
206 self.child.expect("Continuing")
207
208 def interrupt(self):
209 self.child.send("\003")
210 self.child.expect(r"\(gdb\)", timeout=60)
211 return self.child.before.strip()
212
213 def x(self, address, size='w'):
214 output = self.command("x/%s %s" % (size, address))
215 value = int(output.split(':')[1].strip(), 0)
216 return value
217
218 def p(self, obj):
219 output = self.command("p/x %s" % obj)
220 m = re.search("Cannot access memory at address (0x[0-9a-f]+)", output)
221 if m:
222 raise CannotAccess(int(m.group(1), 0))
223 value = int(output.split('=')[-1].strip(), 0)
224 return value
225
226 def p_string(self, obj):
227 output = self.command("p %s" % obj)
228 value = shlex.split(output.split('=')[-1].strip())[1]
229 return value
230
231 def stepi(self):
232 output = self.command("stepi")
233 return output
234
235 def load(self):
236 output = self.command("load", timeout=60)
237 assert "failed" not in output
238 assert "Transfer rate" in output
239
240 def b(self, location):
241 output = self.command("b %s" % location)
242 assert "not defined" not in output
243 assert "Breakpoint" in output
244 return output
245
246 def hbreak(self, location):
247 output = self.command("hbreak %s" % location)
248 assert "not defined" not in output
249 assert "Hardware assisted breakpoint" in output
250 return output
251
252 def run_all_tests(module, target, tests, fail_fast):
253 good_results = set(('pass', 'not_applicable'))
254
255 start = time.time()
256
257 results = {}
258 count = 0
259 for name in dir(module):
260 definition = getattr(module, name)
261 if type(definition) == type and hasattr(definition, 'test') and \
262 (not tests or any(test in name for test in tests)):
263 instance = definition(target)
264 result = instance.run()
265 results.setdefault(result, []).append(name)
266 count += 1
267 if result not in good_results and fail_fast:
268 break
269
270 header("ran %d tests in %.0fs" % (count, time.time() - start), dash=':')
271
272 result = 0
273 for key, value in results.iteritems():
274 print "%d tests returned %s" % (len(value), key)
275 if key not in good_results:
276 result = 1
277 for test in value:
278 print " ", test
279
280 return result
281
282 def add_test_run_options(parser):
283 parser.add_argument("--fail-fast", "-f", action="store_true",
284 help="Exit as soon as any test fails.")
285 parser.add_argument("test", nargs='*',
286 help="Run only tests that are named here.")
287
288 def header(title, dash='-'):
289 dashes = dash * (36 - len(title))
290 before = dashes[:len(dashes)/2]
291 after = dashes[len(dashes)/2:]
292 print "%s[ %s ]%s" % (before, title, after)
293
294 class BaseTest(object):
295 compiled = {}
296 logs = []
297
298 def __init__(self, target):
299 self.target = target
300 self.server = None
301 self.target_process = None
302 self.binary = None
303 self.start = 0
304
305 def early_applicable(self):
306 """Return a false value if the test has determined it cannot run
307 without ever needing to talk to the target or server."""
308 # pylint: disable=no-self-use
309 return True
310
311 def setup(self):
312 pass
313
314 def compile(self):
315 compile_args = getattr(self, 'compile_args', None)
316 if compile_args:
317 if compile_args not in BaseTest.compiled:
318 try:
319 # pylint: disable=star-args
320 BaseTest.compiled[compile_args] = \
321 self.target.compile(*compile_args)
322 except Exception: # pylint: disable=broad-except
323 print "exception while compiling in %.2fs" % (
324 time.time() - self.start)
325 print "=" * 40
326 header("Traceback")
327 traceback.print_exc(file=sys.stdout)
328 print "/" * 40
329 return "exception"
330 self.binary = BaseTest.compiled.get(compile_args)
331
332 def classSetup(self):
333 self.compile()
334 self.target_process = self.target.target()
335 self.server = self.target.server()
336 self.logs.append(self.server.logname)
337
338 def classTeardown(self):
339 del self.server
340 del self.target_process
341
342 def run(self):
343 """
344 If compile_args is set, compile a program and set self.binary.
345
346 Call setup().
347
348 Then call test() and return the result, displaying relevant information
349 if an exception is raised.
350 """
351
352 print "Running", type(self).__name__, "...",
353 sys.stdout.flush()
354
355 if not self.early_applicable():
356 print "not_applicable"
357 return "not_applicable"
358
359 self.start = time.time()
360
361 self.classSetup()
362
363 try:
364 self.setup()
365 result = self.test() # pylint: disable=no-member
366 except Exception as e: # pylint: disable=broad-except
367 if isinstance(e, TestFailed):
368 result = "fail"
369 else:
370 result = "exception"
371 print "%s in %.2fs" % (result, time.time() - self.start)
372 print "=" * 40
373 if isinstance(e, TestFailed):
374 header("Message")
375 print e.message
376 header("Traceback")
377 traceback.print_exc(file=sys.stdout)
378 for log in self.logs:
379 header(log)
380 print open(log, "r").read()
381 print "/" * 40
382 return result
383
384 finally:
385 self.classTeardown()
386
387 if not result:
388 result = 'pass'
389 print "%s in %.2fs" % (result, time.time() - self.start)
390 return result
391
392 class TestFailed(Exception):
393 def __init__(self, message):
394 Exception.__init__(self)
395 self.message = message
396
397 def assertEqual(a, b):
398 if a != b:
399 raise TestFailed("%r != %r" % (a, b))
400
401 def assertNotEqual(a, b):
402 if a == b:
403 raise TestFailed("%r == %r" % (a, b))
404
405 def assertIn(a, b):
406 if a not in b:
407 raise TestFailed("%r not in %r" % (a, b))
408
409 def assertNotIn(a, b):
410 if a in b:
411 raise TestFailed("%r in %r" % (a, b))
412
413 def assertGreater(a, b):
414 if not a > b:
415 raise TestFailed("%r not greater than %r" % (a, b))
416
417 def assertTrue(a):
418 if not a:
419 raise TestFailed("%r is not True" % a)
420
421 def assertRegexpMatches(text, regexp):
422 if not re.search(regexp, text):
423 raise TestFailed("can't find %r in %r" % (regexp, text))