Add basic floating point register test.
[riscv-tests.git] / debug / testlib.py
1 import os.path
2 import re
3 import shlex
4 import subprocess
5 import sys
6 import time
7 import traceback
8
9 import pexpect
10
11 # Note that gdb comes with its own testsuite. I was unable to figure out how to
12 # run that testsuite against the spike simulator.
13
14 def find_file(path):
15 for directory in (os.getcwd(), os.path.dirname(__file__)):
16 fullpath = os.path.join(directory, path)
17 if os.path.exists(fullpath):
18 return fullpath
19 return None
20
21 def compile(args, xlen=32): # pylint: disable=redefined-builtin
22 cc = os.path.expandvars("$RISCV/bin/riscv%d-unknown-elf-gcc" % xlen)
23 cmd = [cc, "-g"]
24 for arg in args:
25 found = find_file(arg)
26 if found:
27 cmd.append(found)
28 else:
29 cmd.append(arg)
30 cmd = " ".join(cmd)
31 result = os.system(cmd)
32 assert result == 0, "%r failed" % cmd
33
34 def unused_port():
35 # http://stackoverflow.com/questions/2838244/get-open-tcp-port-in-python/2838309#2838309
36 import socket
37 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
38 s.bind(("", 0))
39 port = s.getsockname()[1]
40 s.close()
41 return port
42
43 class Spike(object):
44 logname = "spike.log"
45
46 def __init__(self, cmd, binary=None, halted=False, with_gdb=True,
47 timeout=None, xlen=64):
48 """Launch spike. Return tuple of its process and the port it's running
49 on."""
50 if cmd:
51 cmd = shlex.split(cmd)
52 else:
53 cmd = ["spike"]
54 if xlen == 32:
55 cmd += ["--isa", "RV32"]
56
57 if timeout:
58 cmd = ["timeout", str(timeout)] + cmd
59
60 if halted:
61 cmd.append('-H')
62 if with_gdb:
63 self.port = unused_port()
64 cmd += ['--gdb-port', str(self.port)]
65 cmd.append("-m32")
66 cmd.append('pk')
67 if binary:
68 cmd.append(binary)
69 logfile = open(self.logname, "w")
70 logfile.write("+ %s\n" % " ".join(cmd))
71 logfile.flush()
72 self.process = subprocess.Popen(cmd, stdin=subprocess.PIPE,
73 stdout=logfile, stderr=logfile)
74
75 def __del__(self):
76 try:
77 self.process.kill()
78 self.process.wait()
79 except OSError:
80 pass
81
82 def wait(self, *args, **kwargs):
83 return self.process.wait(*args, **kwargs)
84
85 class VcsSim(object):
86 def __init__(self, simv=None, debug=False):
87 if simv:
88 cmd = shlex.split(simv)
89 else:
90 cmd = ["simv"]
91 cmd += ["+jtag_vpi_enable"]
92 if debug:
93 cmd[0] = cmd[0] + "-debug"
94 cmd += ["+vcdplusfile=output/gdbserver.vpd"]
95 logfile = open("simv.log", "w")
96 logfile.write("+ %s\n" % " ".join(cmd))
97 logfile.flush()
98 listenfile = open("simv.log", "r")
99 listenfile.seek(0, 2)
100 self.process = subprocess.Popen(cmd, stdin=subprocess.PIPE,
101 stdout=logfile, stderr=logfile)
102 done = False
103 while not done:
104 line = listenfile.readline()
105 if not line:
106 time.sleep(1)
107 match = re.match(r"^Listening on port (\d+)$", line)
108 if match:
109 done = True
110 self.port = int(match.group(1))
111 os.environ['JTAG_VPI_PORT'] = str(self.port)
112
113 def __del__(self):
114 try:
115 self.process.kill()
116 self.process.wait()
117 except OSError:
118 pass
119
120 class Openocd(object):
121 logname = "openocd.log"
122
123 def __init__(self, cmd=None, config=None, debug=False):
124 if cmd:
125 cmd = shlex.split(cmd)
126 else:
127 cmd = ["openocd"]
128 if config:
129 cmd += ["-f", find_file(config)]
130 if debug:
131 cmd.append("-d")
132
133 # Assign port
134 self.port = unused_port()
135 # This command needs to come before any config scripts on the command
136 # line, since they are executed in order.
137 cmd[1:1] = ["--command", "gdb_port %d" % self.port]
138
139 logfile = open(Openocd.logname, "w")
140 logfile.write("+ %s\n" % " ".join(cmd))
141 logfile.flush()
142 self.process = subprocess.Popen(cmd, stdin=subprocess.PIPE,
143 stdout=logfile, stderr=logfile)
144
145 # Wait for OpenOCD to have made it through riscv_examine(). When using
146 # OpenOCD to communicate with a simulator this may take a long time,
147 # and gdb will time out when trying to connect if we attempt too early.
148 start = time.time()
149 messaged = False
150 while True:
151 log = open(Openocd.logname).read()
152 if "Examined RISCV core" in log:
153 break
154 if not self.process.poll() is None:
155 raise Exception(
156 "OpenOCD exited before completing riscv_examine()")
157 if not messaged and time.time() - start > 1:
158 messaged = True
159 print "Waiting for OpenOCD to examine RISCV core..."
160
161 def __del__(self):
162 try:
163 self.process.kill()
164 self.process.wait()
165 except OSError:
166 pass
167
168 class OpenocdCli(object):
169 def __init__(self, port=4444):
170 self.child = pexpect.spawn(
171 "sh -c 'telnet localhost %d | tee openocd-cli.log'" % port)
172 self.child.expect("> ")
173
174 def command(self, cmd):
175 self.child.sendline(cmd)
176 self.child.expect(cmd)
177 self.child.expect("\n")
178 self.child.expect("> ")
179 return self.child.before.strip("\t\r\n \0")
180
181 def reg(self, reg=''):
182 output = self.command("reg %s" % reg)
183 matches = re.findall(r"(\w+) \(/\d+\): (0x[0-9A-F]+)", output)
184 values = {r: int(v, 0) for r, v in matches}
185 if reg:
186 return values[reg]
187 return values
188
189 def load_image(self, image):
190 output = self.command("load_image %s" % image)
191 if 'invalid ELF file, only 32bits files are supported' in output:
192 raise TestNotApplicable(output)
193
194 class CannotAccess(Exception):
195 def __init__(self, address):
196 Exception.__init__(self)
197 self.address = address
198
199 class Gdb(object):
200 def __init__(self,
201 cmd=os.path.expandvars("$RISCV/bin/riscv64-unknown-elf-gdb")):
202 self.child = pexpect.spawn(cmd)
203 self.child.logfile = open("gdb.log", "w")
204 self.child.logfile.write("+ %s\n" % cmd)
205 self.wait()
206 self.command("set confirm off")
207 self.command("set width 0")
208 self.command("set height 0")
209 # Force consistency.
210 self.command("set print entry-values no")
211
212 def wait(self):
213 """Wait for prompt."""
214 self.child.expect(r"\(gdb\)")
215
216 def command(self, command, timeout=-1):
217 self.child.sendline(command)
218 self.child.expect("\n", timeout=timeout)
219 self.child.expect(r"\(gdb\)", timeout=timeout)
220 return self.child.before.strip()
221
222 def c(self, wait=True):
223 if wait:
224 output = self.command("c")
225 assert "Continuing" in output
226 return output
227 else:
228 self.child.sendline("c")
229 self.child.expect("Continuing")
230
231 def interrupt(self):
232 self.child.send("\003")
233 self.child.expect(r"\(gdb\)", timeout=60)
234 return self.child.before.strip()
235
236 def x(self, address, size='w'):
237 output = self.command("x/%s %s" % (size, address))
238 value = int(output.split(':')[1].strip(), 0)
239 return value
240
241 def p_raw(self, obj):
242 output = self.command("p %s" % obj)
243 m = re.search("Cannot access memory at address (0x[0-9a-f]+)", output)
244 if m:
245 raise CannotAccess(int(m.group(1), 0))
246 return output.split('=')[-1].strip()
247
248 def p(self, obj):
249 output = self.command("p/x %s" % obj)
250 m = re.search("Cannot access memory at address (0x[0-9a-f]+)", output)
251 if m:
252 raise CannotAccess(int(m.group(1), 0))
253 value = int(output.split('=')[-1].strip(), 0)
254 return value
255
256 def p_string(self, obj):
257 output = self.command("p %s" % obj)
258 value = shlex.split(output.split('=')[-1].strip())[1]
259 return value
260
261 def stepi(self):
262 output = self.command("stepi")
263 return output
264
265 def load(self):
266 output = self.command("load", timeout=60)
267 assert "failed" not in output
268 assert "Transfer rate" in output
269
270 def b(self, location):
271 output = self.command("b %s" % location)
272 assert "not defined" not in output
273 assert "Breakpoint" in output
274 return output
275
276 def hbreak(self, location):
277 output = self.command("hbreak %s" % location)
278 assert "not defined" not in output
279 assert "Hardware assisted breakpoint" in output
280 return output
281
282 def run_all_tests(module, target, tests, fail_fast):
283 good_results = set(('pass', 'not_applicable'))
284
285 start = time.time()
286
287 results = {}
288 count = 0
289 for name in dir(module):
290 definition = getattr(module, name)
291 if type(definition) == type and hasattr(definition, 'test') and \
292 (not tests or any(test in name for test in tests)):
293 instance = definition(target)
294 result = instance.run()
295 results.setdefault(result, []).append(name)
296 count += 1
297 if result not in good_results and fail_fast:
298 break
299
300 header("ran %d tests in %.0fs" % (count, time.time() - start), dash=':')
301
302 result = 0
303 for key, value in results.iteritems():
304 print "%d tests returned %s" % (len(value), key)
305 if key not in good_results:
306 result = 1
307 for test in value:
308 print " ", test
309
310 return result
311
312 def add_test_run_options(parser):
313 parser.add_argument("--fail-fast", "-f", action="store_true",
314 help="Exit as soon as any test fails.")
315 parser.add_argument("test", nargs='*',
316 help="Run only tests that are named here.")
317
318 def header(title, dash='-'):
319 dashes = dash * (36 - len(title))
320 before = dashes[:len(dashes)/2]
321 after = dashes[len(dashes)/2:]
322 print "%s[ %s ]%s" % (before, title, after)
323
324 class BaseTest(object):
325 compiled = {}
326 logs = []
327
328 def __init__(self, target):
329 self.target = target
330 self.server = None
331 self.target_process = None
332 self.binary = None
333 self.start = 0
334
335 def early_applicable(self):
336 """Return a false value if the test has determined it cannot run
337 without ever needing to talk to the target or server."""
338 # pylint: disable=no-self-use
339 return True
340
341 def setup(self):
342 pass
343
344 def compile(self):
345 compile_args = getattr(self, 'compile_args', None)
346 if compile_args:
347 if compile_args not in BaseTest.compiled:
348 try:
349 # pylint: disable=star-args
350 BaseTest.compiled[compile_args] = \
351 self.target.compile(*compile_args)
352 except Exception: # pylint: disable=broad-except
353 print "exception while compiling in %.2fs" % (
354 time.time() - self.start)
355 print "=" * 40
356 header("Traceback")
357 traceback.print_exc(file=sys.stdout)
358 print "/" * 40
359 return "exception"
360 self.binary = BaseTest.compiled.get(compile_args)
361
362 def classSetup(self):
363 self.compile()
364 self.target_process = self.target.target()
365 self.server = self.target.server()
366 self.logs.append(self.server.logname)
367
368 def classTeardown(self):
369 del self.server
370 del self.target_process
371
372 def run(self):
373 """
374 If compile_args is set, compile a program and set self.binary.
375
376 Call setup().
377
378 Then call test() and return the result, displaying relevant information
379 if an exception is raised.
380 """
381
382 print "Running", type(self).__name__, "...",
383 sys.stdout.flush()
384
385 if not self.early_applicable():
386 print "not_applicable"
387 return "not_applicable"
388
389 self.start = time.time()
390
391 self.classSetup()
392
393 try:
394 self.setup()
395 result = self.test() # pylint: disable=no-member
396 except TestNotApplicable:
397 result = "not_applicable"
398 except Exception as e: # pylint: disable=broad-except
399 if isinstance(e, TestFailed):
400 result = "fail"
401 else:
402 result = "exception"
403 print "%s in %.2fs" % (result, time.time() - self.start)
404 print "=" * 40
405 if isinstance(e, TestFailed):
406 header("Message")
407 print e.message
408 header("Traceback")
409 traceback.print_exc(file=sys.stdout)
410 for log in self.logs:
411 header(log)
412 print open(log, "r").read()
413 print "/" * 40
414 return result
415
416 finally:
417 self.classTeardown()
418
419 if not result:
420 result = 'pass'
421 print "%s in %.2fs" % (result, time.time() - self.start)
422 return result
423
424 class TestFailed(Exception):
425 def __init__(self, message):
426 Exception.__init__(self)
427 self.message = message
428
429 class TestNotApplicable(Exception):
430 def __init__(self, message):
431 Exception.__init__(self)
432 self.message = message
433
434 def assertEqual(a, b):
435 if a != b:
436 raise TestFailed("%r != %r" % (a, b))
437
438 def assertNotEqual(a, b):
439 if a == b:
440 raise TestFailed("%r == %r" % (a, b))
441
442 def assertIn(a, b):
443 if a not in b:
444 raise TestFailed("%r not in %r" % (a, b))
445
446 def assertNotIn(a, b):
447 if a in b:
448 raise TestFailed("%r in %r" % (a, b))
449
450 def assertGreater(a, b):
451 if not a > b:
452 raise TestFailed("%r not greater than %r" % (a, b))
453
454 def assertLess(a, b):
455 if not a < b:
456 raise TestFailed("%r not less than %r" % (a, b))
457
458 def assertTrue(a):
459 if not a:
460 raise TestFailed("%r is not True" % a)
461
462 def assertRegexpMatches(text, regexp):
463 if not re.search(regexp, text):
464 raise TestFailed("can't find %r in %r" % (regexp, text))