replace PartitionedSignal with SimdSignal
[soc.git] / src / soc / fu / logical / formal / proof_main_stage.py
1 # Proof of correctness for partitioned equal signal combiner
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
3 """
4 Links:
5 * https://bugs.libre-soc.org/show_bug.cgi?id=331
6 * https://libre-soc.org/openpower/isa/fixedlogical/
7 """
8
9 from nmigen import (Module, Signal, Elaboratable, Mux, Cat, Repl,
10 signed)
11 from nmigen.asserts import Assert, AnyConst, Assume, Cover
12 from nmutil.formaltest import FHDLTestCase
13 from nmigen.lib.coding import PriorityEncoder
14 from nmigen.cli import rtlil
15
16 from soc.fu.logical.main_stage import LogicalMainStage
17 from soc.fu.alu.pipe_data import ALUPipeSpec
18 from soc.fu.alu.alu_input_record import CompALUOpSubset
19 from soc.decoder.power_enums import MicrOp
20 import unittest
21
22
23 def simple_popcount(sig, width):
24 """simple, naive (and obvious) popcount.
25 formal verification does not to be fast: it does have to be correct
26 """
27 result = 0
28 for i in range(width):
29 result = result + sig[i]
30 return result
31
32
33 # This defines a module to drive the device under test and assert
34 # properties about its outputs
35 class Driver(Elaboratable):
36 def __init__(self):
37 # inputs and outputs
38 pass
39
40 def elaborate(self, platform):
41 m = Module()
42 comb = m.d.comb
43
44 rec = CompALUOpSubset()
45 # Setup random inputs for dut.op
46 for p in rec.ports():
47 width = p.width
48 comb += p.eq(AnyConst(width))
49
50 pspec = ALUPipeSpec(id_wid=2)
51 m.submodules.dut = dut = LogicalMainStage(pspec)
52
53 # convenience variables
54 a = dut.i.a
55 b = dut.i.b
56 #carry_in = dut.i.xer_ca[0]
57 #carry_in32 = dut.i.xer_ca[1]
58 o = dut.o.o.data
59
60 # setup random inputs
61 comb += [a.eq(AnyConst(64)),
62 b.eq(AnyConst(64)),
63 #carry_in.eq(AnyConst(0b11)),
64 ]
65
66 comb += dut.i.ctx.op.eq(rec)
67
68 # Assert that op gets copied from the input to output
69 for rec_sig in rec.ports():
70 name = rec_sig.name
71 dut_sig = getattr(dut.o.ctx.op, name)
72 comb += Assert(dut_sig == rec_sig)
73
74 # signed and signed/32 versions of input a
75 a_signed = Signal(signed(64))
76 a_signed_32 = Signal(signed(32))
77 comb += a_signed.eq(a)
78 comb += a_signed_32.eq(a[0:32])
79
80 o_ok = Signal()
81 comb += o_ok.eq(1) # will be set to zero if no op takes place
82
83 # main assertion of arithmetic operations
84 with m.Switch(rec.insn_type):
85 with m.Case(MicrOp.OP_AND):
86 comb += Assert(o == a & b)
87 with m.Case(MicrOp.OP_OR):
88 comb += Assert(o == a | b)
89 with m.Case(MicrOp.OP_XOR):
90 comb += Assert(o == a ^ b)
91
92 with m.Case(MicrOp.OP_POPCNT):
93 with m.If(rec.data_len == 8):
94 comb += Assert(o == simple_popcount(a, 64))
95 with m.If(rec.data_len == 4):
96 for i in range(2):
97 slc = slice(i*32, (i+1)*32)
98 comb += Assert(o[slc] == simple_popcount(a[slc], 32))
99 with m.If(rec.data_len == 1):
100 for i in range(8):
101 slc = slice(i*8, (i+1)*8)
102 comb += Assert(o[slc] == simple_popcount(a[slc], 8))
103
104 with m.Case(MicrOp.OP_PRTY):
105 with m.If(rec.data_len == 8):
106 result = 0
107 for i in range(8):
108 result = result ^ a[i*8]
109 comb += Assert(o == result)
110 with m.If(rec.data_len == 4):
111 result_low = 0
112 result_high = 0
113 for i in range(4):
114 result_low = result_low ^ a[i*8]
115 result_high = result_high ^ a[i*8 + 32]
116 comb += Assert(o[0:32] == result_low)
117 comb += Assert(o[32:64] == result_high)
118
119 with m.Case(MicrOp.OP_CNTZ):
120 XO = dut.fields.FormX.XO[0:-1]
121 with m.If(rec.is_32bit):
122 m.submodules.pe32 = pe32 = PriorityEncoder(32)
123 peo = Signal(range(0, 32+1))
124 with m.If(pe32.n):
125 comb += peo.eq(32)
126 with m.Else():
127 comb += peo.eq(pe32.o)
128 with m.If(XO[-1]): # cnttzw
129 comb += pe32.i.eq(a[0:32])
130 comb += Assert(o == peo)
131 with m.Else(): # cntlzw
132 comb += pe32.i.eq(a[0:32][::-1])
133 comb += Assert(o == peo)
134 with m.Else():
135 m.submodules.pe64 = pe64 = PriorityEncoder(64)
136 peo64 = Signal(7)
137 with m.If(pe64.n):
138 comb += peo64.eq(64)
139 with m.Else():
140 comb += peo64.eq(pe64.o)
141 with m.If(XO[-1]): # cnttzd
142 comb += pe64.i.eq(a[0:64])
143 comb += Assert(o == peo64)
144 with m.Else(): # cntlzd
145 comb += pe64.i.eq(a[0:64][::-1])
146 comb += Assert(o == peo64)
147
148 with m.Case(MicrOp.OP_CMPB):
149 for i in range(8):
150 slc = slice(i*8, (i+1)*8)
151 with m.If(a[slc] == b[slc]):
152 comb += Assert(o[slc] == 0xff)
153 with m.Else():
154 comb += Assert(o[slc] == 0)
155
156 with m.Case(MicrOp.OP_BPERM):
157 # note that this is a copy of the beautifully-documented
158 # proof_bpermd.py
159 comb += Assert(o[8:] == 0)
160 for i in range(8):
161 index = a[i*8:i*8+8]
162 with m.If(index >= 64):
163 comb += Assert(o[i] == 0)
164 with m.Else():
165 for j in range(64):
166 with m.If(index == j):
167 comb += Assert(o[i] == b[63-j])
168
169 with m.Default():
170 comb += o_ok.eq(0)
171
172 # check that data ok was only enabled when op actioned
173 comb += Assert(dut.o.o.ok == o_ok)
174
175 return m
176
177
178 class LogicalTestCase(FHDLTestCase):
179 def test_formal(self):
180 module = Driver()
181 self.assertFormal(module, mode="bmc", depth=2)
182 self.assertFormal(module, mode="cover", depth=2)
183 def test_ilang(self):
184 dut = Driver()
185 vl = rtlil.convert(dut, ports=[])
186 with open("main_stage.il", "w") as f:
187 f.write(vl)
188
189
190 if __name__ == '__main__':
191 unittest.main()