MUL pipeline: account for overflow flags. WIP
[soc.git] / src / soc / fu / mul / formal / proof_main_stage.py
1 # Proof of correctness for partitioned equal signal combiner
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
3
4 from nmigen import (Module, Signal, Elaboratable, Mux, Cat, Repl,
5 signed)
6 from nmigen.asserts import Assert, AnyConst, Assume, Cover
7 from nmutil.formaltest import FHDLTestCase
8 from nmutil.stageapi import StageChain
9 from nmigen.cli import rtlil
10
11 from soc.fu.mul.pipe_data import CompMULOpSubset, MulPipeSpec
12 from soc.fu.mul.pre_stage import MulMainStage1
13 from soc.fu.mul.main_stage import MulMainStage2
14 from soc.fu.mul.post_stage import MulMainStage3
15
16 from soc.decoder.power_enums import MicrOp
17 import unittest
18
19
20 # This defines a module to drive the device under test and assert
21 # properties about its outputs
22 class Driver(Elaboratable):
23 def __init__(self):
24 # inputs and outputs
25 pass
26
27 def elaborate(self, platform):
28 m = Module()
29 comb = m.d.comb
30
31 rec = CompMULOpSubset()
32
33 # Setup random inputs for dut.op
34 comb += rec.insn_type.eq(AnyConst(rec.insn_type.width))
35 comb += rec.fn_unit.eq(AnyConst(rec.fn_unit.width))
36 comb += rec.is_signed.eq(AnyConst(rec.is_signed.width))
37 comb += rec.is_32bit.eq(AnyConst(rec.is_32bit.width))
38 comb += rec.imm_data.imm.eq(AnyConst(64))
39 comb += rec.imm_data.imm_ok.eq(AnyConst(1))
40 # TODO, the rest of these. (the for-loop hides Assert-failures)
41
42 # set up the mul stages. do not add them to m.submodules, this
43 # is handled by StageChain.setup().
44 pspec = MulPipeSpec(id_wid=2)
45 pipe1 = MulMainStage1(pspec)
46 pipe2 = MulMainStage2(pspec)
47 pipe3 = MulMainStage3(pspec)
48
49 class Dummy: pass
50 dut = Dummy() # make a class into which dut.i and dut.o can be dropped
51 dut.i = pipe1.ispec()
52 chain = [pipe1, pipe2, pipe3] # chain of 3 mul stages
53
54 StageChain(chain).setup(m, dut.i) # input linked here, through chain
55 dut.o = chain[-1].o # output is the last thing in the chain...
56
57 # convenience variables
58 a = dut.i.ra
59 b = dut.i.rb
60
61 abs32_a = Signal(32)
62 abs32_b = Signal(32)
63 comb += abs32_a.eq(Mux(a[31], -a[0:32], a[0:32]))
64 comb += abs32_b.eq(Mux(b[31], -b[0:32], b[0:32]))
65
66 abs64_a = Signal(64)
67 abs64_b = Signal(64)
68 comb += abs64_a.eq(Mux(a[63], -a[0:64], a[0:64]))
69 comb += abs64_b.eq(Mux(b[63], -b[0:64], b[0:64]))
70
71 # setup random inputs
72 comb += [a.eq(AnyConst(64)),
73 b.eq(AnyConst(64)),
74 ]
75
76 comb += dut.i.ctx.op.eq(rec)
77
78 # Assert that op gets copied from the input to output
79 comb += Assert(dut.o.ctx.op == dut.i.ctx.op)
80 comb += Assert(dut.o.ctx.muxid == dut.i.ctx.muxid)
81
82 # Assert that XER_SO propagates through as well.
83 # Doesn't mean that the ok signal is always set though.
84 comb += Assert(dut.o.xer_so.data == dut.i.xer_so)
85
86 # main assertion of arithmetic operations
87 with m.Switch(rec.insn_type):
88 with m.Case(MicrOp.OP_MUL_H32):
89 comb += Assume(rec.is_32bit) # OP_MUL_H32 is a 32-bit op
90
91 expected_product = Signal(64)
92 expected_o = Signal.like(expected_product)
93
94 # unsigned hi32 - mulhwu
95 with m.If(~rec.is_signed):
96 comb += expected_product.eq(a[0:32] * b[0:32])
97 comb += expected_o.eq(Repl(expected_product[32:64], 2))
98 comb += Assert(dut.o.o.data[0:64] == expected_o)
99
100 # signed hi32 - mulhw
101 with m.Else():
102 prod = Signal.like(expected_product) # intermediate product
103 comb += prod.eq(abs32_a * abs32_b)
104 comb += expected_product.eq(Mux(a[31] ^ b[31], -prod, prod))
105 comb += expected_o.eq(Repl(expected_product[32:64], 2))
106 comb += Assert(dut.o.o.data[0:64] == expected_o)
107
108 with m.Case(MicrOp.OP_MUL_H64):
109 comb += Assume(~rec.is_32bit)
110
111 expected_product = Signal(128)
112
113 # unsigned hi64 - mulhdu
114 with m.If(~rec.is_signed):
115 comb += expected_product.eq(a[0:64] * b[0:64])
116 comb += Assert(dut.o.o.data[0:64] == expected_product[64:128])
117
118 # signed hi64 - mulhd
119 with m.Else():
120 prod = Signal.like(expected_product) # intermediate product
121 comb += prod.eq(abs64_a * abs64_b)
122 comb += expected_product.eq(Mux(a[63] ^ b[63], -prod, prod))
123 comb += Assert(dut.o.o.data[0:64] == expected_product[64:128])
124
125 # mulli, mullw(o)
126 with m.Case(MicrOp.OP_MUL_L64):
127 expected_product = Signal(64)
128 expected_ov = Signal()
129
130 with m.If(rec.is_32bit):
131 # unsigned lo64 - mulwu
132 with m.If(~rec.is_signed):
133 comb += expected_product.eq(a[0:32] * b[0:32])
134 comb += Assert(dut.o.o.data[0:64] == expected_product[0:64])
135
136 m31 = expected_product[31:64]
137 comb += expected_ov.eq(m31.bool() & ~m31.all())
138 comb += Assert(dut.o.xer_ov.data == Repl(expected_ov, 2))
139
140 # signed lo64 - mulw
141 with m.Else():
142 pass
143
144 with m.Else(): # is 64-bit
145 pass
146
147 return m
148
149
150 class MulTestCase(FHDLTestCase):
151 def test_formal(self):
152 module = Driver()
153 self.assertFormal(module, mode="bmc", depth=2)
154 self.assertFormal(module, mode="cover", depth=2)
155 def test_ilang(self):
156 dut = Driver()
157 vl = rtlil.convert(dut, ports=[])
158 with open("main_stage.il", "w") as f:
159 f.write(vl)
160
161
162 if __name__ == '__main__':
163 unittest.main()