switch to exact version of cython
[ieee754fpu.git] / src / ieee754 / part_cmp / experiments / formal / proof_equal.py
1 # Proof of correctness for partitioned equals module
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
3
4 from nmigen import Module, Signal, Elaboratable, Mux
5 from nmigen.asserts import Assert, AnyConst, Assume
6 from nmutil.formaltest import FHDLTestCase
7 from nmigen.cli import rtlil
8
9 from ieee754.part_mul_add.partpoints import PartitionPoints
10 from ieee754.part_cmp.experiments.equal_ortree import PartitionedEq
11 import unittest
12
13
14 # This defines a module to drive the device under test and assert
15 # properties about its outputs
16 class EqualsDriver(Elaboratable):
17 def __init__(self):
18 # inputs and outputs
19 pass
20
21 def get_intervals(self, signal, points):
22 start = 0
23 interval = []
24 keys = list(points.keys()) + [signal.width]
25 for key in keys:
26 end = key
27 interval.append(signal[start:end])
28 start = end
29 return interval
30
31 def elaborate(self, platform):
32 m = Module()
33 comb = m.d.comb
34 width = 32
35 mwidth = 4
36
37 # setup the inputs and outputs of the DUT as anyconst
38 a = Signal(width)
39 b = Signal(width)
40 points = PartitionPoints()
41 gates = Signal(mwidth-1)
42 for i in range(mwidth-1):
43 points[i*4+4] = gates[i]
44 out = Signal(mwidth)
45
46 comb += [a.eq(AnyConst(width)),
47 b.eq(AnyConst(width)),
48 gates.eq(AnyConst(mwidth-1))]
49
50 m.submodules.dut = dut = PartitionedEq(width, points)
51
52 a_intervals = self.get_intervals(a, points)
53 b_intervals = self.get_intervals(b, points)
54
55 with m.Switch(gates):
56 with m.Case(0b000):
57 comb += Assert(out == (a == b))
58 with m.Case(0b001):
59 comb += Assert(out[1] == ((a_intervals[1] == b_intervals[1]) &
60 (a_intervals[2] == b_intervals[2]) &
61 (a_intervals[3] == b_intervals[3])))
62 comb += Assert(out[0] == (a_intervals[0] == b_intervals[0]))
63 with m.Case(0b010):
64 comb += Assert(out[2] == ((a_intervals[2] == b_intervals[2]) &
65 (a_intervals[3] == b_intervals[3])))
66 comb += Assert(out[0] == ((a_intervals[0] == b_intervals[0]) &
67 (a_intervals[1] == b_intervals[1])))
68 with m.Case(0b011):
69 comb += Assert(out[2] == ((a_intervals[2] == b_intervals[2]) &
70 (a_intervals[3] == b_intervals[3])))
71 comb += Assert(out[0] == (a_intervals[0] == b_intervals[0]))
72 comb += Assert(out[1] == (a_intervals[1] == b_intervals[1]))
73 with m.Case(0b100):
74 comb += Assert(out[0] == ((a_intervals[0] == b_intervals[0]) &
75 (a_intervals[1] == b_intervals[1]) &
76 (a_intervals[2] == b_intervals[2])))
77 comb += Assert(out[3] == (a_intervals[3] == b_intervals[3]))
78 with m.Case(0b101):
79 comb += Assert(out[1] == ((a_intervals[1] == b_intervals[1]) &
80 (a_intervals[2] == b_intervals[2])))
81 comb += Assert(out[3] == (a_intervals[3] == b_intervals[3]))
82 comb += Assert(out[0] == (a_intervals[0] == b_intervals[0]))
83 with m.Case(0b110):
84 comb += Assert(out[0] == ((a_intervals[0] == b_intervals[0]) &
85 (a_intervals[1] == b_intervals[1])))
86 comb += Assert(out[3] == (a_intervals[3] == b_intervals[3]))
87 comb += Assert(out[2] == (a_intervals[2] == b_intervals[2]))
88 with m.Case(0b111):
89 for i in range(mwidth-1):
90 comb += Assert(out[i] == (a_intervals[i] == b_intervals[i]))
91
92
93
94 comb += [dut.a.eq(a),
95 dut.b.eq(b),
96 out.eq(dut.output)]
97 return m
98
99 class PartitionedEqTestCase(FHDLTestCase):
100 def test_eq(self):
101 module = EqualsDriver()
102 self.assertFormal(module, mode="bmc", depth=4)
103
104 if __name__ == "__main__":
105 unittest.main()
106