Added "yosys-smtbmc --aig"
authorClifford Wolf <clifford@clifford.at>
Thu, 1 Dec 2016 11:57:26 +0000 (12:57 +0100)
committerClifford Wolf <clifford@clifford.at>
Thu, 1 Dec 2016 12:16:57 +0000 (13:16 +0100)
backends/smt2/smtbmc.py

index 3384789ee8b7dacc1df12966dd43288c5183220b..56c7bccc15c4811cbc6fdecb1cbb0863499d33f7 100644 (file)
@@ -22,12 +22,14 @@ import os, sys, getopt, re
 from smtio import SmtIo, SmtOpts, MkVcd
 from collections import defaultdict
 
+got_topt = False
 skip_steps = 0
 step_size = 1
 num_steps = 20
 append_steps = 0
 vcdfile = None
 cexfile = None
+aigprefix = None
 vlogtbfile = None
 inconstr = list()
 outconstr = None
@@ -66,6 +68,11 @@ yosys-smtbmc [options] <yosys_smt2_output>
     --cex <cex_filename>
         read cex file as written by ABC's "write_cex -n"
 
+    --aig <prefix>
+        read AIGER map file (as written by Yosys' "write_aiger -map")
+        and AIGER witness file. The file names are <prefix>.aim for
+        the map file and <prefix>.aiw for the witness file.
+
     --noinfo
         only run the core proof, do not collect and print any
         additional information (e.g. which assert failed)
@@ -104,12 +111,14 @@ yosys-smtbmc [options] <yosys_smt2_output>
 
 try:
     opts, args = getopt.getopt(sys.argv[1:], so.shortopts + "t:igm:", so.longopts +
-            ["final-only", "assume-skipped=", "smtc=", "cex=", "dump-vcd=", "dump-vlogtb=", "dump-smtc=", "dump-all", "noinfo", "append="])
+            ["final-only", "assume-skipped=", "smtc=", "cex=", "aig=",
+             "dump-vcd=", "dump-vlogtb=", "dump-smtc=", "dump-all", "noinfo", "append="])
 except:
     usage()
 
 for o, a in opts:
     if o == "-t":
+        got_topt = True
         a = a.split(":")
         if len(a) == 1:
             num_steps = int(a[0])
@@ -121,7 +130,7 @@ for o, a in opts:
             step_size = int(a[1])
             num_steps = int(a[2])
         else:
-            assert 0
+            assert False
     elif o == "--assume-skipped":
         assume_skipped = int(a)
     elif o == "--final-only":
@@ -130,6 +139,8 @@ for o, a in opts:
         inconstr.append(a)
     elif o == "--cex":
         cexfile = a
+    elif o == "--aig":
+        aigprefix = a
     elif o == "--dump-vcd":
         vcdfile = a
     elif o == "--dump-vlogtb":
@@ -195,7 +206,7 @@ for fn in inconstr:
                     current_states = set(["final-%d" % i for i in range(-i, num_steps+1)])
                     constr_final_start = -i if constr_final_start is None else min(constr_final_start, -i)
                 else:
-                    assert 0
+                    assert False
                 continue
 
             if tokens[0] == "state":
@@ -214,7 +225,7 @@ for fn in inconstr:
                             for i in range(lower, upper+1):
                                 current_states.add(i)
                         else:
-                            assert 0
+                            assert False
                 continue
 
             if tokens[0] == "always":
@@ -225,7 +236,7 @@ for fn in inconstr:
                     assert i < 0
                     current_states = set(range(-i, num_steps+1))
                 else:
-                    assert 0
+                    assert False
                 continue
 
             if tokens[0] == "assert":
@@ -252,7 +263,7 @@ for fn in inconstr:
                 so.logic = " ".join(tokens[1:])
                 continue
 
-            assert 0
+            assert False
 
 
 def get_constr_expr(db, state, final=False, getvalues=False):
@@ -357,6 +368,116 @@ if cexfile is not None:
             # print("cex@%d: %s" % (step, smtexpr))
             constr_assumes[step].append((cexfile, smtexpr))
 
+if aigprefix is not None:
+    input_map = dict()
+    init_map = dict()
+    latch_map = dict()
+
+    with open(aigprefix + ".aim", "r") as f:
+        for entry in f.read().splitlines():
+            entry = entry.split()
+
+            if entry[0] == "input":
+                input_map[int(entry[1])] = (entry[3], int(entry[2]))
+                continue
+
+            if entry[0] == "init":
+                init_map[int(entry[1])] = (entry[3], int(entry[2]))
+                continue
+
+            if entry[0] in ["latch", "invlatch"]:
+                latch_map[int(entry[1])] = (entry[3], int(entry[2]), entry[0] == "invlatch")
+                continue
+
+            if entry[0] in ["output", "wire"]:
+                continue
+
+            assert False
+
+    with open(aigprefix + ".aiw", "r") as f:
+        got_state = False
+        got_ffinit = False
+        step = 0
+
+        for entry in f.read().splitlines():
+            if len(entry) == 0 or entry[0] in "bcjfu.":
+                continue
+
+            if not got_state:
+                got_state = True
+                assert entry == "1"
+                continue
+
+            if not got_ffinit:
+                got_ffinit = True
+                if len(init_map) == 0:
+                    for i in range(len(entry)):
+                        if entry[i] == "x":
+                            continue
+
+                        if i in latch_map:
+                            value = int(entry[i])
+                            name = latch_map[i][0]
+                            bitidx = latch_map[i][1]
+                            invert = latch_map[i][2]
+
+                            if invert:
+                                value = 1 - value
+
+                            path = smt.get_path(topmod, name)
+                            width = smt.net_width(topmod, path)
+
+                            if width == 1:
+                                assert bitidx == 0
+                                smtexpr = "(= [%s] %s)" % (name, "true" if value else "false")
+                            else:
+                                smtexpr = "(= ((_ extract %d %d) [%s]) #b%d)" % (bitidx, bitidx, name, value)
+
+                            constr_assumes[0].append((cexfile, smtexpr))
+                continue
+
+            for i in range(len(entry)):
+                if entry[i] == "x":
+                    continue
+
+                if (step == 0) and (i in init_map):
+                    value = int(entry[i])
+                    name = init_map[i][0]
+                    bitidx = init_map[i][1]
+
+                    path = smt.get_path(topmod, name)
+                    width = smt.net_width(topmod, path)
+
+                    if width == 1:
+                        assert bitidx == 0
+                        smtexpr = "(= [%s] %s)" % (name, "true" if value else "false")
+                    else:
+                        smtexpr = "(= ((_ extract %d %d) [%s]) #b%d)" % (bitidx, bitidx, name, value)
+
+                    constr_assumes[0].append((cexfile, smtexpr))
+
+                if i in input_map:
+                    value = int(entry[i])
+                    name = input_map[i][0]
+                    bitidx = input_map[i][1]
+
+                    path = smt.get_path(topmod, name)
+                    width = smt.net_width(topmod, path)
+
+                    if width == 1:
+                        assert bitidx == 0
+                        smtexpr = "(= [%s] %s)" % (name, "true" if value else "false")
+                    else:
+                        smtexpr = "(= ((_ extract %d %d) [%s]) #b%d)" % (bitidx, bitidx, name, value)
+
+                    constr_assumes[step].append((cexfile, smtexpr))
+
+            if not got_topt:
+                skip_steps = step
+                assume_skipped = 0
+                num_steps = max(num_steps, step+1)
+            step += 1
+
 def write_vcd_trace(steps_start, steps_stop, index):
     filename = vcdfile.replace("%", index)
     print_msg("Writing trace to VCD file: %s" % (filename))