cleanup logical pipe formal proof
[soc.git] / src / soc / fu / logical / formal / proof_main_stage.py
1 # Proof of correctness for partitioned equal signal combiner
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
3 """
4 Links:
5 * https://bugs.libre-soc.org/show_bug.cgi?id=331
6 * https://libre-soc.org/openpower/isa/fixedlogical/
7 """
8
9 from nmigen import (Module, Signal, Elaboratable, Mux, Cat, Repl,
10 signed)
11 from nmigen.asserts import Assert, AnyConst, Assume, Cover
12 from nmigen.test.utils import FHDLTestCase
13 from nmigen.lib.coding import PriorityEncoder
14 from nmigen.cli import rtlil
15
16 from soc.fu.logical.main_stage import LogicalMainStage
17 from soc.fu.alu.pipe_data import ALUPipeSpec
18 from soc.fu.alu.alu_input_record import CompALUOpSubset
19 from soc.decoder.power_enums import InternalOp
20 import unittest
21
22
23 # This defines a module to drive the device under test and assert
24 # properties about its outputs
25 class Driver(Elaboratable):
26 def __init__(self):
27 # inputs and outputs
28 pass
29
30 def popcount(self, sig, width):
31 result = 0
32 for i in range(width):
33 result = result + sig[i]
34 return result
35
36 def elaborate(self, platform):
37 m = Module()
38 comb = m.d.comb
39
40 rec = CompALUOpSubset()
41 # Setup random inputs for dut.op
42 for p in rec.ports():
43 width = p.width
44 comb += p.eq(AnyConst(width))
45
46 pspec = ALUPipeSpec(id_wid=2)
47 m.submodules.dut = dut = LogicalMainStage(pspec)
48
49 # convenience variables
50 a = dut.i.a
51 b = dut.i.b
52 carry_in = dut.i.xer_ca[0]
53 carry_in32 = dut.i.xer_ca[1]
54 o = dut.o.o
55
56 # setup random inputs
57 comb += [a.eq(AnyConst(64)),
58 b.eq(AnyConst(64)),
59 carry_in.eq(AnyConst(0b11)),
60 ]
61
62 comb += dut.i.ctx.op.eq(rec)
63
64 # Assert that op gets copied from the input to output
65 for rec_sig in rec.ports():
66 name = rec_sig.name
67 dut_sig = getattr(dut.o.ctx.op, name)
68 comb += Assert(dut_sig == rec_sig)
69
70 # signed and signed/32 versions of input a
71 a_signed = Signal(signed(64))
72 a_signed_32 = Signal(signed(32))
73 comb += a_signed.eq(a)
74 comb += a_signed_32.eq(a[0:32])
75
76 # main assertion of arithmetic operations
77 with m.Switch(rec.insn_type):
78 with m.Case(InternalOp.OP_AND):
79 comb += Assert(dut.o.o == a & b)
80 with m.Case(InternalOp.OP_OR):
81 comb += Assert(dut.o.o == a | b)
82 with m.Case(InternalOp.OP_XOR):
83 comb += Assert(dut.o.o == a ^ b)
84
85 with m.Case(InternalOp.OP_POPCNT):
86 with m.If(rec.data_len == 8):
87 comb += Assert(dut.o.o == self.popcount(a, 64))
88 with m.If(rec.data_len == 4):
89
90 for i in range(2):
91 comb += Assert(dut.o.o[i*32:(i+1)*32] ==
92 self.popcount(a[i*32:(i+1)*32], 32))
93 with m.If(rec.data_len == 1):
94 for i in range(8):
95 comb += Assert(dut.o.o[i*8:(i+1)*8] ==
96 self.popcount(a[i*8:(i+1)*8], 8))
97
98 with m.Case(InternalOp.OP_PRTY):
99 with m.If(rec.data_len == 8):
100 result = 0
101 for i in range(8):
102 result = result ^ a[i*8]
103 comb += Assert(dut.o.o == result)
104 with m.If(rec.data_len == 4):
105 result_low = 0
106 result_high = 0
107 for i in range(4):
108 result_low = result_low ^ a[i*8]
109 result_high = result_high ^ a[i*8 + 32]
110 comb += Assert(dut.o.o[0:32] == result_low)
111 comb += Assert(dut.o.o[32:64] == result_high)
112 with m.Case(InternalOp.OP_CNTZ):
113 XO = dut.fields.FormX.XO[0:-1]
114 with m.If(rec.is_32bit):
115 m.submodules.pe32 = pe32 = PriorityEncoder(32)
116 peo = Signal(range(0, 32+1))
117 with m.If(pe32.n):
118 comb += peo.eq(32)
119 with m.Else():
120 comb += peo.eq(pe32.o)
121 with m.If(XO[-1]): # cnttzw
122 comb += pe32.i.eq(a[0:32])
123 comb += Assert(dut.o.o == peo)
124 with m.Else(): # cntlzw
125 comb += pe32.i.eq(a[0:32][::-1])
126 comb += Assert(dut.o.o == peo)
127 with m.Else():
128 m.submodules.pe64 = pe64 = PriorityEncoder(64)
129 peo64 = Signal(7)
130 with m.If(pe64.n):
131 comb += peo64.eq(64)
132 with m.Else():
133 comb += peo64.eq(pe64.o)
134 with m.If(XO[-1]): # cnttzd
135 comb += pe64.i.eq(a[0:64])
136 comb += Assert(dut.o.o == peo64)
137 with m.Else(): # cntlzd
138 comb += pe64.i.eq(a[0:64][::-1])
139 comb += Assert(dut.o.o == peo64)
140
141
142 return m
143
144
145 class LogicalTestCase(FHDLTestCase):
146 def test_formal(self):
147 module = Driver()
148 self.assertFormal(module, mode="bmc", depth=2)
149 self.assertFormal(module, mode="cover", depth=2)
150 def test_ilang(self):
151 dut = Driver()
152 vl = rtlil.convert(dut, ports=[])
153 with open("main_stage.il", "w") as f:
154 f.write(vl)
155
156
157 if __name__ == '__main__':
158 unittest.main()