Allow the formal engine to perform a same-cycle result in the ALU
[soc.git] / src / soc / fu / mul / formal / proof_main_stage.py
1 # Proof of correctness for multiplier
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
3 # Copyright (C) 2020 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
4 # Copyright (C) 2020 Samuel A. Falvo II <kc5tja@arrl.net>
5
6 """Formal Correctness Proof for POWER9 multiplier
7
8 notes for ov/32. similar logic applies for 64-bit quantities (m63)
9
10 m31 = exp_prod[31:64]
11 comb += expected_ov.eq(m31.bool() & ~m31.all())
12
13 If the instruction enables the OV and OV32 flags to be
14 set, then we must set them both to 1 if and only if
15 the resulting product *cannot* be contained within a
16 32-bit quantity.
17
18 This is detected by checking to see if the resulting
19 upper bits are either all 1s or all 0s. If even *one*
20 bit in this set differs from its peers, then we know
21 the signed value cannot be contained in the destination's
22 field width.
23
24 m31.bool() is true if *any* high bit is set.
25 m31.all() is true if *all* high bits are set.
26
27 m31.bool() m31.all() Meaning
28 0 x All upper bits are 0, so product
29 is positive. Thus, it fits.
30 1 0 At least *one* high bit is clear.
31 Implying, not all high bits are
32 clones of the output sign bit.
33 Thus, product extends beyond
34 destination register size.
35 1 1 All high bits are set *and* they
36 match the sign bit. The number
37 is properly negative, and fits
38 in the destination register width.
39
40 Note that OV/OV32 are set to the *inverse* of m31.all(),
41 hence the expression m31.bool() & ~m31.all().
42 """
43
44
45 from nmigen import (Module, Signal, Elaboratable, Mux, Cat, Repl,
46 signed)
47 from nmigen.asserts import Assert, AnyConst, Assume, Cover
48 from nmutil.formaltest import FHDLTestCase
49 from nmutil.stageapi import StageChain
50 from nmigen.cli import rtlil
51
52 from soc.decoder.power_fields import DecodeFields
53 from soc.decoder.power_fieldsn import SignalBitRange
54
55 from soc.fu.mul.pipe_data import CompMULOpSubset, MulPipeSpec
56 from soc.fu.mul.pre_stage import MulMainStage1
57 from soc.fu.mul.main_stage import MulMainStage2
58 from soc.fu.mul.post_stage import MulMainStage3
59
60 from soc.decoder.power_enums import MicrOp
61 import unittest
62
63
64 # This defines a module to drive the device under test and assert
65 # properties about its outputs
66 class Driver(Elaboratable):
67 def __init__(self):
68 # inputs and outputs
69 pass
70
71 def elaborate(self, platform):
72 m = Module()
73 comb = m.d.comb
74
75 rec = CompMULOpSubset()
76
77 # Setup random inputs for dut.op
78 comb += rec.insn_type.eq(AnyConst(rec.insn_type.width))
79 comb += rec.fn_unit.eq(AnyConst(rec.fn_unit.width))
80 comb += rec.is_signed.eq(AnyConst(rec.is_signed.width))
81 comb += rec.is_32bit.eq(AnyConst(rec.is_32bit.width))
82 comb += rec.imm_data.imm.eq(AnyConst(64))
83 comb += rec.imm_data.imm_ok.eq(AnyConst(1))
84
85 # set up the mul stages. do not add them to m.submodules, this
86 # is handled by StageChain.setup().
87 pspec = MulPipeSpec(id_wid=2)
88 pipe1 = MulMainStage1(pspec)
89 pipe2 = MulMainStage2(pspec)
90 pipe3 = MulMainStage3(pspec)
91
92 class Dummy: pass
93 dut = Dummy() # make a class into which dut.i and dut.o can be dropped
94 dut.i = pipe1.ispec()
95 chain = [pipe1, pipe2, pipe3] # chain of 3 mul stages
96
97 StageChain(chain).setup(m, dut.i) # input linked here, through chain
98 dut.o = chain[-1].o # output is the last thing in the chain...
99
100 # convenience variables
101 a = dut.i.ra
102 b = dut.i.rb
103 o = dut.o.o.data
104 xer_ov_o = dut.o.xer_ov.data
105 xer_ov_ok = dut.o.xer_ov.ok
106
107 # For 32- and 64-bit parameters, work out the absolute values of the
108 # input parameters for signed multiplies. Needed for signed
109 # multiplication.
110
111 abs32_a = Signal(32)
112 abs32_b = Signal(32)
113 abs64_a = Signal(64)
114 abs64_b = Signal(64)
115 a32_s = Signal(1)
116 b32_s = Signal(1)
117 a64_s = Signal(1)
118 b64_s = Signal(1)
119
120 comb += a32_s.eq(a[31])
121 comb += b32_s.eq(b[31])
122 comb += a64_s.eq(a[63])
123 comb += b64_s.eq(b[63])
124
125 comb += abs32_a.eq(Mux(a32_s, -a[0:32], a[0:32]))
126 comb += abs32_b.eq(Mux(b32_s, -b[0:32], b[0:32]))
127 comb += abs64_a.eq(Mux(a64_s, -a[0:64], a[0:64]))
128 comb += abs64_b.eq(Mux(b64_s, -b[0:64], b[0:64]))
129
130 # For 32- and 64-bit quantities, break out whether signs differ.
131 # (the _sne suffix is read as "signs not equal").
132 #
133 # This is required because of the rules of signed multiplication:
134 #
135 # a*b = +(abs(a)*abs(b)) for two positive numbers a and b.
136 # a*b = -(abs(a)*abs(b)) for any one positive number and negative
137 # number.
138 # a*b = +(abs(a)*abs(b)) for two negative numbers a and b.
139
140 ab32_sne = Signal()
141 ab64_sne = Signal()
142 comb += ab32_sne.eq(a32_s ^ b32_s)
143 comb += ab64_sne.eq(a64_s ^ b64_s)
144
145 # setup random inputs
146 comb += [a.eq(AnyConst(64)),
147 b.eq(AnyConst(64)),
148 ]
149
150 comb += dut.i.ctx.op.eq(rec)
151
152 # check overflow and result flags
153 result_ok = Signal()
154 enable_overflow = Signal()
155
156 # default to 1, disabled if default case is hit
157 comb += result_ok.eq(1)
158
159 # Assert that op gets copied from the input to output
160 comb += Assert(dut.o.ctx.op == dut.i.ctx.op)
161 comb += Assert(dut.o.ctx.muxid == dut.i.ctx.muxid)
162
163 # Assert that XER_SO propagates through as well.
164 comb += Assert(dut.o.xer_so == dut.i.xer_so)
165
166 # main assertion of arithmetic operations
167 with m.Switch(rec.insn_type):
168
169 ###### HI-32 #####
170
171 with m.Case(MicrOp.OP_MUL_H32):
172 comb += Assume(rec.is_32bit) # OP_MUL_H32 is a 32-bit op
173
174 exp_prod = Signal(64)
175 expected_o = Signal.like(exp_prod)
176
177 # unsigned hi32 - mulhwu
178 with m.If(~rec.is_signed):
179 comb += exp_prod.eq(a[0:32] * b[0:32])
180 comb += expected_o.eq(Repl(exp_prod[32:64], 2))
181 comb += Assert(o[0:64] == expected_o)
182
183 # signed hi32 - mulhw
184 with m.Else():
185 # Per rules of signed multiplication, if input signs
186 # differ, we negate the product. This implies that
187 # the product is calculated from the absolute values
188 # of the inputs.
189 prod = Signal.like(exp_prod) # intermediate product
190 comb += prod.eq(abs32_a * abs32_b)
191 comb += exp_prod.eq(Mux(ab32_sne, -prod, prod))
192 comb += expected_o.eq(Repl(exp_prod[32:64], 2))
193 comb += Assert(o[0:64] == expected_o)
194
195 ###### HI-64 #####
196
197 with m.Case(MicrOp.OP_MUL_H64):
198 comb += Assume(~rec.is_32bit)
199
200 exp_prod = Signal(128)
201
202 # unsigned hi64 - mulhdu
203 with m.If(~rec.is_signed):
204 comb += exp_prod.eq(a[0:64] * b[0:64])
205 comb += Assert(o[0:64] == exp_prod[64:128])
206
207 # signed hi64 - mulhd
208 with m.Else():
209 # Per rules of signed multiplication, if input signs
210 # differ, we negate the product. This implies that
211 # the product is calculated from the absolute values
212 # of the inputs.
213 prod = Signal.like(exp_prod) # intermediate product
214 comb += prod.eq(abs64_a * abs64_b)
215 comb += exp_prod.eq(Mux(ab64_sne, -prod, prod))
216 comb += Assert(o[0:64] == exp_prod[64:128])
217
218 ###### LO-64 #####
219 # mulli, mullw(o)(u), mulld(o)
220
221 with m.Case(MicrOp.OP_MUL_L64):
222
223 with m.If(rec.is_32bit): # 32-bit mode
224 expected_ov = Signal()
225 prod = Signal(64)
226 exp_prod = Signal.like(prod)
227
228 # unsigned lo32 - mullwu
229 with m.If(~rec.is_signed):
230 comb += exp_prod.eq(a[0:32] * b[0:32])
231 comb += Assert(o[0:64] == exp_prod[0:64])
232
233 # signed lo32 - mullw
234 with m.Else():
235 # Per rules of signed multiplication, if input signs
236 # differ, we negate the product. This implies that
237 # the product is calculated from the absolute values
238 # of the inputs.
239 comb += prod.eq(abs32_a[0:64] * abs32_b[0:64])
240 comb += exp_prod.eq(Mux(ab32_sne, -prod, prod))
241 comb += Assert(o[0:64] == exp_prod[0:64])
242
243 # see notes on overflow detection, above
244 m31 = exp_prod[31:64]
245 comb += expected_ov.eq(m31.bool() & ~m31.all())
246 comb += enable_overflow.eq(1)
247 comb += Assert(xer_ov_o == Repl(expected_ov, 2))
248
249 with m.Else(): # is 64-bit; mulld
250 expected_ov = Signal()
251 prod = Signal(128)
252 exp_prod = Signal.like(prod)
253
254 # From my reading of the v3.0B ISA spec,
255 # only signed instructions exist.
256 #
257 # Per rules of signed multiplication, if input signs
258 # differ, we negate the product. This implies that
259 # the product is calculated from the absolute values
260 # of the inputs.
261 comb += Assume(rec.is_signed)
262 comb += prod.eq(abs64_a[0:64] * abs64_b[0:64])
263 comb += exp_prod.eq(Mux(ab64_sne, -prod, prod))
264 comb += Assert(o[0:64] == exp_prod[0:64])
265
266 # see notes on overflow detection, above
267 m63 = exp_prod[63:128]
268 comb += expected_ov.eq(m63.bool() & ~m63.all())
269 comb += enable_overflow.eq(1)
270 comb += Assert(xer_ov_o == Repl(expected_ov, 2))
271
272 # not any of the cases above, disable result checking
273 with m.Default():
274 comb += result_ok.eq(0)
275
276 # check result "write" is correctly requested
277 comb += Assert(dut.o.o.ok == result_ok)
278 comb += Assert(xer_ov_ok == enable_overflow)
279
280 return m
281
282
283 class MulTestCase(FHDLTestCase):
284 def test_formal(self):
285 module = Driver()
286 self.assertFormal(module, mode="bmc", depth=2)
287 self.assertFormal(module, mode="cover", depth=2)
288 def test_ilang(self):
289 dut = Driver()
290 vl = rtlil.convert(dut, ports=[])
291 with open("main_stage.il", "w") as f:
292 f.write(vl)
293
294
295 if __name__ == '__main__':
296 unittest.main()