Merge pull request #941 from Wren6991/sim_lib_io_clke
[yosys.git] / backends / smt2 / smtio.py
1 #
2 # yosys -- Yosys Open SYnthesis Suite
3 #
4 # Copyright (C) 2012 Clifford Wolf <clifford@clifford.at>
5 #
6 # Permission to use, copy, modify, and/or distribute this software for any
7 # purpose with or without fee is hereby granted, provided that the above
8 # copyright notice and this permission notice appear in all copies.
9 #
10 # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11 # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12 # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13 # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14 # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15 # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
16 # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 #
18
19 import sys, re, os, signal
20 import subprocess
21 if os.name == "posix":
22 import resource
23 from copy import deepcopy
24 from select import select
25 from time import time
26 from queue import Queue, Empty
27 from threading import Thread
28
29
30 # This is needed so that the recursive SMT2 S-expression parser
31 # does not run out of stack frames when parsing large expressions
32 if os.name == "posix":
33 smtio_reclimit = 64 * 1024
34 if sys.getrecursionlimit() < smtio_reclimit:
35 sys.setrecursionlimit(smtio_reclimit)
36
37 current_rlimit_stack = resource.getrlimit(resource.RLIMIT_STACK)
38 if current_rlimit_stack[0] != resource.RLIM_INFINITY:
39 smtio_stacksize = 128 * 1024 * 1024
40 if os.uname().sysname == "Darwin":
41 # MacOS has rather conservative stack limits
42 smtio_stacksize = 16 * 1024 * 1024
43 if current_rlimit_stack[1] != resource.RLIM_INFINITY:
44 smtio_stacksize = min(smtio_stacksize, current_rlimit_stack[1])
45 if current_rlimit_stack[0] < smtio_stacksize:
46 resource.setrlimit(resource.RLIMIT_STACK, (smtio_stacksize, current_rlimit_stack[1]))
47
48
49 # currently running solvers (so we can kill them)
50 running_solvers = dict()
51 forced_shutdown = False
52 solvers_index = 0
53
54 def force_shutdown(signum, frame):
55 global forced_shutdown
56 if not forced_shutdown:
57 forced_shutdown = True
58 if signum is not None:
59 print("<%s>" % signal.Signals(signum).name)
60 for p in running_solvers.values():
61 # os.killpg(os.getpgid(p.pid), signal.SIGTERM)
62 os.kill(p.pid, signal.SIGTERM)
63 sys.exit(1)
64
65 if os.name == "posix":
66 signal.signal(signal.SIGHUP, force_shutdown)
67 signal.signal(signal.SIGINT, force_shutdown)
68 signal.signal(signal.SIGTERM, force_shutdown)
69
70 def except_hook(exctype, value, traceback):
71 if not forced_shutdown:
72 sys.__excepthook__(exctype, value, traceback)
73 force_shutdown(None, None)
74
75 sys.excepthook = except_hook
76
77
78 hex_dict = {
79 "0": "0000", "1": "0001", "2": "0010", "3": "0011",
80 "4": "0100", "5": "0101", "6": "0110", "7": "0111",
81 "8": "1000", "9": "1001", "A": "1010", "B": "1011",
82 "C": "1100", "D": "1101", "E": "1110", "F": "1111",
83 "a": "1010", "b": "1011", "c": "1100", "d": "1101",
84 "e": "1110", "f": "1111"
85 }
86
87
88 class SmtModInfo:
89 def __init__(self):
90 self.inputs = set()
91 self.outputs = set()
92 self.registers = set()
93 self.memories = dict()
94 self.wires = set()
95 self.wsize = dict()
96 self.clocks = dict()
97 self.cells = dict()
98 self.asserts = dict()
99 self.covers = dict()
100 self.anyconsts = dict()
101 self.anyseqs = dict()
102 self.allconsts = dict()
103 self.allseqs = dict()
104 self.asize = dict()
105
106
107 class SmtIo:
108 def __init__(self, opts=None):
109 global solvers_index
110
111 self.logic = None
112 self.logic_qf = True
113 self.logic_ax = True
114 self.logic_uf = True
115 self.logic_bv = True
116 self.logic_dt = False
117 self.forall = False
118 self.produce_models = True
119 self.smt2cache = [list()]
120 self.p = None
121 self.p_index = solvers_index
122 solvers_index += 1
123
124 if opts is not None:
125 self.logic = opts.logic
126 self.solver = opts.solver
127 self.solver_opts = opts.solver_opts
128 self.debug_print = opts.debug_print
129 self.debug_file = opts.debug_file
130 self.dummy_file = opts.dummy_file
131 self.timeinfo = opts.timeinfo
132 self.unroll = opts.unroll
133 self.noincr = opts.noincr
134 self.info_stmts = opts.info_stmts
135 self.nocomments = opts.nocomments
136
137 else:
138 self.solver = "yices"
139 self.solver_opts = list()
140 self.debug_print = False
141 self.debug_file = None
142 self.dummy_file = None
143 self.timeinfo = os.name != "nt"
144 self.unroll = False
145 self.noincr = False
146 self.info_stmts = list()
147 self.nocomments = False
148
149 self.start_time = time()
150
151 self.modinfo = dict()
152 self.curmod = None
153 self.topmod = None
154 self.setup_done = False
155
156 def __del__(self):
157 if self.p is not None and not forced_shutdown:
158 os.killpg(os.getpgid(self.p.pid), signal.SIGTERM)
159 if running_solvers is not None:
160 del running_solvers[self.p_index]
161
162 def setup(self):
163 assert not self.setup_done
164
165 if self.forall:
166 self.unroll = False
167
168 if self.solver == "yices":
169 if self.noincr:
170 self.popen_vargs = ['yices-smt2'] + self.solver_opts
171 else:
172 self.popen_vargs = ['yices-smt2', '--incremental'] + self.solver_opts
173
174 if self.solver == "z3":
175 self.popen_vargs = ['z3', '-smt2', '-in'] + self.solver_opts
176
177 if self.solver == "cvc4":
178 if self.noincr:
179 self.popen_vargs = ['cvc4', '--lang', 'smt2.6' if self.logic_dt else 'smt2'] + self.solver_opts
180 else:
181 self.popen_vargs = ['cvc4', '--incremental', '--lang', 'smt2.6' if self.logic_dt else 'smt2'] + self.solver_opts
182
183 if self.solver == "mathsat":
184 self.popen_vargs = ['mathsat'] + self.solver_opts
185
186 if self.solver == "boolector":
187 if self.noincr:
188 self.popen_vargs = ['boolector', '--smt2'] + self.solver_opts
189 else:
190 self.popen_vargs = ['boolector', '--smt2', '-i'] + self.solver_opts
191 self.unroll = True
192
193 if self.solver == "abc":
194 if len(self.solver_opts) > 0:
195 self.popen_vargs = ['yosys-abc', '-S', '; '.join(self.solver_opts)]
196 else:
197 self.popen_vargs = ['yosys-abc', '-S', '%blast; &sweep -C 5000; &syn4; &cec -s -m -C 2000']
198 self.logic_ax = False
199 self.unroll = True
200 self.noincr = True
201
202 if self.solver == "dummy":
203 assert self.dummy_file is not None
204 self.dummy_fd = open(self.dummy_file, "r")
205 else:
206 if self.dummy_file is not None:
207 self.dummy_fd = open(self.dummy_file, "w")
208 if not self.noincr:
209 self.p_open()
210
211 if self.unroll:
212 assert not self.forall
213 self.logic_uf = False
214 self.unroll_idcnt = 0
215 self.unroll_buffer = ""
216 self.unroll_sorts = set()
217 self.unroll_objs = set()
218 self.unroll_decls = dict()
219 self.unroll_cache = dict()
220 self.unroll_stack = list()
221
222 if self.logic is None:
223 self.logic = ""
224 if self.logic_qf: self.logic += "QF_"
225 if self.logic_ax: self.logic += "A"
226 if self.logic_uf: self.logic += "UF"
227 if self.logic_bv: self.logic += "BV"
228 if self.logic_dt: self.logic = "ALL"
229
230 self.setup_done = True
231
232 for stmt in self.info_stmts:
233 self.write(stmt)
234
235 if self.produce_models:
236 self.write("(set-option :produce-models true)")
237
238 self.write("(set-logic %s)" % self.logic)
239
240 def timestamp(self):
241 secs = int(time() - self.start_time)
242 return "## %3d:%02d:%02d " % (secs // (60*60), (secs // 60) % 60, secs % 60)
243
244 def replace_in_stmt(self, stmt, pat, repl):
245 if stmt == pat:
246 return repl
247
248 if isinstance(stmt, list):
249 return [self.replace_in_stmt(s, pat, repl) for s in stmt]
250
251 return stmt
252
253 def unroll_stmt(self, stmt):
254 if not isinstance(stmt, list):
255 return stmt
256
257 stmt = [self.unroll_stmt(s) for s in stmt]
258
259 if len(stmt) >= 2 and not isinstance(stmt[0], list) and stmt[0] in self.unroll_decls:
260 assert stmt[1] in self.unroll_objs
261
262 key = tuple(stmt)
263 if key not in self.unroll_cache:
264 decl = deepcopy(self.unroll_decls[key[0]])
265
266 self.unroll_cache[key] = "|UNROLL#%d|" % self.unroll_idcnt
267 decl[1] = self.unroll_cache[key]
268 self.unroll_idcnt += 1
269
270 if decl[0] == "declare-fun":
271 if isinstance(decl[3], list) or decl[3] not in self.unroll_sorts:
272 self.unroll_objs.add(decl[1])
273 decl[2] = list()
274 else:
275 self.unroll_objs.add(decl[1])
276 decl = list()
277
278 elif decl[0] == "define-fun":
279 arg_index = 1
280 for arg_name, arg_sort in decl[2]:
281 decl[4] = self.replace_in_stmt(decl[4], arg_name, key[arg_index])
282 arg_index += 1
283 decl[2] = list()
284
285 if len(decl) > 0:
286 decl = self.unroll_stmt(decl)
287 self.write(self.unparse(decl), unroll=False)
288
289 return self.unroll_cache[key]
290
291 return stmt
292
293 def p_thread_main(self):
294 while True:
295 data = self.p.stdout.readline().decode("ascii")
296 if data == "": break
297 self.p_queue.put(data)
298 self.p_queue.put("")
299 self.p_running = False
300
301 def p_open(self):
302 assert self.p is None
303 self.p = subprocess.Popen(self.popen_vargs, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
304 running_solvers[self.p_index] = self.p
305 self.p_running = True
306 self.p_next = None
307 self.p_queue = Queue()
308 self.p_thread = Thread(target=self.p_thread_main)
309 self.p_thread.start()
310
311 def p_write(self, data, flush):
312 assert self.p is not None
313 self.p.stdin.write(bytes(data, "ascii"))
314 if flush: self.p.stdin.flush()
315
316 def p_read(self):
317 assert self.p is not None
318 if self.p_next is not None:
319 data = self.p_next
320 self.p_next = None
321 return data
322 if not self.p_running:
323 return ""
324 return self.p_queue.get()
325
326 def p_poll(self, timeout=0.1):
327 assert self.p is not None
328 assert self.p_running
329 if self.p_next is not None:
330 return False
331 try:
332 self.p_next = self.p_queue.get(True, timeout)
333 return False
334 except Empty:
335 return True
336
337 def p_close(self):
338 assert self.p is not None
339 self.p.stdin.close()
340 self.p_thread.join()
341 assert not self.p_running
342 del running_solvers[self.p_index]
343 self.p = None
344 self.p_next = None
345 self.p_queue = None
346 self.p_thread = None
347
348 def write(self, stmt, unroll=True):
349 if stmt.startswith(";"):
350 self.info(stmt)
351 if not self.setup_done:
352 self.info_stmts.append(stmt)
353 return
354 elif not self.setup_done:
355 self.setup()
356
357 stmt = stmt.strip()
358
359 if self.nocomments or self.unroll:
360 stmt = re.sub(r" *;.*", "", stmt)
361 if stmt == "": return
362
363 if unroll and self.unroll:
364 stmt = self.unroll_buffer + stmt
365 self.unroll_buffer = ""
366
367 s = re.sub(r"\|[^|]*\|", "", stmt)
368 if s.count("(") != s.count(")"):
369 self.unroll_buffer = stmt + " "
370 return
371
372 s = self.parse(stmt)
373
374 if self.debug_print:
375 print("-> %s" % s)
376
377 if len(s) == 3 and s[0] == "declare-sort" and s[2] == "0":
378 self.unroll_sorts.add(s[1])
379 return
380
381 elif len(s) == 4 and s[0] == "declare-fun" and s[2] == [] and s[3] in self.unroll_sorts:
382 self.unroll_objs.add(s[1])
383 return
384
385 elif len(s) >= 4 and s[0] == "declare-fun":
386 for arg_sort in s[2]:
387 if arg_sort in self.unroll_sorts:
388 self.unroll_decls[s[1]] = s
389 return
390
391 elif len(s) >= 4 and s[0] == "define-fun":
392 for arg_name, arg_sort in s[2]:
393 if arg_sort in self.unroll_sorts:
394 self.unroll_decls[s[1]] = s
395 return
396
397 stmt = self.unparse(self.unroll_stmt(s))
398
399 if stmt == "(push 1)":
400 self.unroll_stack.append((
401 deepcopy(self.unroll_sorts),
402 deepcopy(self.unroll_objs),
403 deepcopy(self.unroll_decls),
404 deepcopy(self.unroll_cache),
405 ))
406
407 if stmt == "(pop 1)":
408 self.unroll_sorts, self.unroll_objs, self.unroll_decls, self.unroll_cache = self.unroll_stack.pop()
409
410 if self.debug_print:
411 print("> %s" % stmt)
412
413 if self.debug_file:
414 print(stmt, file=self.debug_file)
415 self.debug_file.flush()
416
417 if self.solver != "dummy":
418 if self.noincr:
419 if self.p is not None and not stmt.startswith("(get-"):
420 self.p_close()
421 if stmt == "(push 1)":
422 self.smt2cache.append(list())
423 elif stmt == "(pop 1)":
424 self.smt2cache.pop()
425 else:
426 if self.p is not None:
427 self.p_write(stmt + "\n", True)
428 self.smt2cache[-1].append(stmt)
429 else:
430 self.p_write(stmt + "\n", True)
431
432 def info(self, stmt):
433 if not stmt.startswith("; yosys-smt2-"):
434 return
435
436 fields = stmt.split()
437
438 if fields[1] == "yosys-smt2-nomem":
439 if self.logic is None:
440 self.logic_ax = False
441
442 if fields[1] == "yosys-smt2-nobv":
443 if self.logic is None:
444 self.logic_bv = False
445
446 if fields[1] == "yosys-smt2-stdt":
447 if self.logic is None:
448 self.logic_dt = True
449
450 if fields[1] == "yosys-smt2-forall":
451 if self.logic is None:
452 self.logic_qf = False
453 self.forall = True
454
455 if fields[1] == "yosys-smt2-module":
456 self.curmod = fields[2]
457 self.modinfo[self.curmod] = SmtModInfo()
458
459 if fields[1] == "yosys-smt2-cell":
460 self.modinfo[self.curmod].cells[fields[3]] = fields[2]
461
462 if fields[1] == "yosys-smt2-topmod":
463 self.topmod = fields[2]
464
465 if fields[1] == "yosys-smt2-input":
466 self.modinfo[self.curmod].inputs.add(fields[2])
467 self.modinfo[self.curmod].wsize[fields[2]] = int(fields[3])
468
469 if fields[1] == "yosys-smt2-output":
470 self.modinfo[self.curmod].outputs.add(fields[2])
471 self.modinfo[self.curmod].wsize[fields[2]] = int(fields[3])
472
473 if fields[1] == "yosys-smt2-register":
474 self.modinfo[self.curmod].registers.add(fields[2])
475 self.modinfo[self.curmod].wsize[fields[2]] = int(fields[3])
476
477 if fields[1] == "yosys-smt2-memory":
478 self.modinfo[self.curmod].memories[fields[2]] = (int(fields[3]), int(fields[4]), int(fields[5]), int(fields[6]), fields[7] == "async")
479
480 if fields[1] == "yosys-smt2-wire":
481 self.modinfo[self.curmod].wires.add(fields[2])
482 self.modinfo[self.curmod].wsize[fields[2]] = int(fields[3])
483
484 if fields[1] == "yosys-smt2-clock":
485 for edge in fields[3:]:
486 if fields[2] not in self.modinfo[self.curmod].clocks:
487 self.modinfo[self.curmod].clocks[fields[2]] = edge
488 elif self.modinfo[self.curmod].clocks[fields[2]] != edge:
489 self.modinfo[self.curmod].clocks[fields[2]] = "event"
490
491 if fields[1] == "yosys-smt2-assert":
492 self.modinfo[self.curmod].asserts["%s_a %s" % (self.curmod, fields[2])] = fields[3]
493
494 if fields[1] == "yosys-smt2-cover":
495 self.modinfo[self.curmod].covers["%s_c %s" % (self.curmod, fields[2])] = fields[3]
496
497 if fields[1] == "yosys-smt2-anyconst":
498 self.modinfo[self.curmod].anyconsts[fields[2]] = (fields[4], None if len(fields) <= 5 else fields[5])
499 self.modinfo[self.curmod].asize[fields[2]] = int(fields[3])
500
501 if fields[1] == "yosys-smt2-anyseq":
502 self.modinfo[self.curmod].anyseqs[fields[2]] = (fields[4], None if len(fields) <= 5 else fields[5])
503 self.modinfo[self.curmod].asize[fields[2]] = int(fields[3])
504
505 if fields[1] == "yosys-smt2-allconst":
506 self.modinfo[self.curmod].allconsts[fields[2]] = (fields[4], None if len(fields) <= 5 else fields[5])
507 self.modinfo[self.curmod].asize[fields[2]] = int(fields[3])
508
509 if fields[1] == "yosys-smt2-allseq":
510 self.modinfo[self.curmod].allseqs[fields[2]] = (fields[4], None if len(fields) <= 5 else fields[5])
511 self.modinfo[self.curmod].asize[fields[2]] = int(fields[3])
512
513 def hiernets(self, top, regs_only=False):
514 def hiernets_worker(nets, mod, cursor):
515 for netname in sorted(self.modinfo[mod].wsize.keys()):
516 if not regs_only or netname in self.modinfo[mod].registers:
517 nets.append(cursor + [netname])
518 for cellname, celltype in sorted(self.modinfo[mod].cells.items()):
519 hiernets_worker(nets, celltype, cursor + [cellname])
520
521 nets = list()
522 hiernets_worker(nets, top, [])
523 return nets
524
525 def hieranyconsts(self, top):
526 def worker(results, mod, cursor):
527 for name, value in sorted(self.modinfo[mod].anyconsts.items()):
528 width = self.modinfo[mod].asize[name]
529 results.append((cursor, name, value[0], value[1], width))
530 for cellname, celltype in sorted(self.modinfo[mod].cells.items()):
531 worker(results, celltype, cursor + [cellname])
532
533 results = list()
534 worker(results, top, [])
535 return results
536
537 def hieranyseqs(self, top):
538 def worker(results, mod, cursor):
539 for name, value in sorted(self.modinfo[mod].anyseqs.items()):
540 width = self.modinfo[mod].asize[name]
541 results.append((cursor, name, value[0], value[1], width))
542 for cellname, celltype in sorted(self.modinfo[mod].cells.items()):
543 worker(results, celltype, cursor + [cellname])
544
545 results = list()
546 worker(results, top, [])
547 return results
548
549 def hierallconsts(self, top):
550 def worker(results, mod, cursor):
551 for name, value in sorted(self.modinfo[mod].allconsts.items()):
552 width = self.modinfo[mod].asize[name]
553 results.append((cursor, name, value[0], value[1], width))
554 for cellname, celltype in sorted(self.modinfo[mod].cells.items()):
555 worker(results, celltype, cursor + [cellname])
556
557 results = list()
558 worker(results, top, [])
559 return results
560
561 def hierallseqs(self, top):
562 def worker(results, mod, cursor):
563 for name, value in sorted(self.modinfo[mod].allseqs.items()):
564 width = self.modinfo[mod].asize[name]
565 results.append((cursor, name, value[0], value[1], width))
566 for cellname, celltype in sorted(self.modinfo[mod].cells.items()):
567 worker(results, celltype, cursor + [cellname])
568
569 results = list()
570 worker(results, top, [])
571 return results
572
573 def hiermems(self, top):
574 def hiermems_worker(mems, mod, cursor):
575 for memname in sorted(self.modinfo[mod].memories.keys()):
576 mems.append(cursor + [memname])
577 for cellname, celltype in sorted(self.modinfo[mod].cells.items()):
578 hiermems_worker(mems, celltype, cursor + [cellname])
579
580 mems = list()
581 hiermems_worker(mems, top, [])
582 return mems
583
584 def read(self):
585 stmt = []
586 count_brackets = 0
587
588 while True:
589 if self.solver == "dummy":
590 line = self.dummy_fd.readline().strip()
591 else:
592 line = self.p_read().strip()
593 if self.dummy_file is not None:
594 self.dummy_fd.write(line + "\n")
595
596 count_brackets += line.count("(")
597 count_brackets -= line.count(")")
598 stmt.append(line)
599
600 if self.debug_print:
601 print("< %s" % line)
602 if count_brackets == 0:
603 break
604 if self.solver != "dummy" and self.p.poll():
605 print("%s Solver terminated unexpectedly: %s" % (self.timestamp(), "".join(stmt)), flush=True)
606 sys.exit(1)
607
608 stmt = "".join(stmt)
609 if stmt.startswith("(error"):
610 print("%s Solver Error: %s" % (self.timestamp(), stmt), flush=True)
611 if self.solver != "dummy":
612 self.p_close()
613 sys.exit(1)
614
615 return stmt
616
617 def check_sat(self):
618 if self.debug_print:
619 print("> (check-sat)")
620 if self.debug_file and not self.nocomments:
621 print("; running check-sat..", file=self.debug_file)
622 self.debug_file.flush()
623
624 if self.solver != "dummy":
625 if self.noincr:
626 if self.p is not None:
627 self.p_close()
628 self.p_open()
629 for cache_ctx in self.smt2cache:
630 for cache_stmt in cache_ctx:
631 self.p_write(cache_stmt + "\n", False)
632
633 self.p_write("(check-sat)\n", True)
634
635 if self.timeinfo:
636 i = 0
637 s = "/-\|"
638
639 count = 0
640 num_bs = 0
641 while self.p_poll():
642 count += 1
643
644 if count < 25:
645 continue
646
647 if count % 10 == 0 or count == 25:
648 secs = count // 10
649
650 if secs < 60:
651 m = "(%d seconds)" % secs
652 elif secs < 60*60:
653 m = "(%d seconds -- %d:%02d)" % (secs, secs // 60, secs % 60)
654 else:
655 m = "(%d seconds -- %d:%02d:%02d)" % (secs, secs // (60*60), (secs // 60) % 60, secs % 60)
656
657 print("%s %s %c" % ("\b \b" * num_bs, m, s[i]), end="", file=sys.stderr)
658 num_bs = len(m) + 3
659
660 else:
661 print("\b" + s[i], end="", file=sys.stderr)
662
663 sys.stderr.flush()
664 i = (i + 1) % len(s)
665
666 if num_bs != 0:
667 print("\b \b" * num_bs, end="", file=sys.stderr)
668 sys.stderr.flush()
669
670 else:
671 count = 0
672 while self.p_poll(60):
673 count += 1
674 msg = None
675
676 if count == 1:
677 msg = "1 minute"
678
679 elif count in [5, 10, 15, 30]:
680 msg = "%d minutes" % count
681
682 elif count == 60:
683 msg = "1 hour"
684
685 elif count % 60 == 0:
686 msg = "%d hours" % (count // 60)
687
688 if msg is not None:
689 print("%s waiting for solver (%s)" % (self.timestamp(), msg), flush=True)
690
691 result = self.read()
692
693 if self.debug_file:
694 print("(set-info :status %s)" % result, file=self.debug_file)
695 print("(check-sat)", file=self.debug_file)
696 self.debug_file.flush()
697
698 if result not in ["sat", "unsat"]:
699 if result == "":
700 print("%s Unexpected EOF response from solver." % (self.timestamp()), flush=True)
701 else:
702 print("%s Unexpected response from solver: %s" % (self.timestamp(), result), flush=True)
703 if self.solver != "dummy":
704 self.p_close()
705 sys.exit(1)
706
707 return result
708
709 def parse(self, stmt):
710 def worker(stmt):
711 if stmt[0] == '(':
712 expr = []
713 cursor = 1
714 while stmt[cursor] != ')':
715 el, le = worker(stmt[cursor:])
716 expr.append(el)
717 cursor += le
718 return expr, cursor+1
719
720 if stmt[0] == '|':
721 expr = "|"
722 cursor = 1
723 while stmt[cursor] != '|':
724 expr += stmt[cursor]
725 cursor += 1
726 expr += "|"
727 return expr, cursor+1
728
729 if stmt[0] in [" ", "\t", "\r", "\n"]:
730 el, le = worker(stmt[1:])
731 return el, le+1
732
733 expr = ""
734 cursor = 0
735 while stmt[cursor] not in ["(", ")", "|", " ", "\t", "\r", "\n"]:
736 expr += stmt[cursor]
737 cursor += 1
738 return expr, cursor
739 return worker(stmt)[0]
740
741 def unparse(self, stmt):
742 if isinstance(stmt, list):
743 return "(" + " ".join([self.unparse(s) for s in stmt]) + ")"
744 return stmt
745
746 def bv2hex(self, v):
747 h = ""
748 v = self.bv2bin(v)
749 while len(v) > 0:
750 d = 0
751 if len(v) > 0 and v[-1] == "1": d += 1
752 if len(v) > 1 and v[-2] == "1": d += 2
753 if len(v) > 2 and v[-3] == "1": d += 4
754 if len(v) > 3 and v[-4] == "1": d += 8
755 h = hex(d)[2:] + h
756 if len(v) < 4: break
757 v = v[:-4]
758 return h
759
760 def bv2bin(self, v):
761 if type(v) is list and len(v) == 3 and v[0] == "_" and v[1].startswith("bv"):
762 x, n = int(v[1][2:]), int(v[2])
763 return "".join("1" if (x & (1 << i)) else "0" for i in range(n-1, -1, -1))
764 if v == "true": return "1"
765 if v == "false": return "0"
766 if v.startswith("#b"):
767 return v[2:]
768 if v.startswith("#x"):
769 return "".join(hex_dict.get(x) for x in v[2:])
770 assert False
771
772 def bv2int(self, v):
773 return int(self.bv2bin(v), 2)
774
775 def get(self, expr):
776 self.write("(get-value (%s))" % (expr))
777 return self.parse(self.read())[0][1]
778
779 def get_list(self, expr_list):
780 if len(expr_list) == 0:
781 return []
782 self.write("(get-value (%s))" % " ".join(expr_list))
783 return [n[1] for n in self.parse(self.read())]
784
785 def get_path(self, mod, path):
786 assert mod in self.modinfo
787 path = path.replace("\\", "/").split(".")
788
789 for i in range(len(path)-1):
790 first = ".".join(path[0:i+1])
791 second = ".".join(path[i+1:])
792
793 if first in self.modinfo[mod].cells:
794 nextmod = self.modinfo[mod].cells[first]
795 return [first] + self.get_path(nextmod, second)
796
797 return [".".join(path)]
798
799 def net_expr(self, mod, base, path):
800 if len(path) == 0:
801 return base
802
803 if len(path) == 1:
804 assert mod in self.modinfo
805 if path[0] == "":
806 return base
807 if path[0] in self.modinfo[mod].cells:
808 return "(|%s_h %s| %s)" % (mod, path[0], base)
809 if path[0] in self.modinfo[mod].wsize:
810 return "(|%s_n %s| %s)" % (mod, path[0], base)
811 if path[0] in self.modinfo[mod].memories:
812 return "(|%s_m %s| %s)" % (mod, path[0], base)
813 assert 0
814
815 assert mod in self.modinfo
816 assert path[0] in self.modinfo[mod].cells
817
818 nextmod = self.modinfo[mod].cells[path[0]]
819 nextbase = "(|%s_h %s| %s)" % (mod, path[0], base)
820 return self.net_expr(nextmod, nextbase, path[1:])
821
822 def net_width(self, mod, net_path):
823 for i in range(len(net_path)-1):
824 assert mod in self.modinfo
825 assert net_path[i] in self.modinfo[mod].cells
826 mod = self.modinfo[mod].cells[net_path[i]]
827
828 assert mod in self.modinfo
829 assert net_path[-1] in self.modinfo[mod].wsize
830 return self.modinfo[mod].wsize[net_path[-1]]
831
832 def net_clock(self, mod, net_path):
833 for i in range(len(net_path)-1):
834 assert mod in self.modinfo
835 assert net_path[i] in self.modinfo[mod].cells
836 mod = self.modinfo[mod].cells[net_path[i]]
837
838 assert mod in self.modinfo
839 if net_path[-1] not in self.modinfo[mod].clocks:
840 return None
841 return self.modinfo[mod].clocks[net_path[-1]]
842
843 def net_exists(self, mod, net_path):
844 for i in range(len(net_path)-1):
845 if mod not in self.modinfo: return False
846 if net_path[i] not in self.modinfo[mod].cells: return False
847 mod = self.modinfo[mod].cells[net_path[i]]
848
849 if mod not in self.modinfo: return False
850 if net_path[-1] not in self.modinfo[mod].wsize: return False
851 return True
852
853 def mem_exists(self, mod, mem_path):
854 for i in range(len(mem_path)-1):
855 if mod not in self.modinfo: return False
856 if mem_path[i] not in self.modinfo[mod].cells: return False
857 mod = self.modinfo[mod].cells[mem_path[i]]
858
859 if mod not in self.modinfo: return False
860 if mem_path[-1] not in self.modinfo[mod].memories: return False
861 return True
862
863 def mem_expr(self, mod, base, path, port=None, infomode=False):
864 if len(path) == 1:
865 assert mod in self.modinfo
866 assert path[0] in self.modinfo[mod].memories
867 if infomode:
868 return self.modinfo[mod].memories[path[0]]
869 return "(|%s_m%s %s| %s)" % (mod, "" if port is None else ":%s" % port, path[0], base)
870
871 assert mod in self.modinfo
872 assert path[0] in self.modinfo[mod].cells
873
874 nextmod = self.modinfo[mod].cells[path[0]]
875 nextbase = "(|%s_h %s| %s)" % (mod, path[0], base)
876 return self.mem_expr(nextmod, nextbase, path[1:], port=port, infomode=infomode)
877
878 def mem_info(self, mod, path):
879 return self.mem_expr(mod, "", path, infomode=True)
880
881 def get_net(self, mod_name, net_path, state_name):
882 return self.get(self.net_expr(mod_name, state_name, net_path))
883
884 def get_net_list(self, mod_name, net_path_list, state_name):
885 return self.get_list([self.net_expr(mod_name, state_name, n) for n in net_path_list])
886
887 def get_net_hex(self, mod_name, net_path, state_name):
888 return self.bv2hex(self.get_net(mod_name, net_path, state_name))
889
890 def get_net_hex_list(self, mod_name, net_path_list, state_name):
891 return [self.bv2hex(v) for v in self.get_net_list(mod_name, net_path_list, state_name)]
892
893 def get_net_bin(self, mod_name, net_path, state_name):
894 return self.bv2bin(self.get_net(mod_name, net_path, state_name))
895
896 def get_net_bin_list(self, mod_name, net_path_list, state_name):
897 return [self.bv2bin(v) for v in self.get_net_list(mod_name, net_path_list, state_name)]
898
899 def wait(self):
900 if self.p is not None:
901 self.p.wait()
902 self.p_close()
903
904
905 class SmtOpts:
906 def __init__(self):
907 self.shortopts = "s:S:v"
908 self.longopts = ["unroll", "noincr", "noprogress", "dump-smt2=", "logic=", "dummy=", "info=", "nocomments"]
909 self.solver = "yices"
910 self.solver_opts = list()
911 self.debug_print = False
912 self.debug_file = None
913 self.dummy_file = None
914 self.unroll = False
915 self.noincr = False
916 self.timeinfo = os.name != "nt"
917 self.logic = None
918 self.info_stmts = list()
919 self.nocomments = False
920
921 def handle(self, o, a):
922 if o == "-s":
923 self.solver = a
924 elif o == "-S":
925 self.solver_opts.append(a)
926 elif o == "-v":
927 self.debug_print = True
928 elif o == "--unroll":
929 self.unroll = True
930 elif o == "--noincr":
931 self.noincr = True
932 elif o == "--noprogress":
933 self.timeinfo = False
934 elif o == "--dump-smt2":
935 self.debug_file = open(a, "w")
936 elif o == "--logic":
937 self.logic = a
938 elif o == "--dummy":
939 self.dummy_file = a
940 elif o == "--info":
941 self.info_stmts.append(a)
942 elif o == "--nocomments":
943 self.nocomments = True
944 else:
945 return False
946 return True
947
948 def helpmsg(self):
949 return """
950 -s <solver>
951 set SMT solver: z3, yices, boolector, cvc4, mathsat, dummy
952 default: yices
953
954 -S <opt>
955 pass <opt> as command line argument to the solver
956
957 --logic <smt2_logic>
958 use the specified SMT2 logic (e.g. QF_AUFBV)
959
960 --dummy <filename>
961 if solver is "dummy", read solver output from that file
962 otherwise: write solver output to that file
963
964 -v
965 enable debug output
966
967 --unroll
968 unroll uninterpreted functions
969
970 --noincr
971 don't use incremental solving, instead restart solver for
972 each (check-sat). This also avoids (push) and (pop).
973
974 --noprogress
975 disable timer display during solving
976 (this option is set implicitly on Windows)
977
978 --dump-smt2 <filename>
979 write smt2 statements to file
980
981 --info <smt2-info-stmt>
982 include the specified smt2 info statement in the smt2 output
983
984 --nocomments
985 strip all comments from the generated smt2 code
986 """
987
988
989 class MkVcd:
990 def __init__(self, f):
991 self.f = f
992 self.t = -1
993 self.nets = dict()
994 self.clocks = dict()
995
996 def add_net(self, path, width):
997 path = tuple(path)
998 assert self.t == -1
999 key = "n%d" % len(self.nets)
1000 self.nets[path] = (key, width)
1001
1002 def add_clock(self, path, edge):
1003 path = tuple(path)
1004 assert self.t == -1
1005 key = "n%d" % len(self.nets)
1006 self.nets[path] = (key, 1)
1007 self.clocks[path] = (key, edge)
1008
1009 def set_net(self, path, bits):
1010 path = tuple(path)
1011 assert self.t >= 0
1012 assert path in self.nets
1013 if path not in self.clocks:
1014 print("b%s %s" % (bits, self.nets[path][0]), file=self.f)
1015
1016 def escape_name(self, name):
1017 name = re.sub(r"\[([0-9a-zA-Z_]*[a-zA-Z_][0-9a-zA-Z_]*)\]", r"<\1>", name)
1018 if re.match("[\[\]]", name) and name[0] != "\\":
1019 name = "\\" + name
1020 return name
1021
1022 def set_time(self, t):
1023 assert t >= self.t
1024 if t != self.t:
1025 if self.t == -1:
1026 print("$var integer 32 t smt_step $end", file=self.f)
1027 print("$var event 1 ! smt_clock $end", file=self.f)
1028
1029 scope = []
1030 for path in sorted(self.nets):
1031 key, width = self.nets[path]
1032
1033 uipath = list(path)
1034 if "." in uipath[-1]:
1035 uipath = uipath[0:-1] + uipath[-1].split(".")
1036 for i in range(len(uipath)):
1037 uipath[i] = re.sub(r"\[([^\]]*)\]", r"<\1>", uipath[i])
1038
1039 while uipath[:len(scope)] != scope:
1040 print("$upscope $end", file=self.f)
1041 scope = scope[:-1]
1042
1043 while uipath[:-1] != scope:
1044 print("$scope module %s $end" % uipath[len(scope)], file=self.f)
1045 scope.append(uipath[len(scope)])
1046
1047 if path in self.clocks and self.clocks[path][1] == "event":
1048 print("$var event 1 %s %s $end" % (key, uipath[-1]), file=self.f)
1049 else:
1050 print("$var wire %d %s %s $end" % (width, key, uipath[-1]), file=self.f)
1051
1052 for i in range(len(scope)):
1053 print("$upscope $end", file=self.f)
1054
1055 print("$enddefinitions $end", file=self.f)
1056
1057 self.t = t
1058 assert self.t >= 0
1059
1060 if self.t > 0:
1061 print("#%d" % (10 * self.t - 5), file=self.f)
1062 for path in sorted(self.clocks.keys()):
1063 if self.clocks[path][1] == "posedge":
1064 print("b0 %s" % self.nets[path][0], file=self.f)
1065 elif self.clocks[path][1] == "negedge":
1066 print("b1 %s" % self.nets[path][0], file=self.f)
1067
1068 print("#%d" % (10 * self.t), file=self.f)
1069 print("1!", file=self.f)
1070 print("b%s t" % format(self.t, "032b"), file=self.f)
1071
1072 for path in sorted(self.clocks.keys()):
1073 if self.clocks[path][1] == "negedge":
1074 print("b0 %s" % self.nets[path][0], file=self.f)
1075 else:
1076 print("b1 %s" % self.nets[path][0], file=self.f)