f3baa1b67257fe34b770822581829d87cda6a4bb
[soc.git] / src / soc / fu / div / formal / proof_main_stage.py
1 # Proof of correctness for partitioned equal signal combiner
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
3 """
4 Links:
5 * https://bugs.libre-soc.org/show_bug.cgi?id=331
6 * https://libre-soc.org/openpower/isa/fixedlogical/
7 """
8
9 from nmigen import (Module, Signal, Elaboratable, Mux, Cat, Repl,
10 signed)
11 from nmigen.asserts import Assert, AnyConst, Assume, Cover
12 from nmutil.formaltest import FHDLTestCase
13 from nmigen.lib.coding import PriorityEncoder
14 from nmigen.cli import rtlil
15
16 from soc.fu.logical.main_stage import LogicalMainStage
17 from soc.fu.alu.pipe_data import ALUPipeSpec
18 from soc.fu.alu.alu_input_record import CompALUOpSubset
19 from soc.decoder.power_enums import InternalOp
20 import unittest
21
22
23 # This defines a module to drive the device under test and assert
24 # properties about its outputs
25 class Driver(Elaboratable):
26 def __init__(self):
27 # inputs and outputs
28 pass
29
30 def popcount(self, sig, width):
31 result = 0
32 for i in range(width):
33 result = result + sig[i]
34 return result
35
36 def elaborate(self, platform):
37 m = Module()
38 comb = m.d.comb
39
40 rec = CompALUOpSubset()
41 recwidth = 0
42 # Setup random inputs for dut.op
43 for p in rec.ports():
44 width = p.width
45 recwidth += width
46 comb += p.eq(AnyConst(width))
47
48 pspec = ALUPipeSpec(id_wid=2, op_wid=recwidth)
49 m.submodules.dut = dut = LogicalMainStage(pspec)
50
51 # convenience variables
52 a = dut.i.a
53 b = dut.i.b
54 carry_in = dut.i.xer_ca[0]
55 carry_in32 = dut.i.xer_ca[1]
56 so_in = dut.i.xer_so
57 o = dut.o.o
58
59 # setup random inputs
60 comb += [a.eq(AnyConst(64)),
61 b.eq(AnyConst(64)),
62 carry_in.eq(AnyConst(0b11)),
63 so_in.eq(AnyConst(1))]
64
65 comb += dut.i.ctx.op.eq(rec)
66
67 # Assert that op gets copied from the input to output
68 for rec_sig in rec.ports():
69 name = rec_sig.name
70 dut_sig = getattr(dut.o.ctx.op, name)
71 comb += Assert(dut_sig == rec_sig)
72
73 # signed and signed/32 versions of input a
74 a_signed = Signal(signed(64))
75 a_signed_32 = Signal(signed(32))
76 comb += a_signed.eq(a)
77 comb += a_signed_32.eq(a[0:32])
78
79 # main assertion of arithmetic operations
80 with m.Switch(rec.insn_type):
81 with m.Case(InternalOp.OP_AND):
82 comb += Assert(dut.o.o == a & b)
83 with m.Case(InternalOp.OP_OR):
84 comb += Assert(dut.o.o == a | b)
85 with m.Case(InternalOp.OP_XOR):
86 comb += Assert(dut.o.o == a ^ b)
87
88 with m.Case(InternalOp.OP_POPCNT):
89 with m.If(rec.data_len == 8):
90 comb += Assert(dut.o.o == self.popcount(a, 64))
91 with m.If(rec.data_len == 4):
92
93 for i in range(2):
94 comb += Assert(dut.o.o[i*32:(i+1)*32] ==
95 self.popcount(a[i*32:(i+1)*32], 32))
96 with m.If(rec.data_len == 1):
97 for i in range(8):
98 comb += Assert(dut.o.o[i*8:(i+1)*8] ==
99 self.popcount(a[i*8:(i+1)*8], 8))
100
101 with m.Case(InternalOp.OP_PRTY):
102 with m.If(rec.data_len == 8):
103 result = 0
104 for i in range(8):
105 result = result ^ a[i*8]
106 comb += Assert(dut.o.o == result)
107 with m.If(rec.data_len == 4):
108 result_low = 0
109 result_high = 0
110 for i in range(4):
111 result_low = result_low ^ a[i*8]
112 result_high = result_high ^ a[i*8 + 32]
113 comb += Assert(dut.o.o[0:32] == result_low)
114 comb += Assert(dut.o.o[32:64] == result_high)
115 with m.Case(InternalOp.OP_CNTZ):
116 XO = dut.fields.FormX.XO[0:-1]
117 with m.If(rec.is_32bit):
118 m.submodules.pe32 = pe32 = PriorityEncoder(32)
119 peo = Signal(range(0, 32+1))
120 with m.If(pe32.n):
121 comb += peo.eq(32)
122 with m.Else():
123 comb += peo.eq(pe32.o)
124 with m.If(XO[-1]): # cnttzw
125 comb += pe32.i.eq(a[0:32])
126 comb += Assert(dut.o.o == peo)
127 with m.Else(): # cntlzw
128 comb += pe32.i.eq(a[0:32][::-1])
129 comb += Assert(dut.o.o == peo)
130 with m.Else():
131 m.submodules.pe64 = pe64 = PriorityEncoder(64)
132 peo64 = Signal(7)
133 with m.If(pe64.n):
134 comb += peo64.eq(64)
135 with m.Else():
136 comb += peo64.eq(pe64.o)
137 with m.If(XO[-1]): # cnttzd
138 comb += pe64.i.eq(a[0:64])
139 comb += Assert(dut.o.o == peo64)
140 with m.Else(): # cntlzd
141 comb += pe64.i.eq(a[0:64][::-1])
142 comb += Assert(dut.o.o == peo64)
143
144
145 return m
146
147
148 class LogicalTestCase(FHDLTestCase):
149 def test_formal(self):
150 module = Driver()
151 self.assertFormal(module, mode="bmc", depth=2)
152 self.assertFormal(module, mode="cover", depth=2)
153 def test_ilang(self):
154 dut = Driver()
155 vl = rtlil.convert(dut, ports=[])
156 with open("main_stage.il", "w") as f:
157 f.write(vl)
158
159
160 if __name__ == '__main__':
161 unittest.main()