Improve SMT2 encoding of $reduce_{and,or,bool}
[yosys.git] / backends / smt2 / smtio.py
1 #
2 # yosys -- Yosys Open SYnthesis Suite
3 #
4 # Copyright (C) 2012 Clifford Wolf <clifford@clifford.at>
5 #
6 # Permission to use, copy, modify, and/or distribute this software for any
7 # purpose with or without fee is hereby granted, provided that the above
8 # copyright notice and this permission notice appear in all copies.
9 #
10 # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11 # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12 # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13 # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14 # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15 # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
16 # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 #
18
19 import sys, re, os, signal
20 import resource, subprocess
21 from copy import deepcopy
22 from select import select
23 from time import time
24 from queue import Queue, Empty
25 from threading import Thread
26
27
28 # This is needed so that the recursive SMT2 S-expression parser
29 # does not run out of stack frames when parsing large expressions
30 smtio_reclimit = 64 * 1024
31 smtio_stacksize = 128 * 1024 * 1024
32 if sys.getrecursionlimit() < smtio_reclimit:
33 sys.setrecursionlimit(smtio_reclimit)
34 if resource.getrlimit(resource.RLIMIT_STACK)[0] < smtio_stacksize:
35 resource.setrlimit(resource.RLIMIT_STACK, (smtio_stacksize, -1))
36
37
38 # currently running solvers (so we can kill them)
39 running_solvers = dict()
40 got_term_signal = False
41 solvers_index = 0
42
43 def sig_handler(signum, frame):
44 global got_term_signal
45 if not got_term_signal:
46 print("<%s>" % signal.Signals(signum).name)
47 got_term_signal = True
48 for p in running_solvers.values():
49 os.killpg(os.getpgid(p.pid), signal.SIGTERM)
50 sys.exit(1)
51
52 signal.signal(signal.SIGINT, sig_handler)
53 signal.signal(signal.SIGHUP, sig_handler)
54 signal.signal(signal.SIGTERM, sig_handler)
55
56
57 hex_dict = {
58 "0": "0000", "1": "0001", "2": "0010", "3": "0011",
59 "4": "0100", "5": "0101", "6": "0110", "7": "0111",
60 "8": "1000", "9": "1001", "A": "1010", "B": "1011",
61 "C": "1100", "D": "1101", "E": "1110", "F": "1111",
62 "a": "1010", "b": "1011", "c": "1100", "d": "1101",
63 "e": "1110", "f": "1111"
64 }
65
66
67 class SmtModInfo:
68 def __init__(self):
69 self.inputs = set()
70 self.outputs = set()
71 self.registers = set()
72 self.memories = dict()
73 self.wires = set()
74 self.wsize = dict()
75 self.clocks = dict()
76 self.cells = dict()
77 self.asserts = dict()
78 self.covers = dict()
79 self.anyconsts = dict()
80 self.anyseqs = dict()
81 self.allconsts = dict()
82 self.allseqs = dict()
83 self.asize = dict()
84
85
86 class SmtIo:
87 def __init__(self, opts=None):
88 global solvers_index
89
90 self.logic = None
91 self.logic_qf = True
92 self.logic_ax = True
93 self.logic_uf = True
94 self.logic_bv = True
95 self.logic_dt = False
96 self.forall = False
97 self.produce_models = True
98 self.smt2cache = [list()]
99 self.p = None
100 self.p_index = solvers_index
101 solvers_index += 1
102
103 if opts is not None:
104 self.logic = opts.logic
105 self.solver = opts.solver
106 self.solver_opts = opts.solver_opts
107 self.debug_print = opts.debug_print
108 self.debug_file = opts.debug_file
109 self.dummy_file = opts.dummy_file
110 self.timeinfo = opts.timeinfo
111 self.unroll = opts.unroll
112 self.noincr = opts.noincr
113 self.info_stmts = opts.info_stmts
114 self.nocomments = opts.nocomments
115
116 else:
117 self.solver = "yices"
118 self.solver_opts = list()
119 self.debug_print = False
120 self.debug_file = None
121 self.dummy_file = None
122 self.timeinfo = os.name != "nt"
123 self.unroll = False
124 self.noincr = False
125 self.info_stmts = list()
126 self.nocomments = False
127
128 self.start_time = time()
129
130 self.modinfo = dict()
131 self.curmod = None
132 self.topmod = None
133 self.setup_done = False
134
135 def __del__(self):
136 if self.p is not None:
137 os.killpg(os.getpgid(self.p.pid), signal.SIGTERM)
138 if running_solvers is not None:
139 del running_solvers[self.p_index]
140
141 def setup(self):
142 assert not self.setup_done
143
144 if self.forall:
145 self.unroll = False
146
147 if self.solver == "yices":
148 self.popen_vargs = ['yices-smt2', '--incremental'] + self.solver_opts
149
150 if self.solver == "z3":
151 self.popen_vargs = ['z3', '-smt2', '-in'] + self.solver_opts
152
153 if self.solver == "cvc4":
154 self.popen_vargs = ['cvc4', '--incremental', '--lang', 'smt2.6' if self.logic_dt else 'smt2'] + self.solver_opts
155
156 if self.solver == "mathsat":
157 self.popen_vargs = ['mathsat'] + self.solver_opts
158
159 if self.solver == "boolector":
160 self.popen_vargs = ['boolector', '--smt2', '-i'] + self.solver_opts
161 self.unroll = True
162
163 if self.solver == "abc":
164 if len(self.solver_opts) > 0:
165 self.popen_vargs = ['yosys-abc', '-S', '; '.join(self.solver_opts)]
166 else:
167 self.popen_vargs = ['yosys-abc', '-S', '%blast; &sweep -C 5000; &syn4; &cec -s -m -C 2000']
168 self.logic_ax = False
169 self.unroll = True
170 self.noincr = True
171
172 if self.solver == "dummy":
173 assert self.dummy_file is not None
174 self.dummy_fd = open(self.dummy_file, "r")
175 else:
176 if self.dummy_file is not None:
177 self.dummy_fd = open(self.dummy_file, "w")
178 if not self.noincr:
179 self.p_open()
180
181 if self.unroll:
182 assert not self.forall
183 self.logic_uf = False
184 self.unroll_idcnt = 0
185 self.unroll_buffer = ""
186 self.unroll_sorts = set()
187 self.unroll_objs = set()
188 self.unroll_decls = dict()
189 self.unroll_cache = dict()
190 self.unroll_stack = list()
191
192 if self.logic is None:
193 self.logic = ""
194 if self.logic_qf: self.logic += "QF_"
195 if self.logic_ax: self.logic += "A"
196 if self.logic_uf: self.logic += "UF"
197 if self.logic_bv: self.logic += "BV"
198 if self.logic_dt: self.logic = "ALL"
199
200 self.setup_done = True
201
202 for stmt in self.info_stmts:
203 self.write(stmt)
204
205 if self.produce_models:
206 self.write("(set-option :produce-models true)")
207
208 self.write("(set-logic %s)" % self.logic)
209
210 def timestamp(self):
211 secs = int(time() - self.start_time)
212 return "## %6d %3d:%02d:%02d " % (secs, secs // (60*60), (secs // 60) % 60, secs % 60)
213
214 def replace_in_stmt(self, stmt, pat, repl):
215 if stmt == pat:
216 return repl
217
218 if isinstance(stmt, list):
219 return [self.replace_in_stmt(s, pat, repl) for s in stmt]
220
221 return stmt
222
223 def unroll_stmt(self, stmt):
224 if not isinstance(stmt, list):
225 return stmt
226
227 stmt = [self.unroll_stmt(s) for s in stmt]
228
229 if len(stmt) >= 2 and not isinstance(stmt[0], list) and stmt[0] in self.unroll_decls:
230 assert stmt[1] in self.unroll_objs
231
232 key = tuple(stmt)
233 if key not in self.unroll_cache:
234 decl = deepcopy(self.unroll_decls[key[0]])
235
236 self.unroll_cache[key] = "|UNROLL#%d|" % self.unroll_idcnt
237 decl[1] = self.unroll_cache[key]
238 self.unroll_idcnt += 1
239
240 if decl[0] == "declare-fun":
241 if isinstance(decl[3], list) or decl[3] not in self.unroll_sorts:
242 self.unroll_objs.add(decl[1])
243 decl[2] = list()
244 else:
245 self.unroll_objs.add(decl[1])
246 decl = list()
247
248 elif decl[0] == "define-fun":
249 arg_index = 1
250 for arg_name, arg_sort in decl[2]:
251 decl[4] = self.replace_in_stmt(decl[4], arg_name, key[arg_index])
252 arg_index += 1
253 decl[2] = list()
254
255 if len(decl) > 0:
256 decl = self.unroll_stmt(decl)
257 self.write(self.unparse(decl), unroll=False)
258
259 return self.unroll_cache[key]
260
261 return stmt
262
263 def p_thread_main(self):
264 while True:
265 data = self.p.stdout.readline().decode("ascii")
266 if data == "": break
267 self.p_queue.put(data)
268 self.p_queue.put("")
269 self.p_running = False
270
271 def p_open(self):
272 assert self.p is None
273 self.p = subprocess.Popen(self.popen_vargs, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
274 running_solvers[self.p_index] = self.p
275 self.p_running = True
276 self.p_next = None
277 self.p_queue = Queue()
278 self.p_thread = Thread(target=self.p_thread_main)
279 self.p_thread.start()
280
281 def p_write(self, data, flush):
282 assert self.p is not None
283 self.p.stdin.write(bytes(data, "ascii"))
284 if flush: self.p.stdin.flush()
285
286 def p_read(self):
287 assert self.p is not None
288 assert self.p_running
289 if self.p_next is not None:
290 data = self.p_next
291 self.p_next = None
292 return data
293 return self.p_queue.get()
294
295 def p_poll(self):
296 assert self.p is not None
297 assert self.p_running
298 if self.p_next is not None:
299 return False
300 try:
301 self.p_next = self.p_queue.get(True, 0.1)
302 return False
303 except Empty:
304 return True
305
306 def p_close(self):
307 assert self.p is not None
308 self.p.stdin.close()
309 self.p_thread.join()
310 assert not self.p_running
311 del running_solvers[self.p_index]
312 self.p = None
313 self.p_next = None
314 self.p_queue = None
315 self.p_thread = None
316
317 def write(self, stmt, unroll=True):
318 if stmt.startswith(";"):
319 self.info(stmt)
320 if not self.setup_done:
321 self.info_stmts.append(stmt)
322 return
323 elif not self.setup_done:
324 self.setup()
325
326 stmt = stmt.strip()
327
328 if self.nocomments or self.unroll:
329 stmt = re.sub(r" *;.*", "", stmt)
330 if stmt == "": return
331
332 if unroll and self.unroll:
333 stmt = self.unroll_buffer + stmt
334 self.unroll_buffer = ""
335
336 s = re.sub(r"\|[^|]*\|", "", stmt)
337 if s.count("(") != s.count(")"):
338 self.unroll_buffer = stmt + " "
339 return
340
341 s = self.parse(stmt)
342
343 if self.debug_print:
344 print("-> %s" % s)
345
346 if len(s) == 3 and s[0] == "declare-sort" and s[2] == "0":
347 self.unroll_sorts.add(s[1])
348 return
349
350 elif len(s) == 4 and s[0] == "declare-fun" and s[2] == [] and s[3] in self.unroll_sorts:
351 self.unroll_objs.add(s[1])
352 return
353
354 elif len(s) >= 4 and s[0] == "declare-fun":
355 for arg_sort in s[2]:
356 if arg_sort in self.unroll_sorts:
357 self.unroll_decls[s[1]] = s
358 return
359
360 elif len(s) >= 4 and s[0] == "define-fun":
361 for arg_name, arg_sort in s[2]:
362 if arg_sort in self.unroll_sorts:
363 self.unroll_decls[s[1]] = s
364 return
365
366 stmt = self.unparse(self.unroll_stmt(s))
367
368 if stmt == "(push 1)":
369 self.unroll_stack.append((
370 deepcopy(self.unroll_sorts),
371 deepcopy(self.unroll_objs),
372 deepcopy(self.unroll_decls),
373 deepcopy(self.unroll_cache),
374 ))
375
376 if stmt == "(pop 1)":
377 self.unroll_sorts, self.unroll_objs, self.unroll_decls, self.unroll_cache = self.unroll_stack.pop()
378
379 if self.debug_print:
380 print("> %s" % stmt)
381
382 if self.debug_file:
383 print(stmt, file=self.debug_file)
384 self.debug_file.flush()
385
386 if self.solver != "dummy":
387 if self.noincr:
388 if self.p is not None and not stmt.startswith("(get-"):
389 self.p_close()
390 if stmt == "(push 1)":
391 self.smt2cache.append(list())
392 elif stmt == "(pop 1)":
393 self.smt2cache.pop()
394 else:
395 if self.p is not None:
396 self.p_write(stmt + "\n", True)
397 self.smt2cache[-1].append(stmt)
398 else:
399 self.p_write(stmt + "\n", True)
400
401 def info(self, stmt):
402 if not stmt.startswith("; yosys-smt2-"):
403 return
404
405 fields = stmt.split()
406
407 if fields[1] == "yosys-smt2-nomem":
408 if self.logic is None:
409 self.logic_ax = False
410
411 if fields[1] == "yosys-smt2-nobv":
412 if self.logic is None:
413 self.logic_bv = False
414
415 if fields[1] == "yosys-smt2-stdt":
416 if self.logic is None:
417 self.logic_dt = True
418
419 if fields[1] == "yosys-smt2-forall":
420 if self.logic is None:
421 self.logic_qf = False
422 self.forall = True
423
424 if fields[1] == "yosys-smt2-module":
425 self.curmod = fields[2]
426 self.modinfo[self.curmod] = SmtModInfo()
427
428 if fields[1] == "yosys-smt2-cell":
429 self.modinfo[self.curmod].cells[fields[3]] = fields[2]
430
431 if fields[1] == "yosys-smt2-topmod":
432 self.topmod = fields[2]
433
434 if fields[1] == "yosys-smt2-input":
435 self.modinfo[self.curmod].inputs.add(fields[2])
436 self.modinfo[self.curmod].wsize[fields[2]] = int(fields[3])
437
438 if fields[1] == "yosys-smt2-output":
439 self.modinfo[self.curmod].outputs.add(fields[2])
440 self.modinfo[self.curmod].wsize[fields[2]] = int(fields[3])
441
442 if fields[1] == "yosys-smt2-register":
443 self.modinfo[self.curmod].registers.add(fields[2])
444 self.modinfo[self.curmod].wsize[fields[2]] = int(fields[3])
445
446 if fields[1] == "yosys-smt2-memory":
447 self.modinfo[self.curmod].memories[fields[2]] = (int(fields[3]), int(fields[4]), int(fields[5]), int(fields[6]), fields[7] == "async")
448
449 if fields[1] == "yosys-smt2-wire":
450 self.modinfo[self.curmod].wires.add(fields[2])
451 self.modinfo[self.curmod].wsize[fields[2]] = int(fields[3])
452
453 if fields[1] == "yosys-smt2-clock":
454 for edge in fields[3:]:
455 if fields[2] not in self.modinfo[self.curmod].clocks:
456 self.modinfo[self.curmod].clocks[fields[2]] = edge
457 elif self.modinfo[self.curmod].clocks[fields[2]] != edge:
458 self.modinfo[self.curmod].clocks[fields[2]] = "event"
459
460 if fields[1] == "yosys-smt2-assert":
461 self.modinfo[self.curmod].asserts["%s_a %s" % (self.curmod, fields[2])] = fields[3]
462
463 if fields[1] == "yosys-smt2-cover":
464 self.modinfo[self.curmod].covers["%s_c %s" % (self.curmod, fields[2])] = fields[3]
465
466 if fields[1] == "yosys-smt2-anyconst":
467 self.modinfo[self.curmod].anyconsts[fields[2]] = (fields[4], None if len(fields) <= 5 else fields[5])
468 self.modinfo[self.curmod].asize[fields[2]] = int(fields[3])
469
470 if fields[1] == "yosys-smt2-anyseq":
471 self.modinfo[self.curmod].anyseqs[fields[2]] = (fields[4], None if len(fields) <= 5 else fields[5])
472 self.modinfo[self.curmod].asize[fields[2]] = int(fields[3])
473
474 if fields[1] == "yosys-smt2-allconst":
475 self.modinfo[self.curmod].allconsts[fields[2]] = (fields[4], None if len(fields) <= 5 else fields[5])
476 self.modinfo[self.curmod].asize[fields[2]] = int(fields[3])
477
478 if fields[1] == "yosys-smt2-allseq":
479 self.modinfo[self.curmod].allseqs[fields[2]] = (fields[4], None if len(fields) <= 5 else fields[5])
480 self.modinfo[self.curmod].asize[fields[2]] = int(fields[3])
481
482 def hiernets(self, top, regs_only=False):
483 def hiernets_worker(nets, mod, cursor):
484 for netname in sorted(self.modinfo[mod].wsize.keys()):
485 if not regs_only or netname in self.modinfo[mod].registers:
486 nets.append(cursor + [netname])
487 for cellname, celltype in sorted(self.modinfo[mod].cells.items()):
488 hiernets_worker(nets, celltype, cursor + [cellname])
489
490 nets = list()
491 hiernets_worker(nets, top, [])
492 return nets
493
494 def hieranyconsts(self, top):
495 def worker(results, mod, cursor):
496 for name, value in sorted(self.modinfo[mod].anyconsts.items()):
497 width = self.modinfo[mod].asize[name]
498 results.append((cursor, name, value[0], value[1], width))
499 for cellname, celltype in sorted(self.modinfo[mod].cells.items()):
500 worker(results, celltype, cursor + [cellname])
501
502 results = list()
503 worker(results, top, [])
504 return results
505
506 def hieranyseqs(self, top):
507 def worker(results, mod, cursor):
508 for name, value in sorted(self.modinfo[mod].anyseqs.items()):
509 width = self.modinfo[mod].asize[name]
510 results.append((cursor, name, value[0], value[1], width))
511 for cellname, celltype in sorted(self.modinfo[mod].cells.items()):
512 worker(results, celltype, cursor + [cellname])
513
514 results = list()
515 worker(results, top, [])
516 return results
517
518 def hierallconsts(self, top):
519 def worker(results, mod, cursor):
520 for name, value in sorted(self.modinfo[mod].allconsts.items()):
521 width = self.modinfo[mod].asize[name]
522 results.append((cursor, name, value[0], value[1], width))
523 for cellname, celltype in sorted(self.modinfo[mod].cells.items()):
524 worker(results, celltype, cursor + [cellname])
525
526 results = list()
527 worker(results, top, [])
528 return results
529
530 def hierallseqs(self, top):
531 def worker(results, mod, cursor):
532 for name, value in sorted(self.modinfo[mod].allseqs.items()):
533 width = self.modinfo[mod].asize[name]
534 results.append((cursor, name, value[0], value[1], width))
535 for cellname, celltype in sorted(self.modinfo[mod].cells.items()):
536 worker(results, celltype, cursor + [cellname])
537
538 results = list()
539 worker(results, top, [])
540 return results
541
542 def hiermems(self, top):
543 def hiermems_worker(mems, mod, cursor):
544 for memname in sorted(self.modinfo[mod].memories.keys()):
545 mems.append(cursor + [memname])
546 for cellname, celltype in sorted(self.modinfo[mod].cells.items()):
547 hiermems_worker(mems, celltype, cursor + [cellname])
548
549 mems = list()
550 hiermems_worker(mems, top, [])
551 return mems
552
553 def read(self):
554 stmt = []
555 count_brackets = 0
556
557 while True:
558 if self.solver == "dummy":
559 line = self.dummy_fd.readline().strip()
560 else:
561 line = self.p_read().strip()
562 if self.dummy_file is not None:
563 self.dummy_fd.write(line + "\n")
564
565 count_brackets += line.count("(")
566 count_brackets -= line.count(")")
567 stmt.append(line)
568
569 if self.debug_print:
570 print("< %s" % line)
571 if count_brackets == 0:
572 break
573 if self.solver != "dummy" and self.p.poll():
574 print("SMT Solver terminated unexpectedly: %s" % "".join(stmt), flush=True)
575 sys.exit(1)
576
577 stmt = "".join(stmt)
578 if stmt.startswith("(error"):
579 print("SMT Solver Error: %s" % stmt, file=sys.stderr, flush=True)
580 if self.solver != "dummy":
581 self.p_close()
582 sys.exit(1)
583
584 return stmt
585
586 def check_sat(self):
587 if self.debug_print:
588 print("> (check-sat)")
589 if self.debug_file and not self.nocomments:
590 print("; running check-sat..", file=self.debug_file)
591 self.debug_file.flush()
592
593 if self.solver != "dummy":
594 if self.noincr:
595 if self.p is not None:
596 self.p_close()
597 self.p_open()
598 for cache_ctx in self.smt2cache:
599 for cache_stmt in cache_ctx:
600 self.p_write(cache_stmt + "\n", False)
601
602 self.p_write("(check-sat)\n", True)
603
604 if self.timeinfo:
605 i = 0
606 s = "/-\|"
607
608 count = 0
609 num_bs = 0
610 while self.p_poll():
611 count += 1
612
613 if count < 25:
614 continue
615
616 if count % 10 == 0 or count == 25:
617 secs = count // 10
618
619 if secs < 60:
620 m = "(%d seconds)" % secs
621 elif secs < 60*60:
622 m = "(%d seconds -- %d:%02d)" % (secs, secs // 60, secs % 60)
623 else:
624 m = "(%d seconds -- %d:%02d:%02d)" % (secs, secs // (60*60), (secs // 60) % 60, secs % 60)
625
626 print("%s %s %c" % ("\b \b" * num_bs, m, s[i]), end="", file=sys.stderr)
627 num_bs = len(m) + 3
628
629 else:
630 print("\b" + s[i], end="", file=sys.stderr)
631
632 sys.stderr.flush()
633 i = (i + 1) % len(s)
634
635 if num_bs != 0:
636 print("\b \b" * num_bs, end="", file=sys.stderr)
637 sys.stderr.flush()
638
639 result = self.read()
640 if self.debug_file:
641 print("(set-info :status %s)" % result, file=self.debug_file)
642 print("(check-sat)", file=self.debug_file)
643 self.debug_file.flush()
644 return result
645
646 def parse(self, stmt):
647 def worker(stmt):
648 if stmt[0] == '(':
649 expr = []
650 cursor = 1
651 while stmt[cursor] != ')':
652 el, le = worker(stmt[cursor:])
653 expr.append(el)
654 cursor += le
655 return expr, cursor+1
656
657 if stmt[0] == '|':
658 expr = "|"
659 cursor = 1
660 while stmt[cursor] != '|':
661 expr += stmt[cursor]
662 cursor += 1
663 expr += "|"
664 return expr, cursor+1
665
666 if stmt[0] in [" ", "\t", "\r", "\n"]:
667 el, le = worker(stmt[1:])
668 return el, le+1
669
670 expr = ""
671 cursor = 0
672 while stmt[cursor] not in ["(", ")", "|", " ", "\t", "\r", "\n"]:
673 expr += stmt[cursor]
674 cursor += 1
675 return expr, cursor
676 return worker(stmt)[0]
677
678 def unparse(self, stmt):
679 if isinstance(stmt, list):
680 return "(" + " ".join([self.unparse(s) for s in stmt]) + ")"
681 return stmt
682
683 def bv2hex(self, v):
684 h = ""
685 v = self.bv2bin(v)
686 while len(v) > 0:
687 d = 0
688 if len(v) > 0 and v[-1] == "1": d += 1
689 if len(v) > 1 and v[-2] == "1": d += 2
690 if len(v) > 2 and v[-3] == "1": d += 4
691 if len(v) > 3 and v[-4] == "1": d += 8
692 h = hex(d)[2:] + h
693 if len(v) < 4: break
694 v = v[:-4]
695 return h
696
697 def bv2bin(self, v):
698 if v == "true": return "1"
699 if v == "false": return "0"
700 if v.startswith("#b"):
701 return v[2:]
702 if v.startswith("#x"):
703 return "".join(hex_dict.get(x) for x in v[2:])
704 assert False
705
706 def bv2int(self, v):
707 return int(self.bv2bin(v), 2)
708
709 def get(self, expr):
710 self.write("(get-value (%s))" % (expr))
711 return self.parse(self.read())[0][1]
712
713 def get_list(self, expr_list):
714 if len(expr_list) == 0:
715 return []
716 self.write("(get-value (%s))" % " ".join(expr_list))
717 return [n[1] for n in self.parse(self.read())]
718
719 def get_path(self, mod, path):
720 assert mod in self.modinfo
721 path = path.split(".")
722
723 for i in range(len(path)-1):
724 first = ".".join(path[0:i+1])
725 second = ".".join(path[i+1:])
726
727 if first in self.modinfo[mod].cells:
728 nextmod = self.modinfo[mod].cells[first]
729 return [first] + self.get_path(nextmod, second)
730
731 return [".".join(path)]
732
733 def net_expr(self, mod, base, path):
734 if len(path) == 0:
735 return base
736
737 if len(path) == 1:
738 assert mod in self.modinfo
739 if path[0] == "":
740 return base
741 if path[0] in self.modinfo[mod].cells:
742 return "(|%s_h %s| %s)" % (mod, path[0], base)
743 if path[0] in self.modinfo[mod].wsize:
744 return "(|%s_n %s| %s)" % (mod, path[0], base)
745 if path[0] in self.modinfo[mod].memories:
746 return "(|%s_m %s| %s)" % (mod, path[0], base)
747 assert 0
748
749 assert mod in self.modinfo
750 assert path[0] in self.modinfo[mod].cells
751
752 nextmod = self.modinfo[mod].cells[path[0]]
753 nextbase = "(|%s_h %s| %s)" % (mod, path[0], base)
754 return self.net_expr(nextmod, nextbase, path[1:])
755
756 def net_width(self, mod, net_path):
757 for i in range(len(net_path)-1):
758 assert mod in self.modinfo
759 assert net_path[i] in self.modinfo[mod].cells
760 mod = self.modinfo[mod].cells[net_path[i]]
761
762 assert mod in self.modinfo
763 assert net_path[-1] in self.modinfo[mod].wsize
764 return self.modinfo[mod].wsize[net_path[-1]]
765
766 def net_clock(self, mod, net_path):
767 for i in range(len(net_path)-1):
768 assert mod in self.modinfo
769 assert net_path[i] in self.modinfo[mod].cells
770 mod = self.modinfo[mod].cells[net_path[i]]
771
772 assert mod in self.modinfo
773 if net_path[-1] not in self.modinfo[mod].clocks:
774 return None
775 return self.modinfo[mod].clocks[net_path[-1]]
776
777 def net_exists(self, mod, net_path):
778 for i in range(len(net_path)-1):
779 if mod not in self.modinfo: return False
780 if net_path[i] not in self.modinfo[mod].cells: return False
781 mod = self.modinfo[mod].cells[net_path[i]]
782
783 if mod not in self.modinfo: return False
784 if net_path[-1] not in self.modinfo[mod].wsize: return False
785 return True
786
787 def mem_exists(self, mod, mem_path):
788 for i in range(len(mem_path)-1):
789 if mod not in self.modinfo: return False
790 if mem_path[i] not in self.modinfo[mod].cells: return False
791 mod = self.modinfo[mod].cells[mem_path[i]]
792
793 if mod not in self.modinfo: return False
794 if mem_path[-1] not in self.modinfo[mod].memories: return False
795 return True
796
797 def mem_expr(self, mod, base, path, port=None, infomode=False):
798 if len(path) == 1:
799 assert mod in self.modinfo
800 assert path[0] in self.modinfo[mod].memories
801 if infomode:
802 return self.modinfo[mod].memories[path[0]]
803 return "(|%s_m%s %s| %s)" % (mod, "" if port is None else ":%s" % port, path[0], base)
804
805 assert mod in self.modinfo
806 assert path[0] in self.modinfo[mod].cells
807
808 nextmod = self.modinfo[mod].cells[path[0]]
809 nextbase = "(|%s_h %s| %s)" % (mod, path[0], base)
810 return self.mem_expr(nextmod, nextbase, path[1:], port=port, infomode=infomode)
811
812 def mem_info(self, mod, path):
813 return self.mem_expr(mod, "", path, infomode=True)
814
815 def get_net(self, mod_name, net_path, state_name):
816 return self.get(self.net_expr(mod_name, state_name, net_path))
817
818 def get_net_list(self, mod_name, net_path_list, state_name):
819 return self.get_list([self.net_expr(mod_name, state_name, n) for n in net_path_list])
820
821 def get_net_hex(self, mod_name, net_path, state_name):
822 return self.bv2hex(self.get_net(mod_name, net_path, state_name))
823
824 def get_net_hex_list(self, mod_name, net_path_list, state_name):
825 return [self.bv2hex(v) for v in self.get_net_list(mod_name, net_path_list, state_name)]
826
827 def get_net_bin(self, mod_name, net_path, state_name):
828 return self.bv2bin(self.get_net(mod_name, net_path, state_name))
829
830 def get_net_bin_list(self, mod_name, net_path_list, state_name):
831 return [self.bv2bin(v) for v in self.get_net_list(mod_name, net_path_list, state_name)]
832
833 def wait(self):
834 if self.p is not None:
835 self.p.wait()
836 self.p_close()
837
838
839 class SmtOpts:
840 def __init__(self):
841 self.shortopts = "s:S:v"
842 self.longopts = ["unroll", "noincr", "noprogress", "dump-smt2=", "logic=", "dummy=", "info=", "nocomments"]
843 self.solver = "yices"
844 self.solver_opts = list()
845 self.debug_print = False
846 self.debug_file = None
847 self.dummy_file = None
848 self.unroll = False
849 self.noincr = False
850 self.timeinfo = os.name != "nt"
851 self.logic = None
852 self.info_stmts = list()
853 self.nocomments = False
854
855 def handle(self, o, a):
856 if o == "-s":
857 self.solver = a
858 elif o == "-S":
859 self.solver_opts.append(a)
860 elif o == "-v":
861 self.debug_print = True
862 elif o == "--unroll":
863 self.unroll = True
864 elif o == "--noincr":
865 self.noincr = True
866 elif o == "--noprogress":
867 self.timeinfo = False
868 elif o == "--dump-smt2":
869 self.debug_file = open(a, "w")
870 elif o == "--logic":
871 self.logic = a
872 elif o == "--dummy":
873 self.dummy_file = a
874 elif o == "--info":
875 self.info_stmts.append(a)
876 elif o == "--nocomments":
877 self.nocomments = True
878 else:
879 return False
880 return True
881
882 def helpmsg(self):
883 return """
884 -s <solver>
885 set SMT solver: z3, yices, boolector, cvc4, mathsat, dummy
886 default: yices
887
888 -S <opt>
889 pass <opt> as command line argument to the solver
890
891 --logic <smt2_logic>
892 use the specified SMT2 logic (e.g. QF_AUFBV)
893
894 --dummy <filename>
895 if solver is "dummy", read solver output from that file
896 otherwise: write solver output to that file
897
898 -v
899 enable debug output
900
901 --unroll
902 unroll uninterpreted functions
903
904 --noincr
905 don't use incremental solving, instead restart solver for
906 each (check-sat). This also avoids (push) and (pop).
907
908 --noprogress
909 disable timer display during solving
910 (this option is set implicitly on Windows)
911
912 --dump-smt2 <filename>
913 write smt2 statements to file
914
915 --info <smt2-info-stmt>
916 include the specified smt2 info statement in the smt2 output
917
918 --nocomments
919 strip all comments from the generated smt2 code
920 """
921
922
923 class MkVcd:
924 def __init__(self, f):
925 self.f = f
926 self.t = -1
927 self.nets = dict()
928 self.clocks = dict()
929
930 def add_net(self, path, width):
931 path = tuple(path)
932 assert self.t == -1
933 key = "n%d" % len(self.nets)
934 self.nets[path] = (key, width)
935
936 def add_clock(self, path, edge):
937 path = tuple(path)
938 assert self.t == -1
939 key = "n%d" % len(self.nets)
940 self.nets[path] = (key, 1)
941 self.clocks[path] = (key, edge)
942
943 def set_net(self, path, bits):
944 path = tuple(path)
945 assert self.t >= 0
946 assert path in self.nets
947 if path not in self.clocks:
948 print("b%s %s" % (bits, self.nets[path][0]), file=self.f)
949
950 def escape_name(self, name):
951 name = re.sub(r"\[([a-zA-Z_][0-9a-zA-Z_]*)\]", r".\1", name)
952 if re.match("[\[\]]", name) and name[0] != "\\":
953 name = "\\" + name
954 return name
955
956 def set_time(self, t):
957 assert t >= self.t
958 if t != self.t:
959 if self.t == -1:
960 print("$var integer 32 t smt_step $end", file=self.f)
961 print("$var event 1 ! smt_clock $end", file=self.f)
962 scope = []
963 for path in sorted(self.nets):
964 while len(scope)+1 > len(path) or (len(scope) > 0 and scope[-1] != path[len(scope)-1]):
965 print("$upscope $end", file=self.f)
966 scope = scope[:-1]
967 while len(scope)+1 < len(path):
968 print("$scope module %s $end" % path[len(scope)], file=self.f)
969 scope.append(path[len(scope)-1])
970 key, width = self.nets[path]
971 if path in self.clocks and self.clocks[path][1] == "event":
972 print("$var event 1 %s %s $end" % (key, self.escape_name(path[-1])), file=self.f)
973 else:
974 print("$var wire %d %s %s $end" % (width, key, self.escape_name(path[-1])), file=self.f)
975 for i in range(len(scope)):
976 print("$upscope $end", file=self.f)
977 print("$enddefinitions $end", file=self.f)
978
979 self.t = t
980 assert self.t >= 0
981
982 if self.t > 0:
983 print("#%d" % (10 * self.t - 5), file=self.f)
984 for path in sorted(self.clocks.keys()):
985 if self.clocks[path][1] == "posedge":
986 print("b0 %s" % self.nets[path][0], file=self.f)
987 elif self.clocks[path][1] == "negedge":
988 print("b1 %s" % self.nets[path][0], file=self.f)
989
990 print("#%d" % (10 * self.t), file=self.f)
991 print("1!", file=self.f)
992 print("b%s t" % format(self.t, "032b"), file=self.f)
993
994 for path in sorted(self.clocks.keys()):
995 if self.clocks[path][1] == "negedge":
996 print("b0 %s" % self.nets[path][0], file=self.f)
997 else:
998 print("b1 %s" % self.nets[path][0], file=self.f)
999