FSGNJ: expandd formal proof to 16 and 64 bits
[ieee754fpu.git] / src / ieee754 / fsgnj / formal / proof_fsgnj_mod.py
1 # Proof of correctness for FSGNJ module
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
3
4 from nmigen import Module, Signal, Elaboratable
5 from nmigen.asserts import Assert, Assume
6 from nmigen.cli import rtlil
7
8 from ieee754.fpcommon.fpbase import FPNumDecode, FPNumBaseRecord
9 from ieee754.fsgnj.fsgnj import FSGNJPipeMod
10 from ieee754.pipeline import PipelineSpec
11 import subprocess
12
13
14 # This defines a module to drive the device under test and assert
15 # properties about its outputs
16 class FSGNJDriver(Elaboratable):
17 def __init__(self, pspec):
18 # inputs and outputs
19 self.pspec = pspec
20 self.a = Signal(pspec.width)
21 self.b = Signal(pspec.width)
22 self.z = Signal(pspec.width)
23 self.opc = Signal(pspec.op_wid)
24 self.muxid = Signal(pspec.id_wid)
25
26 def elaborate(self, platform):
27 m = Module()
28
29 m.submodules.dut = dut = FSGNJPipeMod(self.pspec)
30
31 a1 = FPNumBaseRecord(self.pspec.width, False)
32 b1 = FPNumBaseRecord(self.pspec.width, False)
33 z1 = FPNumBaseRecord(self.pspec.width, False)
34 m.submodules.sc_decode_a = a1 = FPNumDecode(None, a1)
35 m.submodules.sc_decode_b = b1 = FPNumDecode(None, b1)
36 m.submodules.sc_decode_z = z1 = FPNumDecode(None, z1)
37
38 m.d.comb += [a1.v.eq(self.a),
39 b1.v.eq(self.b),
40 z1.v.eq(self.z)]
41
42 # connect up the inputs and outputs. I think these could
43 # theoretically be $anyconst/$anysync but I'm not sure nmigen
44 # has support for that
45 m.d.comb += dut.i.a.eq(self.a)
46 m.d.comb += dut.i.b.eq(self.b)
47 m.d.comb += dut.i.ctx.op.eq(self.opc)
48 m.d.comb += dut.i.muxid.eq(self.muxid)
49 m.d.comb += self.z.eq(dut.o.z)
50
51 # Since the RISCV spec doesn't define what FSGNJ with a funct3
52 # field of 0b011 throug 0b111 does, they should be ignored.
53 m.d.comb += Assume(self.opc != 0b11)
54
55 # The RISCV spec (page 70) says FSGNJ "produces a result that
56 # takes all buts except the sign bit from [operand 1]". This
57 # asserts that that holds true
58 m.d.comb += Assert(z1.e == a1.e)
59 m.d.comb += Assert(z1.m == a1.m)
60
61 with m.Switch(self.opc):
62
63 # The RISCV Spec (page 70) states that for FSGNJ (opcode
64 # 0b00 in this case) "the result's sign bit is [operand
65 # 2's] sign bit"
66 with m.Case(0b00):
67 m.d.comb += Assert(z1.s == b1.s)
68
69 # The RISCV Spec (page 70) states that for FSGNJN (opcode
70 # 0b01 in this case) "the result's sign bit is the opposite
71 # of [operand 2's] sign bit"
72 with m.Case(0b01):
73 m.d.comb += Assert(z1.s == ~b1.s)
74 # The RISCV Spec (page 70) states that for FSGNJX (opcode
75 # 0b10 in this case) "the result's sign bit is the XOR of
76 # the sign bits of [operand 1] and [operand 2]"
77 with m.Case(0b10):
78 m.d.comb += Assert(z1.s == (a1.s ^ b1.s))
79
80 return m
81
82 def ports(self):
83 return [self.a, self.b, self.z, self.opc, self.muxid]
84
85
86 def run_test(bits=32):
87 m = FSGNJDriver(PipelineSpec(bits, 2, 2))
88
89 il = rtlil.convert(m, ports=m.ports())
90 with open("proof.il", "w") as f:
91 f.write(il)
92 p = subprocess.Popen(['sby', '-f', 'proof.sby'],
93 stdout=subprocess.PIPE,
94 stderr=subprocess.PIPE)
95 if p.wait() == 0:
96 out, _ = p.communicate()
97 print("Proof successful!")
98 else:
99 print("Proof failed")
100 out, err = p.communicate()
101 print(out.decode('utf-8'))
102 print(err.decode('utf-8'))
103
104
105 if __name__ == "__main__":
106 run_test(bits=32)
107 run_test(bits=16)
108 run_test(bits=64)