Added "yosys-smtbmc --unroll"
authorClifford Wolf <clifford@clifford.at>
Wed, 7 Sep 2016 18:57:56 +0000 (20:57 +0200)
committerClifford Wolf <clifford@clifford.at>
Wed, 7 Sep 2016 18:57:56 +0000 (20:57 +0200)
backends/smt2/smtio.py

index 9c4a0225eea52b9edb90a7b866dcc7c0ac87668a..9bb934a461cb623f5dbdb44abe515fbaa19d30a9 100644 (file)
@@ -18,6 +18,7 @@
 #
 
 import sys, subprocess, re
+from copy import deepcopy
 from select import select
 from time import time
 
@@ -46,18 +47,20 @@ class SmtModInfo:
 
 
 class SmtIo:
-    def __init__(self, solver=None, debug_print=None, debug_file=None, timeinfo=None, opts=None):
+    def __init__(self, solver=None, debug_print=None, debug_file=None, timeinfo=None, unroll=None, opts=None):
         if opts is not None:
             self.solver = opts.solver
             self.debug_print = opts.debug_print
             self.debug_file = opts.debug_file
             self.timeinfo = opts.timeinfo
+            self.unroll = opts.unroll
 
         else:
             self.solver = "z3"
             self.debug_print = False
             self.debug_file = None
             self.timeinfo = True
+            self.unroll = False
 
         if solver is not None:
             self.solver = solver
@@ -71,6 +74,9 @@ class SmtIo:
         if timeinfo is not None:
             self.timeinfo = timeinfo
 
+        if unroll is not None:
+            self.unroll = unroll
+
         if self.solver == "yices":
             popen_vargs = ['yices-smt2', '--incremental']
 
@@ -84,8 +90,17 @@ class SmtIo:
             popen_vargs = ['mathsat']
 
         if self.solver == "boolector":
