Various fixes and improvements in yosys-smtbmc
authorClifford Wolf <clifford@clifford.at>
Mon, 29 Aug 2016 11:53:12 +0000 (13:53 +0200)
committerClifford Wolf <clifford@clifford.at>
Mon, 29 Aug 2016 11:53:12 +0000 (13:53 +0200)
backends/smt2/smtbmc.py

index 448a14202e6260a8486058ed9caec864163f9483..1bb9dd93e1e227149b2e353b98ce9f0fe746cee5 100644 (file)
@@ -91,7 +91,7 @@ for o, a in opts:
     if o == "-t":
         a = a.split(":")
         if len(a) == 1:
-            num_steps = int(a[1])
+            num_steps = int(a[0])
         elif len(a) == 2:
             skip_steps = int(a[0])
             num_steps = int(a[1])
@@ -139,9 +139,12 @@ constr_assumes = defaultdict(list)
 
 for fn in inconstr:
     current_states = None
+    current_line = 0
 
     with open(fn, "r") as f:
         for line in f:
+            current_line += 1
+
             if line.startswith("#"):
                 continue
 
@@ -203,7 +206,7 @@ for fn in inconstr:
                 assert current_states is not None
 
                 for state in current_states:
-                    constr_asserts[state].append(" ".join(tokens[1:]))
+                    constr_asserts[state].append(("%s:%d" % (fn, current_line), " ".join(tokens[1:])))
 
                 continue
 
@@ -211,14 +214,14 @@ for fn in inconstr:
                 assert current_states is not None
 
                 for state in current_states:
-                    constr_assumes[state].append(" ".join(tokens[1:]))
+                    constr_assumes[state].append(("%s:%d" % (fn, current_line), " ".join(tokens[1:])))
 
                 continue
 
             assert 0
 
 
-def get_constr_expr(db, state, final=False):
+def get_constr_expr(db, state, final=False, getvalues=False):
     if final:
         if ("final-%d" % state) not in db:
             return "true"
@@ -243,9 +246,17 @@ def get_constr_expr(db, state, final=False):
         return match.group(1) + expr
 
     expr_list = list()
-    for expr in db[("final-%d" % state) if final else state]:
-        expr = netref_regex.sub(replace_netref, expr)
-        expr_list.append(expr)
+    for loc, expr in db[("final-%d" % state) if final else state]:
+        actual_expr = netref_regex.sub(replace_netref, expr)
+        if getvalues:
+            expr_list.append((loc, expr, actual_expr))
+        else:
+            expr_list.append(actual_expr)
+
+    if getvalues:
+        loc_list, expr_list, acual_expr_list = zip(*expr_list)
+        value_list = smt.get_list(acual_expr_list)
+        return loc_list, expr_list, value_list
 
     if len(expr_list) == 0:
         return "true"
@@ -400,41 +411,42 @@ def write_constr_trace(steps):
             width = smt.modinfo[topmod].wsize[name]
             primary_inputs.append((name, width))
 
-        for k in range(steps):
-            if k != 0:
-                print("", file=f)
 
-            print("state %d" % k, file=f)
+        print("initial", file=f)
 
-            if k == 0:
-                regnames = sorted(smt.hiernets(topmod, regs_only=True))
-                regvals = smt.get_net_list(topmod, regnames, "s0")
+        regnames = sorted(smt.hiernets(topmod, regs_only=True))
+        regvals = smt.get_net_list(topmod, regnames, "s0")
 
-                for name, val in zip(regnames, regvals):
-                    print("assume (= [%s] %s)" % (".".join(name), val), file=f)
+        for name, val in zip(regnames, regvals):
+            print("assume (= [%s] %s)" % (".".join(name), val), file=f)
 
-                mems = sorted(smt.hiermems(topmod))
-                for mempath in mems:
-                    abits, width, ports = smt.mem_info(topmod, "s0", mempath)
-                    mem = smt.mem_expr(topmod, "s0", mempath)
+        mems = sorted(smt.hiermems(topmod))
+        for mempath in mems:
+            abits, width, ports = smt.mem_info(topmod, "s0", mempath)
+            mem = smt.mem_expr(topmod, "s0", mempath)
 
-                    addr_expr_list = list()
-                    for i in range(steps):
-                        for j in range(ports):
-                            addr_expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, j))
+            addr_expr_list = list()
+            for i in range(steps):
+                for j in range(ports):
+                    addr_expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, j))
 
