Merge branch 'master' into clk2ff-better-names
[yosys.git] / backends / smt2 / smtio.py
1 #
2 # yosys -- Yosys Open SYnthesis Suite
3 #
4 # Copyright (C) 2012 Claire Xenia Wolf <claire@yosyshq.com>
5 #
6 # Permission to use, copy, modify, and/or distribute this software for any
7 # purpose with or without fee is hereby granted, provided that the above
8 # copyright notice and this permission notice appear in all copies.
9 #
10 # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11 # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12 # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13 # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14 # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15 # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
16 # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 #
18
19 import sys, re, os, signal
20 import subprocess
21 if os.name == "posix":
22 import resource
23 from copy import deepcopy
24 from select import select
25 from time import time
26 from queue import Queue, Empty
27 from threading import Thread
28
29
30 # This is needed so that the recursive SMT2 S-expression parser
31 # does not run out of stack frames when parsing large expressions
32 if os.name == "posix":
33 smtio_reclimit = 64 * 1024
34 if sys.getrecursionlimit() < smtio_reclimit:
35 sys.setrecursionlimit(smtio_reclimit)
36
37 current_rlimit_stack = resource.getrlimit(resource.RLIMIT_STACK)
38 if current_rlimit_stack[0] != resource.RLIM_INFINITY:
39 smtio_stacksize = 128 * 1024 * 1024
40 if os.uname().sysname == "Darwin":
41 # MacOS has rather conservative stack limits
42 smtio_stacksize = 8 * 1024 * 1024
43 if current_rlimit_stack[1] != resource.RLIM_INFINITY:
44 smtio_stacksize = min(smtio_stacksize, current_rlimit_stack[1])
45 if current_rlimit_stack[0] < smtio_stacksize:
46 try:
47 resource.setrlimit(resource.RLIMIT_STACK, (smtio_stacksize, current_rlimit_stack[1]))
48 except ValueError:
49 # couldn't get more stack, just run with what we have
50 pass
51
52
53 # currently running solvers (so we can kill them)
54 running_solvers = dict()
55 forced_shutdown = False
56 solvers_index = 0
57
58 def force_shutdown(signum, frame):
59 global forced_shutdown
60 if not forced_shutdown:
61 forced_shutdown = True
62 if signum is not None:
63 print("<%s>" % signal.Signals(signum).name)
64 for p in running_solvers.values():
65 # os.killpg(os.getpgid(p.pid), signal.SIGTERM)
66 os.kill(p.pid, signal.SIGTERM)
67 sys.exit(1)
68
69 if os.name == "posix":
70 signal.signal(signal.SIGHUP, force_shutdown)
71 signal.signal(signal.SIGINT, force_shutdown)
72 signal.signal(signal.SIGTERM, force_shutdown)
73
74 def except_hook(exctype, value, traceback):
75 if not forced_shutdown:
76 sys.__excepthook__(exctype, value, traceback)
77 force_shutdown(None, None)
78
79 sys.excepthook = except_hook
80
81
82 hex_dict = {
83 "0": "0000", "1": "0001", "2": "0010", "3": "0011",
84 "4": "0100", "5": "0101", "6": "0110", "7": "0111",
85 "8": "1000", "9": "1001", "A": "1010", "B": "1011",
86 "C": "1100", "D": "1101", "E": "1110", "F": "1111",
87 "a": "1010", "b": "1011", "c": "1100", "d": "1101",
88 "e": "1110", "f": "1111"
89 }
90
91
92 class SmtModInfo:
93 def __init__(self):
94 self.inputs = set()
95 self.outputs = set()
96 self.registers = set()
97 self.memories = dict()
98 self.wires = set()
99 self.wsize = dict()
100 self.clocks = dict()
101 self.cells = dict()
102 self.asserts = dict()
103 self.covers = dict()
104 self.maximize = set()
105 self.minimize = set()
106 self.anyconsts = dict()
107 self.anyseqs = dict()
108 self.allconsts = dict()
109 self.allseqs = dict()
110 self.asize = dict()
111
112
113 class SmtIo:
114 def __init__(self, opts=None):
115 global solvers_index
116
117 self.logic = None
118 self.logic_qf = True
119 self.logic_ax = True
120 self.logic_uf = True
121 self.logic_bv = True
122 self.logic_dt = False
123 self.forall = False
124 self.timeout = 0
125 self.produce_models = True
126 self.smt2cache = [list()]
127 self.smt2_options = dict()
128 self.p = None
129 self.p_index = solvers_index
130 solvers_index += 1
131
132 if opts is not None:
133 self.logic = opts.logic
134 self.solver = opts.solver
135 self.solver_opts = opts.solver_opts
136 self.debug_print = opts.debug_print
137 self.debug_file = opts.debug_file
138 self.dummy_file = opts.dummy_file
139 self.timeinfo = opts.timeinfo
140 self.timeout = opts.timeout
141 self.unroll = opts.unroll
142 self.noincr = opts.noincr
143 self.info_stmts = opts.info_stmts
144 self.nocomments = opts.nocomments
145
146 else:
147 self.solver = "yices"
148 self.solver_opts = list()
149 self.debug_print = False
150 self.debug_file = None
151 self.dummy_file = None
152 self.timeinfo = os.name != "nt"
153 self.timeout = 0
154 self.unroll = False
155 self.noincr = False
156 self.info_stmts = list()
157 self.nocomments = False
158
159 self.start_time = time()
160
161 self.modinfo = dict()
162 self.curmod = None
163 self.topmod = None
164 self.setup_done = False
165
166 def __del__(self):
167 if self.p is not None and not forced_shutdown:
168 os.killpg(os.getpgid(self.p.pid), signal.SIGTERM)
169 if running_solvers is not None:
170 del running_solvers[self.p_index]
171
172 def setup(self):
173 assert not self.setup_done
174
175 if self.forall:
176 self.unroll = False
177
178 if self.solver == "yices":
179 if self.noincr or self.forall:
180 self.popen_vargs = ['yices-smt2'] + self.solver_opts
181 else:
182 self.popen_vargs = ['yices-smt2', '--incremental'] + self.solver_opts
183 if self.timeout != 0:
184 self.popen_vargs.append('-t')
185 self.popen_vargs.append('%d' % self.timeout);
186
187 if self.solver == "z3":
188 self.popen_vargs = ['z3', '-smt2', '-in'] + self.solver_opts
189 if self.timeout != 0:
190 self.popen_vargs.append('-T:%d' % self.timeout);
191
192 if self.solver == "cvc4":
193 if self.noincr:
194 self.popen_vargs = ['cvc4', '--lang', 'smt2.6' if self.logic_dt else 'smt2'] + self.solver_opts
195 else:
196 self.popen_vargs = ['cvc4', '--incremental', '--lang', 'smt2.6' if self.logic_dt else 'smt2'] + self.solver_opts
197 if self.timeout != 0:
198 self.popen_vargs.append('--tlimit=%d000' % self.timeout);
199
200 if self.solver == "mathsat":
201 self.popen_vargs = ['mathsat'] + self.solver_opts
202 if self.timeout != 0:
203 print('timeout option is not supported for mathsat.')
204 sys.exit(1)
205
206 if self.solver in ["boolector", "bitwuzla"]:
207 if self.noincr:
208 self.popen_vargs = [self.solver, '--smt2'] + self.solver_opts
209 else:
210 self.popen_vargs = [self.solver, '--smt2', '-i'] + self.solver_opts
211 self.unroll = True
212 if self.timeout != 0:
213 print('timeout option is not supported for %s.' % self.solver)
214 sys.exit(1)
215
216 if self.solver == "abc":
217 if len(self.solver_opts) > 0:
218 self.popen_vargs = ['yosys-abc', '-S', '; '.join(self.solver_opts)]
219 else:
220 self.popen_vargs = ['yosys-abc', '-S', '%blast; &sweep -C 5000; &syn4; &cec -s -m -C 2000']
221 self.logic_ax = False
222 self.unroll = True
223 self.noincr = True
224 if self.timeout != 0:
225 print('timeout option is not supported for abc.')
226 sys.exit(1)
227
228 if self.solver == "dummy":
229 assert self.dummy_file is not None
230 self.dummy_fd = open(self.dummy_file, "r")
231 else:
232 if self.dummy_file is not None:
233 self.dummy_fd = open(self.dummy_file, "w")
234 if not self.noincr:
235 self.p_open()
236
237 if self.unroll:
238 assert not self.forall
239 self.logic_uf = False
240 self.unroll_idcnt = 0
241 self.unroll_buffer = ""
242 self.unroll_sorts = set()
243 self.unroll_objs = set()
244 self.unroll_decls = dict()
245 self.unroll_cache = dict()
246 self.unroll_stack = list()
247
248 if self.logic is None:
249 self.logic = ""
250 if self.logic_qf: self.logic += "QF_"
251 if self.logic_ax: self.logic += "A"
252 if self.logic_uf: self.logic += "UF"
253 if self.logic_bv: self.logic += "BV"
254 if self.logic_dt: self.logic = "ALL"
255 if self.solver == "yices" and self.forall: self.logic = "BV"
256
257 self.setup_done = True
258
259 for stmt in self.info_stmts:
260 self.write(stmt)
261
262 if self.produce_models:
263 self.write("(set-option :produce-models true)")
264
265 #See the SMT-LIB Standard, Section 4.1.7
266 modestart_options = [":global-declarations", ":interactive-mode", ":produce-assertions", ":produce-assignments", ":produce-models", ":produce-proofs", ":produce-unsat-assumptions", ":produce-unsat-cores", ":random-seed"]
267 for key, val in self.smt2_options.items():
268 if key in modestart_options:
269 self.write("(set-option {} {})".format(key, val))
270
271 self.write("(set-logic %s)" % self.logic)
272
273 if self.forall and self.solver == "yices":
274 self.write("(set-option :yices-ef-max-iters 1000000000)")
275
276 for key, val in self.smt2_options.items():
277 if key not in modestart_options:
278 self.write("(set-option {} {})".format(key, val))
279
280 def timestamp(self):
281 secs = int(time() - self.start_time)
282 return "## %3d:%02d:%02d " % (secs // (60*60), (secs // 60) % 60, secs % 60)
283
284 def replace_in_stmt(self, stmt, pat, repl):
285 if stmt == pat:
286 return repl
287
288 if isinstance(stmt, list):
289 return [self.replace_in_stmt(s, pat, repl) for s in stmt]
290
291 return stmt
292
293 def unroll_stmt(self, stmt):
294 if not isinstance(stmt, list):
295 return stmt
296
297 stmt = [self.unroll_stmt(s) for s in stmt]
298
299 if len(stmt) >= 2 and not isinstance(stmt[0], list) and stmt[0] in self.unroll_decls:
300 assert stmt[1] in self.unroll_objs
301
302 key = tuple(stmt)
303 if key not in self.unroll_cache:
304 decl = deepcopy(self.unroll_decls[key[0]])
305
306 self.unroll_cache[key] = "|UNROLL#%d|" % self.unroll_idcnt
307 decl[1] = self.unroll_cache[key]
308 self.unroll_idcnt += 1
309
310 if decl[0] == "declare-fun":
311 if isinstance(decl[3], list) or decl[3] not in self.unroll_sorts:
312 self.unroll_objs.add(decl[1])
313 decl[2] = list()
314 else:
315 self.unroll_objs.add(decl[1])
316 decl = list()
317
318 elif decl[0] == "define-fun":
319 arg_index = 1
320 for arg_name, arg_sort in decl[2]:
321 decl[4] = self.replace_in_stmt(decl[4], arg_name, key[arg_index])
322 arg_index += 1
323 decl[2] = list()
324
325 if len(decl) > 0:
326 decl = self.unroll_stmt(decl)
327 self.write(self.unparse(decl), unroll=False)
328
329 return self.unroll_cache[key]
330
331 return stmt
332
333 def p_thread_main(self):
334 while True:
335 data = self.p.stdout.readline().decode("ascii")
336 if data == "": break
337 self.p_queue.put(data)
338 self.p_queue.put("")
339 self.p_running = False
340
341 def p_open(self):
342 assert self.p is None
343 try:
344 self.p = subprocess.Popen(self.popen_vargs, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
345 except FileNotFoundError:
346 print("%s SMT Solver '%s' not found in path." % (self.timestamp(), self.popen_vargs[0]), flush=True)
347 sys.exit(1)
348 running_solvers[self.p_index] = self.p
349 self.p_running = True
350 self.p_next = None
351 self.p_queue = Queue()
352 self.p_thread = Thread(target=self.p_thread_main)
353 self.p_thread.start()
354
355 def p_write(self, data, flush):
356 assert self.p is not None
357 self.p.stdin.write(bytes(data, "ascii"))
358 if flush: self.p.stdin.flush()
359
360 def p_read(self):
361 assert self.p is not None
362 if self.p_next is not None:
363 data = self.p_next
364 self.p_next = None
365 return data
366 if not self.p_running:
367 return ""
368 return self.p_queue.get()
369
370 def p_poll(self, timeout=0.1):
371 assert self.p is not None
372 assert self.p_running
373 if self.p_next is not None:
374 return False
375 try:
376 self.p_next = self.p_queue.get(True, timeout)
377 return False
378 except Empty:
379 return True
380
381 def p_close(self):
382 assert self.p is not None
383 self.p.stdin.close()
384 self.p_thread.join()
385 assert not self.p_running
386 del running_solvers[self.p_index]
387 self.p = None
388 self.p_next = None
389 self.p_queue = None
390 self.p_thread = None
391
392 def write(self, stmt, unroll=True):
393 if stmt.startswith(";"):
394 self.info(stmt)
395 if not self.setup_done:
396 self.info_stmts.append(stmt)
397 return
398 elif not self.setup_done:
399 self.setup()
400
401 stmt = stmt.strip()
402
403 if self.nocomments or self.unroll:
404 stmt = re.sub(r" *;.*", "", stmt)
405 if stmt == "": return
406
407 if unroll and self.unroll:
408 stmt = self.unroll_buffer + stmt
409 self.unroll_buffer = ""
410
411 s = re.sub(r"\|[^|]*\|", "", stmt)
412 if s.count("(") != s.count(")"):
413 self.unroll_buffer = stmt + " "
414 return
415
416 s = self.parse(stmt)
417
418 if self.debug_print:
419 print("-> %s" % s)
420
421 if len(s) == 3 and s[0] == "declare-sort" and s[2] == "0":
422 self.unroll_sorts.add(s[1])
423 return
424
425 elif len(s) == 4 and s[0] == "declare-fun" and s[2] == [] and s[3] in self.unroll_sorts:
426 self.unroll_objs.add(s[1])
427 return
428
429 elif len(s) >= 4 and s[0] == "declare-fun":
430 for arg_sort in s[2]:
431 if arg_sort in self.unroll_sorts:
432 self.unroll_decls[s[1]] = s
433 return
434
435 elif len(s) >= 4 and s[0] == "define-fun":
436 for arg_name, arg_sort in s[2]:
437 if arg_sort in self.unroll_sorts:
438 self.unroll_decls[s[1]] = s
439 return
440
441 stmt = self.unparse(self.unroll_stmt(s))
442
443 if stmt == "(push 1)":
444 self.unroll_stack.append((
445 deepcopy(self.unroll_sorts),
446 deepcopy(self.unroll_objs),
447 deepcopy(self.unroll_decls),
448 deepcopy(self.unroll_cache),
449 ))
450
451 if stmt == "(pop 1)":
452 self.unroll_sorts, self.unroll_objs, self.unroll_decls, self.unroll_cache = self.unroll_stack.pop()
453
454 if self.debug_print:
455 print("> %s" % stmt)
456
457 if self.debug_file:
458 print(stmt, file=self.debug_file)
459 self.debug_file.flush()
460
461 if self.solver != "dummy":
462 if self.noincr:
463 if self.p is not None and not stmt.startswith("(get-"):
464 self.p_close()
465 if stmt == "(push 1)":
466 self.smt2cache.append(list())
467 elif stmt == "(pop 1)":
468 self.smt2cache.pop()
469 else:
470 if self.p is not None:
471 self.p_write(stmt + "\n", True)
472 self.smt2cache[-1].append(stmt)
473 else:
474 self.p_write(stmt + "\n", True)
475
476 def info(self, stmt):
477 if not stmt.startswith("; yosys-smt2-"):
478 return
479
480 fields = stmt.split()
481
482 if fields[1] == "yosys-smt2-solver-option":
483 self.smt2_options[fields[2]] = fields[3]
484
485 if fields[1] == "yosys-smt2-nomem":
486 if self.logic is None:
487 self.logic_ax = False
488
489 if fields[1] == "yosys-smt2-nobv":
490 if self.logic is None:
491 self.logic_bv = False
492
493 if fields[1] == "yosys-smt2-stdt":
494 if self.logic is None:
495 self.logic_dt = True
496
497 if fields[1] == "yosys-smt2-forall":
498 if self.logic is None:
499 self.logic_qf = False
500 self.forall = True
501
502 if fields[1] == "yosys-smt2-module":
503 self.curmod = fields[2]
504 self.modinfo[self.curmod] = SmtModInfo()
505
506 if fields[1] == "yosys-smt2-cell":
507 self.modinfo[self.curmod].cells[fields[3]] = fields[2]
508
509 if fields[1] == "yosys-smt2-topmod":
510 self.topmod = fields[2]
511
512 if fields[1] == "yosys-smt2-input":
513 self.modinfo[self.curmod].inputs.add(fields[2])
514 self.modinfo[self.curmod].wsize[fields[2]] = int(fields[3])
515
516 if fields[1] == "yosys-smt2-output":
517 self.modinfo[self.curmod].outputs.add(fields[2])
518 self.modinfo[self.curmod].wsize[fields[2]] = int(fields[3])
519
520 if fields[1] == "yosys-smt2-register":
521 self.modinfo[self.curmod].registers.add(fields[2])
522 self.modinfo[self.curmod].wsize[fields[2]] = int(fields[3])
523
524 if fields[1] == "yosys-smt2-memory":
525 self.modinfo[self.curmod].memories[fields[2]] = (int(fields[3]), int(fields[4]), int(fields[5]), int(fields[6]), fields[7] == "async")
526
527 if fields[1] == "yosys-smt2-wire":
528 self.modinfo[self.curmod].wires.add(fields[2])
529 self.modinfo[self.curmod].wsize[fields[2]] = int(fields[3])
530
531 if fields[1] == "yosys-smt2-clock":
532 for edge in fields[3:]:
533 if fields[2] not in self.modinfo[self.curmod].clocks:
534 self.modinfo[self.curmod].clocks[fields[2]] = edge
535 elif self.modinfo[self.curmod].clocks[fields[2]] != edge:
536 self.modinfo[self.curmod].clocks[fields[2]] = "event"
537
538 if fields[1] == "yosys-smt2-assert":
539 self.modinfo[self.curmod].asserts["%s_a %s" % (self.curmod, fields[2])] = fields[3]
540
541 if fields[1] == "yosys-smt2-cover":
542 self.modinfo[self.curmod].covers["%s_c %s" % (self.curmod, fields[2])] = fields[3]
543
544 if fields[1] == "yosys-smt2-maximize":
545 self.modinfo[self.curmod].maximize.add(fields[2])
546
547 if fields[1] == "yosys-smt2-minimize":
548 self.modinfo[self.curmod].minimize.add(fields[2])
549
550 if fields[1] == "yosys-smt2-anyconst":
551 self.modinfo[self.curmod].anyconsts[fields[2]] = (fields[4], None if len(fields) <= 5 else fields[5])
552 self.modinfo[self.curmod].asize[fields[2]] = int(fields[3])
553
554 if fields[1] == "yosys-smt2-anyseq":
555 self.modinfo[self.curmod].anyseqs[fields[2]] = (fields[4], None if len(fields) <= 5 else fields[5])
556 self.modinfo[self.curmod].asize[fields[2]] = int(fields[3])
557
558 if fields[1] == "yosys-smt2-allconst":
559 self.modinfo[self.curmod].allconsts[fields[2]] = (fields[4], None if len(fields) <= 5 else fields[5])
560 self.modinfo[self.curmod].asize[fields[2]] = int(fields[3])
561
562 if fields[1] == "yosys-smt2-allseq":
563 self.modinfo[self.curmod].allseqs[fields[2]] = (fields[4], None if len(fields) <= 5 else fields[5])
564 self.modinfo[self.curmod].asize[fields[2]] = int(fields[3])
565
566 def hiernets(self, top, regs_only=False):
567 def hiernets_worker(nets, mod, cursor):
568 for netname in sorted(self.modinfo[mod].wsize.keys()):
569 if not regs_only or netname in self.modinfo[mod].registers:
570 nets.append(cursor + [netname])
571 for cellname, celltype in sorted(self.modinfo[mod].cells.items()):
572 hiernets_worker(nets, celltype, cursor + [cellname])
573
574 nets = list()
575 hiernets_worker(nets, top, [])
576 return nets
577
578 def hieranyconsts(self, top):
579 def worker(results, mod, cursor):
580 for name, value in sorted(self.modinfo[mod].anyconsts.items()):
581 width = self.modinfo[mod].asize[name]
582 results.append((cursor, name, value[0], value[1], width))
583 for cellname, celltype in sorted(self.modinfo[mod].cells.items()):
584 worker(results, celltype, cursor + [cellname])
585
586 results = list()
587 worker(results, top, [])
588 return results
589
590 def hieranyseqs(self, top):
591 def worker(results, mod, cursor):
592 for name, value in sorted(self.modinfo[mod].anyseqs.items()):
593 width = self.modinfo[mod].asize[name]
594 results.append((cursor, name, value[0], value[1], width))
595 for cellname, celltype in sorted(self.modinfo[mod].cells.items()):
596 worker(results, celltype, cursor + [cellname])
597
598 results = list()
599 worker(results, top, [])
600 return results
601
602 def hierallconsts(self, top):
603 def worker(results, mod, cursor):
604 for name, value in sorted(self.modinfo[mod].allconsts.items()):
605 width = self.modinfo[mod].asize[name]
606 results.append((cursor, name, value[0], value[1], width))
607 for cellname, celltype in sorted(self.modinfo[mod].cells.items()):
608 worker(results, celltype, cursor + [cellname])
609
610 results = list()
611 worker(results, top, [])
612 return results
613
614 def hierallseqs(self, top):
615 def worker(results, mod, cursor):
616 for name, value in sorted(self.modinfo[mod].allseqs.items()):
617 width = self.modinfo[mod].asize[name]
618 results.append((cursor, name, value[0], value[1], width))
619 for cellname, celltype in sorted(self.modinfo[mod].cells.items()):
620 worker(results, celltype, cursor + [cellname])
621
622 results = list()
623 worker(results, top, [])
624 return results
625
626 def hiermems(self, top):
627 def hiermems_worker(mems, mod, cursor):
628 for memname in sorted(self.modinfo[mod].memories.keys()):
629 mems.append(cursor + [memname])
630 for cellname, celltype in sorted(self.modinfo[mod].cells.items()):
631 hiermems_worker(mems, celltype, cursor + [cellname])
632
633 mems = list()
634 hiermems_worker(mems, top, [])
635 return mems
636
637 def read(self):
638 stmt = []
639 count_brackets = 0
640
641 while True:
642 if self.solver == "dummy":
643 line = self.dummy_fd.readline().strip()
644 else:
645 line = self.p_read().strip()
646 if self.dummy_file is not None:
647 self.dummy_fd.write(line + "\n")
648
649 count_brackets += line.count("(")
650 count_brackets -= line.count(")")
651 stmt.append(line)
652
653 if self.debug_print:
654 print("< %s" % line)
655 if count_brackets == 0:
656 break
657 if self.solver != "dummy" and self.p.poll():
658 print("%s Solver terminated unexpectedly: %s" % (self.timestamp(), "".join(stmt)), flush=True)
659 sys.exit(1)
660
661 stmt = "".join(stmt)
662 if stmt.startswith("(error"):
663 print("%s Solver Error: %s" % (self.timestamp(), stmt), flush=True)
664 if self.solver != "dummy":
665 self.p_close()
666 sys.exit(1)
667
668 return stmt
669
670 def check_sat(self, expected=["sat", "unsat", "unknown", "timeout", "interrupted"]):
671 if self.debug_print:
672 print("> (check-sat)")
673 if self.debug_file and not self.nocomments:
674 print("; running check-sat..", file=self.debug_file)
675 self.debug_file.flush()
676
677 if self.solver != "dummy":
678 if self.noincr:
679 if self.p is not None:
680 self.p_close()
681 self.p_open()
682 for cache_ctx in self.smt2cache:
683 for cache_stmt in cache_ctx:
684 self.p_write(cache_stmt + "\n", False)
685
686 self.p_write("(check-sat)\n", True)
687
688 if self.timeinfo:
689 i = 0
690 s = "/-\|"
691
692 count = 0
693 num_bs = 0
694 while self.p_poll():
695 count += 1
696
697 if count < 25:
698 continue
699
700 if count % 10 == 0 or count == 25:
701 secs = count // 10
702
703 if secs < 60:
704 m = "(%d seconds)" % secs
705 elif secs < 60*60:
706 m = "(%d seconds -- %d:%02d)" % (secs, secs // 60, secs % 60)
707 else:
708 m = "(%d seconds -- %d:%02d:%02d)" % (secs, secs // (60*60), (secs // 60) % 60, secs % 60)
709
710 print("%s %s %c" % ("\b \b" * num_bs, m, s[i]), end="", file=sys.stderr)
711 num_bs = len(m) + 3
712
713 else:
714 print("\b" + s[i], end="", file=sys.stderr)
715
716 sys.stderr.flush()
717 i = (i + 1) % len(s)
718
719 if num_bs != 0:
720 print("\b \b" * num_bs, end="", file=sys.stderr)
721 sys.stderr.flush()
722
723 else:
724 count = 0
725 while self.p_poll(60):
726 count += 1
727 msg = None
728
729 if count == 1:
730 msg = "1 minute"
731
732 elif count in [5, 10, 15, 30]:
733 msg = "%d minutes" % count
734
735 elif count == 60:
736 msg = "1 hour"
737
738 elif count % 60 == 0:
739 msg = "%d hours" % (count // 60)
740
741 if msg is not None:
742 print("%s waiting for solver (%s)" % (self.timestamp(), msg), flush=True)
743
744 if self.forall:
745 result = self.read()
746 while result not in ["sat", "unsat", "unknown", "timeout", "interrupted", ""]:
747 print("%s %s: %s" % (self.timestamp(), self.solver, result))
748 result = self.read()
749 else:
750 result = self.read()
751
752 if self.debug_file:
753 print("(set-info :status %s)" % result, file=self.debug_file)
754 print("(check-sat)", file=self.debug_file)
755 self.debug_file.flush()
756
757 if result not in expected:
758 if result == "":
759 print("%s Unexpected EOF response from solver." % (self.timestamp()), flush=True)
760 else:
761 print("%s Unexpected response from solver: %s" % (self.timestamp(), result), flush=True)
762 if self.solver != "dummy":
763 self.p_close()
764 sys.exit(1)
765
766 return result
767
768 def parse(self, stmt):
769 def worker(stmt):
770 if stmt[0] == '(':
771 expr = []
772 cursor = 1
773 while stmt[cursor] != ')':
774 el, le = worker(stmt[cursor:])
775 expr.append(el)
776 cursor += le
777 return expr, cursor+1
778
779 if stmt[0] == '|':
780 expr = "|"
781 cursor = 1
782 while stmt[cursor] != '|':
783 expr += stmt[cursor]
784 cursor += 1
785 expr += "|"
786 return expr, cursor+1
787
788 if stmt[0] in [" ", "\t", "\r", "\n"]:
789 el, le = worker(stmt[1:])
790 return el, le+1
791
792 expr = ""
793 cursor = 0
794 while stmt[cursor] not in ["(", ")", "|", " ", "\t", "\r", "\n"]:
795 expr += stmt[cursor]
796 cursor += 1
797 return expr, cursor
798 return worker(stmt)[0]
799
800 def unparse(self, stmt):
801 if isinstance(stmt, list):
802 return "(" + " ".join([self.unparse(s) for s in stmt]) + ")"
803 return stmt
804
805 def bv2hex(self, v):
806 h = ""
807 v = self.bv2bin(v)
808 while len(v) > 0:
809 d = 0
810 if len(v) > 0 and v[-1] == "1": d += 1
811 if len(v) > 1 and v[-2] == "1": d += 2
812 if len(v) > 2 and v[-3] == "1": d += 4
813 if len(v) > 3 and v[-4] == "1": d += 8
814 h = hex(d)[2:] + h
815 if len(v) < 4: break
816 v = v[:-4]
817 return h
818
819 def bv2bin(self, v):
820 if type(v) is list and len(v) == 3 and v[0] == "_" and v[1].startswith("bv"):
821 x, n = int(v[1][2:]), int(v[2])
822 return "".join("1" if (x & (1 << i)) else "0" for i in range(n-1, -1, -1))
823 if v == "true": return "1"
824 if v == "false": return "0"
825 if v.startswith("#b"):
826 return v[2:]
827 if v.startswith("#x"):
828 return "".join(hex_dict.get(x) for x in v[2:])
829 assert False
830
831 def bv2int(self, v):
832 return int(self.bv2bin(v), 2)
833
834 def get(self, expr):
835 self.write("(get-value (%s))" % (expr))
836 return self.parse(self.read())[0][1]
837
838 def get_list(self, expr_list):
839 if len(expr_list) == 0:
840 return []
841 self.write("(get-value (%s))" % " ".join(expr_list))
842 return [n[1] for n in self.parse(self.read())]
843
844 def get_path(self, mod, path):
845 assert mod in self.modinfo
846 path = path.replace("\\", "/").split(".")
847
848 for i in range(len(path)-1):
849 first = ".".join(path[0:i+1])
850 second = ".".join(path[i+1:])
851
852 if first in self.modinfo[mod].cells:
853 nextmod = self.modinfo[mod].cells[first]
854 return [first] + self.get_path(nextmod, second)
855
856 return [".".join(path)]
857
858 def net_expr(self, mod, base, path):
859 if len(path) == 0:
860 return base
861
862 if len(path) == 1:
863 assert mod in self.modinfo
864 if path[0] == "":
865 return base
866 if path[0] in self.modinfo[mod].cells:
867 return "(|%s_h %s| %s)" % (mod, path[0], base)
868 if path[0] in self.modinfo[mod].wsize:
869 return "(|%s_n %s| %s)" % (mod, path[0], base)
870 if path[0] in self.modinfo[mod].memories:
871 return "(|%s_m %s| %s)" % (mod, path[0], base)
872 assert 0
873
874 assert mod in self.modinfo
875 assert path[0] in self.modinfo[mod].cells
876
877 nextmod = self.modinfo[mod].cells[path[0]]
878 nextbase = "(|%s_h %s| %s)" % (mod, path[0], base)
879 return self.net_expr(nextmod, nextbase, path[1:])
880
881 def net_width(self, mod, net_path):
882 for i in range(len(net_path)-1):
883 assert mod in self.modinfo
884 assert net_path[i] in self.modinfo[mod].cells
885 mod = self.modinfo[mod].cells[net_path[i]]
886
887 assert mod in self.modinfo
888 assert net_path[-1] in self.modinfo[mod].wsize
889 return self.modinfo[mod].wsize[net_path[-1]]
890
891 def net_clock(self, mod, net_path):
892 for i in range(len(net_path)-1):
893 assert mod in self.modinfo
894 assert net_path[i] in self.modinfo[mod].cells
895 mod = self.modinfo[mod].cells[net_path[i]]
896
897 assert mod in self.modinfo
898 if net_path[-1] not in self.modinfo[mod].clocks:
899 return None
900 return self.modinfo[mod].clocks[net_path[-1]]
901
902 def net_exists(self, mod, net_path):
903 for i in range(len(net_path)-1):
904 if mod not in self.modinfo: return False
905 if net_path[i] not in self.modinfo[mod].cells: return False
906 mod = self.modinfo[mod].cells[net_path[i]]
907
908 if mod not in self.modinfo: return False
909 if net_path[-1] not in self.modinfo[mod].wsize: return False
910 return True
911
912 def mem_exists(self, mod, mem_path):
913 for i in range(len(mem_path)-1):
914 if mod not in self.modinfo: return False
915 if mem_path[i] not in self.modinfo[mod].cells: return False
916 mod = self.modinfo[mod].cells[mem_path[i]]
917
918 if mod not in self.modinfo: return False
919 if mem_path[-1] not in self.modinfo[mod].memories: return False
920 return True
921
922 def mem_expr(self, mod, base, path, port=None, infomode=False):
923 if len(path) == 1:
924 assert mod in self.modinfo
925 assert path[0] in self.modinfo[mod].memories
926 if infomode:
927 return self.modinfo[mod].memories[path[0]]
928 return "(|%s_m%s %s| %s)" % (mod, "" if port is None else ":%s" % port, path[0], base)
929
930 assert mod in self.modinfo
931 assert path[0] in self.modinfo[mod].cells
932
933 nextmod = self.modinfo[mod].cells[path[0]]
934 nextbase = "(|%s_h %s| %s)" % (mod, path[0], base)
935 return self.mem_expr(nextmod, nextbase, path[1:], port=port, infomode=infomode)
936
937 def mem_info(self, mod, path):
938 return self.mem_expr(mod, "", path, infomode=True)
939
940 def get_net(self, mod_name, net_path, state_name):
941 return self.get(self.net_expr(mod_name, state_name, net_path))
942
943 def get_net_list(self, mod_name, net_path_list, state_name):
944 return self.get_list([self.net_expr(mod_name, state_name, n) for n in net_path_list])
945
946 def get_net_hex(self, mod_name, net_path, state_name):
947 return self.bv2hex(self.get_net(mod_name, net_path, state_name))
948
949 def get_net_hex_list(self, mod_name, net_path_list, state_name):
950 return [self.bv2hex(v) for v in self.get_net_list(mod_name, net_path_list, state_name)]
951
952 def get_net_bin(self, mod_name, net_path, state_name):
953 return self.bv2bin(self.get_net(mod_name, net_path, state_name))
954
955 def get_net_bin_list(self, mod_name, net_path_list, state_name):
956 return [self.bv2bin(v) for v in self.get_net_list(mod_name, net_path_list, state_name)]
957
958 def wait(self):
959 if self.p is not None:
960 self.p.wait()
961 self.p_close()
962
963
964 class SmtOpts:
965 def __init__(self):
966 self.shortopts = "s:S:v"
967 self.longopts = ["unroll", "noincr", "noprogress", "timeout=", "dump-smt2=", "logic=", "dummy=", "info=", "nocomments"]
968 self.solver = "yices"
969 self.solver_opts = list()
970 self.debug_print = False
971 self.debug_file = None
972 self.dummy_file = None
973 self.unroll = False
974 self.noincr = False
975 self.timeinfo = os.name != "nt"
976 self.timeout = 0
977 self.logic = None
978 self.info_stmts = list()
979 self.nocomments = False
980
981 def handle(self, o, a):
982 if o == "-s":
983 self.solver = a
984 elif o == "-S":
985 self.solver_opts.append(a)
986 elif o == "--timeout":
987 self.timeout = int(a)
988 elif o == "-v":
989 self.debug_print = True
990 elif o == "--unroll":
991 self.unroll = True
992 elif o == "--noincr":
993 self.noincr = True
994 elif o == "--noprogress":
995 self.timeinfo = False
996 elif o == "--dump-smt2":
997 self.debug_file = open(a, "w")
998 elif o == "--logic":
999 self.logic = a
1000 elif o == "--dummy":
1001 self.dummy_file = a
1002 elif o == "--info":
1003 self.info_stmts.append(a)
1004 elif o == "--nocomments":
1005 self.nocomments = True
1006 else:
1007 return False
1008 return True
1009
1010 def helpmsg(self):
1011 return """
1012 -s <solver>
1013 set SMT solver: z3, yices, boolector, bitwuzla, cvc4, mathsat, dummy
1014 default: yices
1015
1016 -S <opt>
1017 pass <opt> as command line argument to the solver
1018
1019 --timeout <value>
1020 set the solver timeout to the specified value (in seconds).
1021
1022 --logic <smt2_logic>
1023 use the specified SMT2 logic (e.g. QF_AUFBV)
1024
1025 --dummy <filename>
1026 if solver is "dummy", read solver output from that file
1027 otherwise: write solver output to that file
1028
1029 -v
1030 enable debug output
1031
1032 --unroll
1033 unroll uninterpreted functions
1034
1035 --noincr
1036 don't use incremental solving, instead restart solver for
1037 each (check-sat). This also avoids (push) and (pop).
1038
1039 --noprogress
1040 disable timer display during solving
1041 (this option is set implicitly on Windows)
1042
1043 --dump-smt2 <filename>
1044 write smt2 statements to file
1045
1046 --info <smt2-info-stmt>
1047 include the specified smt2 info statement in the smt2 output
1048
1049 --nocomments
1050 strip all comments from the generated smt2 code
1051 """
1052
1053
1054 class MkVcd:
1055 def __init__(self, f):
1056 self.f = f
1057 self.t = -1
1058 self.nets = dict()
1059 self.clocks = dict()
1060
1061 def add_net(self, path, width):
1062 path = tuple(path)
1063 assert self.t == -1
1064 key = "n%d" % len(self.nets)
1065 self.nets[path] = (key, width)
1066
1067 def add_clock(self, path, edge):
1068 path = tuple(path)
1069 assert self.t == -1
1070 key = "n%d" % len(self.nets)
1071 self.nets[path] = (key, 1)
1072 self.clocks[path] = (key, edge)
1073
1074 def set_net(self, path, bits):
1075 path = tuple(path)
1076 assert self.t >= 0
1077 assert path in self.nets
1078 if path not in self.clocks:
1079 print("b%s %s" % (bits, self.nets[path][0]), file=self.f)
1080
1081 def escape_name(self, name):
1082 name = re.sub(r"\[([0-9a-zA-Z_]*[a-zA-Z_][0-9a-zA-Z_]*)\]", r"<\1>", name)
1083 if re.match("[\[\]]", name) and name[0] != "\\":
1084 name = "\\" + name
1085 return name
1086
1087 def set_time(self, t):
1088 assert t >= self.t
1089 if t != self.t:
1090 if self.t == -1:
1091 print("$version Generated by Yosys-SMTBMC $end", file=self.f)
1092 print("$timescale 1ns $end", file=self.f)
1093 print("$var integer 32 t smt_step $end", file=self.f)
1094 print("$var event 1 ! smt_clock $end", file=self.f)
1095
1096 def vcdescape(n):
1097 if n.startswith("$") or ":" in n:
1098 return "\\" + n
1099 return n
1100
1101 scope = []
1102 for path in sorted(self.nets):
1103 key, width = self.nets[path]
1104
1105 uipath = list(path)
1106 if "." in uipath[-1] and not uipath[-1].startswith("$"):
1107 uipath = uipath[0:-1] + uipath[-1].split(".")
1108 for i in range(len(uipath)):
1109 uipath[i] = re.sub(r"\[([^\]]*)\]", r"<\1>", uipath[i])
1110
1111 while uipath[:len(scope)] != scope:
1112 print("$upscope $end", file=self.f)
1113 scope = scope[:-1]
1114
1115 while uipath[:-1] != scope:
1116 scopename = uipath[len(scope)]
1117 print("$scope module %s $end" % vcdescape(scopename), file=self.f)
1118 scope.append(uipath[len(scope)])
1119
1120 if path in self.clocks and self.clocks[path][1] == "event":
1121 print("$var event 1 %s %s $end" % (key, vcdescape(uipath[-1])), file=self.f)
1122 else:
1123 print("$var wire %d %s %s $end" % (width, key, vcdescape(uipath[-1])), file=self.f)
1124
1125 for i in range(len(scope)):
1126 print("$upscope $end", file=self.f)
1127
1128 print("$enddefinitions $end", file=self.f)
1129
1130 self.t = t
1131 assert self.t >= 0
1132
1133 if self.t > 0:
1134 print("#%d" % (10 * self.t - 5), file=self.f)
1135 for path in sorted(self.clocks.keys()):
1136 if self.clocks[path][1] == "posedge":
1137 print("b0 %s" % self.nets[path][0], file=self.f)
1138 elif self.clocks[path][1] == "negedge":
1139 print("b1 %s" % self.nets[path][0], file=self.f)
1140
1141 print("#%d" % (10 * self.t), file=self.f)
1142 print("1!", file=self.f)
1143 print("b%s t" % format(self.t, "032b"), file=self.f)
1144
1145 for path in sorted(self.clocks.keys()):
1146 if self.clocks[path][1] == "negedge":
1147 print("b0 %s" % self.nets[path][0], file=self.f)
1148 else:
1149 print("b1 %s" % self.nets[path][0], file=self.f)