-            self.declared_sorts = list()
-            popen_vargs = ['boolector', '--smt2', '-i']
+            popen_vargs = ['boolector', '--smt2', '-i', '-m']
+            self.unroll = True
+
+        if self.unroll:
+            self.unroll_idcnt = 0
+            self.unroll_buffer = ""
+            self.unroll_sorts = set()
+            self.unroll_objs = set()
+            self.unroll_decls = dict()
+            self.unroll_cache = dict()
+            self.unroll_stack = list()
 
         self.p = subprocess.Popen(popen_vargs, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
         self.start_time = time()
@@ -105,15 +120,108 @@ class SmtIo:
         secs = int(time() - self.start_time)
         return "## %6d %3d:%02d:%02d " % (secs, secs // (60*60), (secs // 60) % 60, secs % 60)
 
-    def write(self, stmt):
+    def replace_in_stmt(self, stmt, pat, repl):
+        if stmt == pat:
+            return repl
+
+        if isinstance(stmt, list):
+            return [self.replace_in_stmt(s, pat, repl) for s in stmt]
+
+        return stmt
+
+    def unroll_stmt(self, stmt):
+        if not isinstance(stmt, list):
+            return stmt
+
+        stmt = [self.unroll_stmt(s) for s in stmt]
+
+        if len(stmt) >= 2 and not isinstance(stmt[0], list) and stmt[0] in self.unroll_decls:
+            assert stmt[1] in self.unroll_objs
+
+            key = tuple(stmt)
+            if key not in self.unroll_cache:
+                decl = deepcopy(self.unroll_decls[key[0]])
+
+                self.unroll_cache[key] = "|UNROLL#%d|" % self.unroll_idcnt
+                decl[1] = self.unroll_cache[key]
+                self.unroll_idcnt += 1
+
+                if decl[0] == "declare-fun":
+                    if isinstance(decl[3], list) or decl[3] not in self.unroll_sorts:
+                        self.unroll_objs.add(decl[1])
+                        decl[2] = list()
+                    else:
+                        self.unroll_objs.add(decl[1])
+                        decl = list()
+
+                elif decl[0] == "define-fun":
+                    arg_index = 1
+                    for arg_name, arg_sort in decl[2]:
+                        decl[4] = self.replace_in_stmt(decl[4], arg_name, key[arg_index])
+                        arg_index += 1
+                    decl[2] = list()
+
+                if len(decl) > 0:
+                    decl = self.unroll_stmt(decl)
+                    self.write(self.unparse(decl), unroll=False)
+
+            return self.unroll_cache[key]
+
+        return stmt
+
+    def write(self, stmt, unroll=True):
         stmt = stmt.strip()
 
-        if self.solver == "boolector":
-            if stmt.startswith("(declare-sort"):
-                self.declared_sorts.append(stmt.split()[1])
+        if unroll and self.unroll:
+            if stmt.startswith(";"):
                 return
-            for n in self.declared_sorts:
-                stmt = stmt.replace(n, "(_ BitVec 16)")
+
+            stmt = re.sub(r" ;.*", "", stmt)
+            stmt = self.unroll_buffer + stmt
+            self.unroll_buffer = ""
+
+            s = re.sub(r"\|[^|]*\|", "", stmt)
+            if s.count("(") != s.count(")"):
+                self.unroll_buffer = stmt + " "
+                return
+
+            s = self.parse(stmt)
+
+            if self.debug_print:
+                print("-> %s" % s)
+
+            if len(s) == 3 and s[0] == "declare-sort" and s[2] == "0":
+                self.unroll_sorts.add(s[1])
+                return
+
+            elif len(s) == 4 and s[0] == "declare-fun" and s[2] == [] and s[3] in self.unroll_sorts:
+                self.unroll_objs.add(s[1])
+                return
+
+            elif len(s) >= 4 and s[0] == "declare-fun":
+                for arg_sort in s[2]:
+                    if arg_sort in self.unroll_sorts:
+                        self.unroll_decls[s[1]] = s
+                        return
+
+            elif len(s) >= 4 and s[0] == "define-fun":
+                for arg_name, arg_sort in s[2]:
+                    if arg_sort in self.unroll_sorts:
+                        self.unroll_decls[s[1]] = s
+                        return
+
+            stmt = self.unparse(self.unroll_stmt(s))
+
+            if stmt == "(push 1)":
+                self.unroll_stack.append((
+                    deepcopy(self.unroll_sorts),
+                    deepcopy(self.unroll_objs),
+                    deepcopy(self.unroll_decls),
+                    deepcopy(self.unroll_cache),
+                ))
+
+            if stmt == "(pop 1)":
+                self.unroll_sorts, self.unroll_objs, self.unroll_decls, self.unroll_cache = self.unroll_stack.pop()
 
         if self.debug_print:
             print("> %s" % stmt)
@@ -297,6 +405,11 @@ class SmtIo:
             return expr, cursor
         return worker(stmt)[0]
 
+    def unparse(self, stmt):
+        if isinstance(stmt, list):
+            return "(" + " ".join([self.unparse(s) for s in stmt]) + ")"
+        return stmt
+
     def bv2hex(self, v):
         h = ""
         v = self.bv2bin(v)
@@ -416,10 +529,11 @@ class SmtIo:
 class SmtOpts:
     def __init__(self):
         self.shortopts = "s:v"
-        self.longopts = ["no-progress", "dump-smt2="]
+        self.longopts = ["unroll", "no-progress", "dump-smt2="]
         self.solver = "z3"
         self.debug_print = False
         self.debug_file = None
+        self.unroll = True
         self.timeinfo = True
 
     def handle(self, o, a):
@@ -427,6 +541,8 @@ class SmtOpts:
             self.solver = a
         elif o == "-v":
             self.debug_print = True
+        elif o == "--unroll":
+            self.unroll = True
         elif o == "--no-progress":
             self.timeinfo = True
         elif o == "--dump-smt2":
@@ -444,8 +560,11 @@ class SmtOpts:
     -v
         enable debug output
 
+    --unroll
+        unroll uninterpreted functions
+
     --no-progress
-        disable running timer display during solving
+        disable timer display during solving
 
     --dump-smt2 <filename>
         write smt2 statements to file