-                    addr_list = set()
-                    for val in smt.get_list(addr_expr_list):
-                        addr_list.add(smt.bv2int(val))
+            addr_list = set()
+            for val in smt.get_list(addr_expr_list):
+                addr_list.add(smt.bv2int(val))
+
+            expr_list = list()
+            for i in addr_list:
+                expr_list.append("(select %s #b%s)" % (mem, format(i, "0%db" % abits)))
+
+            for i, val in zip(addr_list, smt.get_list(expr_list)):
+                print("assume (= (select [%s] #b%s) %s)" % (".".join(mempath), format(i, "0%db" % abits), val), file=f)
 
-                    expr_list = list()
-                    for i in addr_list:
-                        expr_list.append("(select %s #b%s)" % (mem, format(i, "0%db" % abits)))
 
-                    for i, val in zip(addr_list, smt.get_list(expr_list)):
-                        print("assume (= (select [%s] #b%s) %s)" % (".".join(mempath), format(i, "0%db" % abits), val), file=f)
+        for k in range(steps):
+            print("", file=f)
+            print("state %d" % k, file=f)
 
-            pi_names = [[name] for name, _ in primary_inputs]
+            pi_names = [[name] for name, _ in sorted(primary_inputs)]
             pi_values = smt.get_net_list(topmod, pi_names, "s%d" % k)
 
             for name, val in zip(pi_names, pi_values):
@@ -452,20 +464,31 @@ def write_trace(steps):
         write_constr_trace(steps)
 
 
-def print_failed_asserts(mod, state, path):
+def print_failed_asserts_worker(mod, state, path):
     assert mod in smt.modinfo
 
-    if smt.get("(|%s_a| %s)" % (mod, state)) == "true":
+    if smt.get("(|%s_a| s%d)" % (mod, state)) == "true":
         return
 
     for cellname, celltype in smt.modinfo[mod].cells.items():
-        print_failed_asserts(celltype, "(|%s_h %s| %s)" % (mod, cellname, state), path + "." + cellname)
+        print_failed_asserts_worker(celltype, "(|%s_h %s| s%d)" % (mod, cellname, state), path + "." + cellname)
 
     for assertfun, assertinfo in smt.modinfo[mod].asserts.items():
-        if smt.get("(|%s| %s)" % (assertfun, state)) == "false":
+        if smt.get("(|%s| s%d)" % (assertfun, state)) == "false":
             print("%s Assert failed in %s: %s" % (smt.timestamp(), path, assertinfo))
 
 
+def print_failed_asserts(state, final=False):
+    loc_list, expr_list, value_list = get_constr_expr(constr_asserts, state, final=final, getvalues=True)
+
+    for loc, expr, value in zip(loc_list, expr_list, value_list):
+        if smt.bv2int(value) == 0:
+            print("%s Assert %s failed: %s" % (smt.timestamp(), loc, expr))
+
+    if not final:
+        print_failed_asserts_worker(topmod, state, topmod)
+
+
 if tempind:
     retstatus = False
     skip_counter = step_size
@@ -497,7 +520,7 @@ if tempind:
         if smt.check_sat() == "sat":
             if step == 0:
                 print("%s Temporal induction failed!" % smt.timestamp())
-                print_failed_asserts(topmod, "s%d" % step, topmod)
+                print_failed_asserts(num_steps)
                 write_trace(num_steps+1)
 
         else:
@@ -556,8 +579,9 @@ else: # not tempind
 
                 if smt.check_sat() == "sat":
                     print("%s BMC failed!" % smt.timestamp())
-                    print_failed_asserts(topmod, "s%d" % step, topmod)
-                    write_trace(step+step_size)
+                    for i in range(step, last_check_step+1):
+                        print_failed_asserts(i)
+                    write_trace(last_check_step+1)
                     retstatus = False
                     break
 
@@ -580,8 +604,8 @@ else: # not tempind
 
                     if smt.check_sat() == "sat":
                         print("%s BMC failed!" % smt.timestamp())
-                        print_failed_asserts(topmod, "s%d" % i, topmod)
-                        write_trace(i)
+                        print_failed_asserts(i, final=True)
+                        write_trace(i+1)
                         retstatus = False
                         break