assign and test on Data, TODO add Data.ok checking in CR proof
[soc.git] / src / soc / fu / cr / formal / proof_main_stage.py
index 79476cf5a515d7733e5f0348e8d78ba9363a9436..26ebc79ebb1785c50b187c19937963f30e11253a 100644 (file)
@@ -48,8 +48,8 @@ class Driver(Elaboratable):
         a = dut.i.a
         b = dut.i.b
         cr = full_cr_in
-        full_cr_out = dut.o.full_cr
-        o = dut.o.o
+        full_cr_out = dut.o.full_cr.data
+        o = dut.o.o.data
 
         # setup random inputs
         comb += [a.eq(AnyConst(64)),
@@ -72,7 +72,13 @@ class Driver(Elaboratable):
         with m.Switch(rec.insn_type):
             # CR_ISEL takes cr_a
             with m.Case(InternalOp.OP_ISEL):
-                comb += dut.i.cr_a.eq(cr_a_in)
+                # grab the MSBs of the cr bit selector
+                bc = Signal(3, reset_less=True)
+                comb += bc.eq(a_fields.BC[2:5])
+
+                # Use the MSBs to select which CR register to feed
+                # into cr_a
+                comb += dut.i.cr_a.eq(cr_input_arr[bc])
 
             # For OP_CROP, we need to input the corresponding CR
             # registers for BA, BB, and BT
@@ -97,7 +103,7 @@ class Driver(Elaboratable):
                     with m.If(i != bt):
                         comb += cr_output_arr[i].eq(cr_input_arr[i])
                     with m.Else():
-                        comb += cr_output_arr[i].eq(dut.o.cr_o)
+                        comb += cr_output_arr[i].eq(dut.o.cr_o.data)
 
             with m.Case(InternalOp.OP_MCRF):
                 # This does a similar thing to OP_CROP above, with
@@ -117,16 +123,24 @@ class Driver(Elaboratable):
                     with m.If(i != bf):
                         comb += cr_output_arr[i].eq(cr_input_arr[i])
                     with m.Else():
-                        comb += cr_output_arr[i].eq(dut.o.cr_o)
+                        comb += cr_output_arr[i].eq(dut.o.cr_o.data)
 
             # For the other two, they take the full CR as input, and
             # output a full CR. This handles that
             with m.Default():
                 comb += dut.i.full_cr.eq(full_cr_in)
-                comb += cr_o.eq(dut.o.full_cr)
+                comb += cr_o.eq(full_cr_out)
 
         comb += dut.i.ctx.op.eq(rec)
 
+        # test signals for output conditions.  these must only be enabled for
+        # specific instructions, indicating that they generated the output.
+        # this is critically important because the "ok" signals are used by
+        # MultiCompUnit to request a write to the regfile.
+        o_ok = Signal()
+        cr_o_ok = Signal()
+        full_cr_o_ok = Signal()
+
         # Assert that op gets copied from the input to output
         for rec_sig in rec.ports():
             name = rec_sig.name
@@ -143,6 +157,8 @@ class Driver(Elaboratable):
                 for i in range(8):
                     with m.If(FXM[i]):
                         comb += Assert(cr_o[4*i:4*i+4] == a[4*i:4*i+4])
+                comb += cr_o_ok.eq(1)
+
             with m.Case(InternalOp.OP_MFCR):
                 with m.If(rec.insn[20]):  # mfocrf
                     for i in range(8):
@@ -152,6 +168,8 @@ class Driver(Elaboratable):
                             comb += Assert(o[4*i:4*i+4] == 0)
                 with m.Else(): # mfcrf
                     comb += Assert(o == cr)
+                comb += o_ok.eq(1)
+
             with m.Case(InternalOp.OP_MCRF):
                 BF = xl_fields.BF[0:-1]
                 BFA = xl_fields.BFA[0:-1]
@@ -160,6 +178,7 @@ class Driver(Elaboratable):
                 for i in range(8):
                     with m.If(BF != 7-i):
                         comb += Assert(cr_o[i*4:i*4+4] == cr[i*4:i*4+4])
+                comb += cr_o_ok.eq(1)
 
             with m.Case(InternalOp.OP_CROP):
                 bt = Signal(xl_fields.BT[0:-1].shape(), reset_less=True)
@@ -196,17 +215,20 @@ class Driver(Elaboratable):
                     comb += Assert(bit_o == bit_a ^ bit_b)
 
             with m.Case(InternalOp.OP_ISEL):
-                # just like in branch, CR0-7 is incoming into cr_a, we
-                # need to select from the last 2 bits of BC
-                BC = a_fields.BC[0:-1][0:2]
-                cr_bits = Array([cr_a_in[3-i] for i in range(4)])
+                # Extract the bit selector of the CR
+                bc = Signal(a_fields.BC[0:-1].shape(), reset_less=True)
+                comb += bc.eq(a_fields.BC[0:-1])
 
-                # The bit of (cr_a=CR0-7) selected by BC
+                # Extract the bit from CR
                 cr_bit = Signal(reset_less=True)
-                comb += cr_bit.eq(cr_bits[BC])
+                comb += cr_bit.eq(cr_arr[bc])
 
                 # select a or b as output
                 comb += Assert(o == Mux(cr_bit, a, b))
+                comb += o_ok.eq(1)
+
+        # check that data ok was only enabled when op actioned
+        comb += Assert(dut.o.o.ok == o_ok)
 
         return m