1 # Proof of correctness for multiplier
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
3 # Copyright (C) 2020 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
4 # Copyright (C) 2020 Samuel A. Falvo II <kc5tja@arrl.net>
6 """Formal Correctness Proof for POWER9 multiplier
8 notes for ov/32. similar logic applies for 64-bit quantities (m63)
10 m31 = exp_prod[31:64]
11 comb += expected_ov.eq(m31.bool() & ~m31.all())
13 If the instruction enables the OV and OV32 flags to be
14 set, then we must set them both to 1 if and only if
15 the resulting product *cannot* be contained within a
16 32-bit quantity.
18 This is detected by checking to see if the resulting
19 upper bits are either all 1s or all 0s. If even *one*
20 bit in this set differs from its peers, then we know
21 the signed value cannot be contained in the destination's
22 field width.
24 m31.bool() is true if *any* high bit is set.
25 m31.all() is true if *all* high bits are set.
27 m31.bool() m31.all() Meaning
28 0 x All upper bits are 0, so product
29 is positive. Thus, it fits.
30 1 0 At least *one* high bit is clear.
31 Implying, not all high bits are
32 clones of the output sign bit.
33 Thus, product extends beyond
34 destination register size.
35 1 1 All high bits are set *and* they
36 match the sign bit. The number
37 is properly negative, and fits
38 in the destination register width.
40 Note that OV/OV32 are set to the *inverse* of m31.all(),
41 hence the expression m31.bool() & ~m31.all().
42 """
45 from nmigen import (Module, Signal, Elaboratable, Mux, Cat, Repl,
46 signed)
47 from nmigen.asserts import Assert, AnyConst, Assume, Cover
48 from nmutil.formaltest import FHDLTestCase
49 from nmutil.stageapi import StageChain
50 from nmigen.cli import rtlil
52 from soc.decoder.power_fields import DecodeFields
53 from soc.decoder.power_fieldsn import SignalBitRange
55 from soc.fu.mul.pipe_data import CompMULOpSubset, MulPipeSpec
56 from soc.fu.mul.pre_stage import MulMainStage1
57 from soc.fu.mul.main_stage import MulMainStage2
58 from soc.fu.mul.post_stage import MulMainStage3
60 from soc.decoder.power_enums import MicrOp
61 import unittest
64 # This defines a module to drive the device under test and assert
65 # properties about its outputs
66 class Driver(Elaboratable):
67 def __init__(self):
68 # inputs and outputs
69 pass
71 def elaborate(self, platform):
72 m = Module()
73 comb = m.d.comb
75 rec = CompMULOpSubset()
77 # Setup random inputs for dut.op
78 comb += rec.insn_type.eq(AnyConst(rec.insn_type.width))
79 comb += rec.fn_unit.eq(AnyConst(rec.fn_unit.width))
80 comb += rec.is_signed.eq(AnyConst(rec.is_signed.width))
81 comb += rec.is_32bit.eq(AnyConst(rec.is_32bit.width))
82 comb += rec.imm_data.imm.eq(AnyConst(64))
83 comb += rec.imm_data.imm_ok.eq(AnyConst(1))
85 # set up the mul stages. do not add them to m.submodules, this
86 # is handled by StageChain.setup().
87 pspec = MulPipeSpec(id_wid=2)
88 pipe1 = MulMainStage1(pspec)
89 pipe2 = MulMainStage2(pspec)
90 pipe3 = MulMainStage3(pspec)
92 class Dummy: pass
93 dut = Dummy() # make a class into which dut.i and dut.o can be dropped
94 dut.i = pipe1.ispec()
95 chain = [pipe1, pipe2, pipe3] # chain of 3 mul stages
97 StageChain(chain).setup(m, dut.i) # input linked here, through chain
98 dut.o = chain[-1].o # output is the last thing in the chain...
100 # convenience variables
101 a = dut.i.ra
102 b = dut.i.rb
103 o = dut.o.o.data
104 xer_ov_o = dut.o.xer_ov.data
105 xer_ov_ok = dut.o.xer_ov.ok
107 # For 32- and 64-bit parameters, work out the absolute values of the
108 # input parameters for signed multiplies. Needed for signed
109 # multiplication.
111 abs32_a = Signal(32)
112 abs32_b = Signal(32)
113 abs64_a = Signal(64)
114 abs64_b = Signal(64)
115 a32_s = Signal(1)
116 b32_s = Signal(1)
117 a64_s = Signal(1)
118 b64_s = Signal(1)
120 comb += a32_s.eq(a[31])
121 comb += b32_s.eq(b[31])
122 comb += a64_s.eq(a[63])
123 comb += b64_s.eq(b[63])
125 comb += abs32_a.eq(Mux(a32_s, -a[0:32], a[0:32]))
126 comb += abs32_b.eq(Mux(b32_s, -b[0:32], b[0:32]))
127 comb += abs64_a.eq(Mux(a64_s, -a[0:64], a[0:64]))
128 comb += abs64_b.eq(Mux(b64_s, -b[0:64], b[0:64]))
130 # For 32- and 64-bit quantities, break out whether signs differ.
131 # (the _sne suffix is read as "signs not equal").
132 #
133 # This is required because of the rules of signed multiplication:
134 #
135 # a*b = +(abs(a)*abs(b)) for two positive numbers a and b.
136 # a*b = -(abs(a)*abs(b)) for any one positive number and negative
137 # number.
138 # a*b = +(abs(a)*abs(b)) for two negative numbers a and b.
140 ab32_sne = Signal()
141 ab64_sne = Signal()
142 comb += ab32_sne.eq(a32_s ^ b32_s)
143 comb += ab64_sne.eq(a64_s ^ b64_s)
145 # setup random inputs
146 comb += [a.eq(AnyConst(64)),
147 b.eq(AnyConst(64)),
148 ]
150 comb += dut.i.ctx.op.eq(rec)
152 # check overflow and result flags
153 result_ok = Signal()
154 enable_overflow = Signal()
156 # default to 1, disabled if default case is hit
157 comb += result_ok.eq(1)
159 # Assert that op gets copied from the input to output
160 comb += Assert(dut.o.ctx.op == dut.i.ctx.op)
161 comb += Assert(dut.o.ctx.muxid == dut.i.ctx.muxid)
163 # Assert that XER_SO propagates through as well.
164 comb += Assert(dut.o.xer_so == dut.i.xer_so)
166 # main assertion of arithmetic operations
167 with m.Switch(rec.insn_type):
169 ###### HI-32 #####
171 with m.Case(MicrOp.OP_MUL_H32):
172 comb += Assume(rec.is_32bit) # OP_MUL_H32 is a 32-bit op
174 exp_prod = Signal(64)
175 expected_o = Signal.like(exp_prod)
177 # unsigned hi32 - mulhwu
178 with m.If(~rec.is_signed):
179 comb += exp_prod.eq(a[0:32] * b[0:32])
180 comb += expected_o.eq(Repl(exp_prod[32:64], 2))
181 comb += Assert(o[0:64] == expected_o)
183 # signed hi32 - mulhw
184 with m.Else():
185 # Per rules of signed multiplication, if input signs
186 # differ, we negate the product. This implies that
187 # the product is calculated from the absolute values
188 # of the inputs.
189 prod = Signal.like(exp_prod) # intermediate product
190 comb += prod.eq(abs32_a * abs32_b)
191 comb += exp_prod.eq(Mux(ab32_sne, -prod, prod))
192 comb += expected_o.eq(Repl(exp_prod[32:64], 2))
193 comb += Assert(o[0:64] == expected_o)
195 ###### HI-64 #####
197 with m.Case(MicrOp.OP_MUL_H64):
198 comb += Assume(~rec.is_32bit)
200 exp_prod = Signal(128)
202 # unsigned hi64 - mulhdu
203 with m.If(~rec.is_signed):
204 comb += exp_prod.eq(a[0:64] * b[0:64])
205 comb += Assert(o[0:64] == exp_prod[64:128])
207 # signed hi64 - mulhd
208 with m.Else():
209 # Per rules of signed multiplication, if input signs
210 # differ, we negate the product. This implies that
211 # the product is calculated from the absolute values
212 # of the inputs.
213 prod = Signal.like(exp_prod) # intermediate product
214 comb += prod.eq(abs64_a * abs64_b)
215 comb += exp_prod.eq(Mux(ab64_sne, -prod, prod))
216 comb += Assert(o[0:64] == exp_prod[64:128])
218 ###### LO-64 #####
219 # mulli, mullw(o)(u), mulld(o)
221 with m.Case(MicrOp.OP_MUL_L64):
223 with m.If(rec.is_32bit): # 32-bit mode
224 expected_ov = Signal()
225 prod = Signal(64)
226 exp_prod = Signal.like(prod)
228 # unsigned lo32 - mullwu
229 with m.If(~rec.is_signed):
230 comb += exp_prod.eq(a[0:32] * b[0:32])
231 comb += Assert(o[0:64] == exp_prod[0:64])
233 # signed lo32 - mullw
234 with m.Else():
235 # Per rules of signed multiplication, if input signs
236 # differ, we negate the product. This implies that
237 # the product is calculated from the absolute values
238 # of the inputs.
239 comb += prod.eq(abs32_a[0:64] * abs32_b[0:64])
240 comb += exp_prod.eq(Mux(ab32_sne, -prod, prod))
241 comb += Assert(o[0:64] == exp_prod[0:64])
243 # see notes on overflow detection, above
244 m31 = exp_prod[31:64]
245 comb += expected_ov.eq(m31.bool() & ~m31.all())
246 comb += enable_overflow.eq(1)
247 comb += Assert(xer_ov_o == Repl(expected_ov, 2))
249 with m.Else(): # is 64-bit; mulld
250 expected_ov = Signal()
251 prod = Signal(128)
252 exp_prod = Signal.like(prod)
254 # From my reading of the v3.0B ISA spec,
255 # only signed instructions exist.
256 #
257 # Per rules of signed multiplication, if input signs
258 # differ, we negate the product. This implies that
259 # the product is calculated from the absolute values
260 # of the inputs.
261 comb += Assume(rec.is_signed)
262 comb += prod.eq(abs64_a[0:64] * abs64_b[0:64])
263 comb += exp_prod.eq(Mux(ab64_sne, -prod, prod))
264 comb += Assert(o[0:64] == exp_prod[0:64])
266 # see notes on overflow detection, above
267 m63 = exp_prod[63:128]
268 comb += expected_ov.eq(m63.bool() & ~m63.all())
269 comb += enable_overflow.eq(1)
270 comb += Assert(xer_ov_o == Repl(expected_ov, 2))
272 # not any of the cases above, disable result checking
273 with m.Default():
274 comb += result_ok.eq(0)
276 # check result "write" is correctly requested
277 comb += Assert(dut.o.o.ok == result_ok)
278 comb += Assert(xer_ov_ok == enable_overflow)
280 return m
283 class MulTestCase(FHDLTestCase):
284 def test_formal(self):
285 module = Driver()
286 self.assertFormal(module, mode="bmc", depth=2)
287 self.assertFormal(module, mode="cover", depth=2)
288 def test_ilang(self):
289 dut = Driver()
290 vl = rtlil.convert(dut, ports=[])
291 with open("main_stage.il", "w") as f:
292 f.write(vl)
295 if __name__ == '__main__':
296 unittest.main()