Improved smtbmc vcd generation performance
authorClifford Wolf <clifford@clifford.at>
Thu, 18 Aug 2016 09:17:45 +0000 (11:17 +0200)
committerClifford Wolf <clifford@clifford.at>
Thu, 18 Aug 2016 09:17:45 +0000 (11:17 +0200)
backends/smt2/smtbmc.py
backends/smt2/smtio.py

index cb491b80050e95502f93c5e10321569dc318ea1c..75bc51abb87cb19ce707a84ac5fc7b170fdbcd7a 100644 (file)
@@ -111,14 +111,16 @@ def write_vcd_model(steps):
     print("%s Writing model to VCD file." % smt.timestamp())
 
     vcd = mkvcd(open(vcdfile, "w"))
+
     for netpath in sorted(smt.hiernets(topmod)):
-        width = len(smt.get_net_bin(topmod, netpath, "s0"))
-        vcd.add_net([topmod] + netpath, width)
+        vcd.add_net([topmod] + netpath, smt.net_width(topmod, netpath))
 
     for i in range(steps):
         vcd.set_time(i)
-        for netpath in sorted(smt.hiernets(topmod)):
-            vcd.set_net([topmod] + netpath, smt.get_net_bin(topmod, netpath, "s%d" % i))
+        path_list = sorted(smt.hiernets(topmod))
+        value_list = smt.get_net_bin_list(topmod, path_list, "s%d" % i)
+        for path, value in zip(path_list, value_list):
+            vcd.set_net([topmod] + path, value)
 
     vcd.set_time(steps)
 
index 53d2ec57bfdbb105a594dc7906281767f84a4b8b..1b3944ebf3a569b5d68aedd55b27c58b5374deec 100644 (file)
@@ -157,7 +157,7 @@ class smtio:
                 print("< %s" % line)
             if count_brackets == 0:
                 break
-            if not self.p.poll():
+            if self.p.poll():
                 print("SMT Solver terminated unexpectedly: %s" % "".join(stmt))
                 sys.exit(1)
 
@@ -297,33 +297,51 @@ class smtio:
         self.write("(get-value (%s))" % (expr))
         return self.parse(self.read())[0][1]
 
-    def get_net(self, mod_name, net_path, state_name):
-        def mkexpr(mod, base, path):
-            if len(path) == 1:
-                assert mod in self.modinfo
-                assert path[0] in self.modinfo[mod].wsize
-                return "(|%s_n %s| %s)" % (mod, path[0], base)
+    def get_list(self, expr_list):
+         self.write("(get-value (%s))" % " ".join(expr_list))
+         return [n[1] for n in self.parse(self.read())]
+
+    def net_expr(self, mod, base, path):
+        if len(path) == 1:
+            assert mod in self.modinfo
+            assert path[0] in self.modinfo[mod].wsize
+            return "(|%s_n %s| %s)" % (mod, path[0], base)
+
+        assert mod in self.modinfo
+        assert path[0] in self.modinfo[mod].cells
 
+        nextmod = self.modinfo[mod].cells[path[0]]
+        nextbase = "(|%s_h %s| %s)" % (mod, path[0], base)
+        return self.net_expr(nextmod, nextbase, path[1:])
+
+    def net_width(self, mod, net_path):
+        for i in range(len(net_path)-1):
             assert mod in self.modinfo
-            assert path[0] in self.modinfo[mod].cells
+            assert net_path[i] in self.modinfo[mod].cells
+            mod = self.modinfo[mod].cells[net_path[i]]
 
-            nextmod = self.modinfo[mod].cells[path[0]]
-            nextbase = "(|%s_h %s| %s)" % (mod, path[0], base)
-            return mkexpr(nextmod, nextbase, path[1:])
+        assert mod in self.modinfo
+        assert net_path[-1] in self.modinfo[mod].wsize
+        return self.modinfo[mod].wsize[net_path[-1]]
 
-        return self.get(mkexpr(mod_name, state_name, net_path))
+    def get_net(self, mod_name, net_path, state_name):
+        return self.get(self.net_expr(mod_name, state_name, net_path))
 
-    def get_net_bool(self, mod_name, net_path, state_name):
-        v = self.get_net(mod_name, net_path, state_name)
-        assert v in ["true", "false"]
-        return 1 if v == "true" else 0
+    def get_net_list(self, mod_name, net_path_list, state_name):
+        return self.get_list([self.net_expr(mod_name, state_name, n) for n in net_path_list])
 
     def get_net_hex(self, mod_name, net_path, state_name):
         return self.bv2hex(self.get_net(mod_name, net_path, state_name))
 
+    def get_net_hex_list(self, mod_name, net_path_list, state_name):
+        return [self.bv2hex(v) for v in self.get_net_list(mod_name, net_path_list, state_name)]
+
     def get_net_bin(self, mod_name, net_path, state_name):
         return self.bv2bin(self.get_net(mod_name, net_path, state_name))
 
+    def get_net_bin_list(self, mod_name, net_path_list, state_name):
+        return [self.bv2bin(v) for v in self.get_net_list(mod_name, net_path_list, state_name)]
+
     def wait(self):
         self.p.wait()