rework branch proof to use br_input_record
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Wed, 22 Jul 2020 10:14:32 +0000 (11:14 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Wed, 22 Jul 2020 10:14:32 +0000 (11:14 +0100)
src/soc/fu/branch/formal/proof_main_stage.py

index 3e1879aef10213831bebd6573343f801079d1de3..626ecf504f0c9eaacdf83f6bb93ed0151ca3d14c 100644 (file)
@@ -14,8 +14,8 @@ from nmutil.extend import exts
 from nmigen.cli import rtlil
 
 from soc.fu.branch.main_stage import BranchMainStage
-from soc.fu.alu.pipe_data import ALUPipeSpec
-from soc.fu.alu.alu_input_record import CompALUOpSubset
+from soc.fu.branch.pipe_data import BranchPipeSpec
+from soc.fu.branch.br_input_record import CompBROpSubset
 from soc.decoder.power_enums import MicrOp
 import unittest
 
@@ -31,7 +31,7 @@ class Driver(Elaboratable):
         m = Module()
         comb = m.d.comb
 
-        rec = CompALUOpSubset()
+        rec = CompBROpSubset()
         recwidth = 0
         # Setup random inputs for dut.op
         for p in rec.ports():
@@ -39,15 +39,21 @@ class Driver(Elaboratable):
             recwidth += width
             comb += p.eq(AnyConst(width))
 
-        pspec = ALUPipeSpec(id_wid=2)
+        pspec = BranchPipeSpec(id_wid=2)
         m.submodules.dut = dut = BranchMainStage(pspec)
 
-        comb += dut.i.ctx.op.eq(rec)
+        # convenience aliases
+        op = dut.i.ctx.op
+        cia, cr_in, fast1, fast2 = op.cia, dut.i.cr, dut.i.fast1, dut.i.fast2
+        ctr = fast1
+        lr_o, nia_o = dut.o.lr, dut.o.nia
+
+        comb += op.eq(rec)
 
         # Assert that op gets copied from the input to output
         for rec_sig in rec.ports():
             name = rec_sig.name
-            dut_sig = getattr(dut.o.ctx.op, name)
+            dut_sig = getattr(op, name)
             comb += Assert(dut_sig == rec_sig)
 
         # Full width CR register. Will have bitfield extracted for
@@ -57,13 +63,9 @@ class Driver(Elaboratable):
         cr_arr = Array([cr[(7-i)*4:(7-i)*4+4] for i in range(8)])
         cr_bit_arr = Array([cr[31-i] for i in range(32)])
 
-        cia, cr_in, fast1, fast2 = dut.i.cia, dut.i.cr, dut.i.fast1, dut.i.fast2
-        ctr = fast1
-        lr_o, nia_o = dut.o.lr, dut.o.nia
-
         comb += [fast2.eq(AnyConst(64)),
                  ctr.eq(AnyConst(64)),
-                 cia.eq(AnyConst(64))]
+                 ]
 
         i_fields = dut.fields.FormI
         b_fields = dut.fields.FormB