fix some borked imports
[ieee754fpu.git] / src / ieee754 / part_cmp / formal / proof_eq_gt_ge.py
1 # Proof of correctness for partitioned equals module
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
3
4 from nmigen import Module, Signal, Elaboratable, Mux, Cat
5 from nmigen.asserts import Assert, AnyConst, Assume
6 from nmutil.formaltest import FHDLTestCase
7 from nmigen.cli import rtlil
8
9 from ieee754.part_mul_add.partpoints import PartitionPoints
10 from ieee754.part_cmp.eq_gt_ge import PartitionedEqGtGe
11 import unittest
12
13
14 # This defines a module to drive the device under test and assert
15 # properties about its outputs
16 class EqualsDriver(Elaboratable):
17 def __init__(self):
18 # inputs and outputs
19 pass
20
21 def get_intervals(self, signal, points):
22 start = 0
23 interval = []
24 keys = list(points.keys()) + [signal.width]
25 for key in keys:
26 end = key
27 interval.append(signal[start:end])
28 start = end
29 return interval
30
31 def elaborate(self, platform):
32 m = Module()
33 comb = m.d.comb
34 width = 24
35 mwidth = 3
36
37 # setup the inputs and outputs of the DUT as anyconst
38 a = Signal(width)
39 b = Signal(width)
40 points = PartitionPoints()
41 gates = Signal(mwidth-1)
42 opcode = Signal(2)
43 for i in range(mwidth-1):
44 points[i*8+8] = gates[i]
45 out = Signal(mwidth)
46
47 comb += [a.eq(AnyConst(width)),
48 b.eq(AnyConst(width)),
49 opcode.eq(AnyConst(opcode.width)),
50 gates.eq(AnyConst(mwidth-1))]
51
52 m.submodules.dut = dut = PartitionedEqGtGe(width, points)
53
54 a_intervals = self.get_intervals(a, points)
55 b_intervals = self.get_intervals(b, points)
56
57 with m.If(opcode == 0b00):
58 with m.Switch(gates):
59 with m.Case(0b00):
60 comb += Assert(out[0] == (a == b))
61 comb += Assert(out[1] == out[0])
62 comb += Assert(out[2] == out[1])
63 with m.Case(0b01):
64 comb += Assert(out[0] == (a_intervals[0] == b_intervals[0]))
65 comb += Assert(out[1] == ((a_intervals[1] == \
66 b_intervals[1]) &
67 (a_intervals[2] == \
68 b_intervals[2])))
69 comb += Assert(out[2] == out[1])
70 with m.Case(0b10):
71 comb += Assert(out[0] == ((a_intervals[0] == \
72 b_intervals[0]) &
73 (a_intervals[1] == \
74 b_intervals[1])))
75 comb += Assert(out[1] == out[0])
76 comb += Assert(out[2] == (a_intervals[2] == b_intervals[2]))
77 with m.Case(0b11):
78 for i in range(mwidth-1):
79 comb += Assert(out[i] == \
80 (a_intervals[i] == b_intervals[i]))
81 with m.If(opcode == 0b01):
82 with m.Switch(gates):
83 with m.Case(0b00):
84 comb += Assert(out[0] == (a > b))
85 comb += Assert(out[1] == out[0])
86 comb += Assert(out[2] == out[1])
87 with m.Case(0b01):
88 comb += Assert(out[0] == (a_intervals[0] > b_intervals[0]))
89
90 comb += Assert(out[1] == (Cat(*a_intervals[1:3]) > \
91 Cat(*b_intervals[1:3])))
92 comb += Assert(out[2] == out[1])
93 with m.Case(0b10):
94 comb += Assert(out[0] == (Cat(*a_intervals[0:2]) > \
95 Cat(*b_intervals[0:2])))
96 comb += Assert(out[1] == out[0])
97 comb += Assert(out[2] == (a_intervals[2] > b_intervals[2]))
98 with m.Case(0b11):
99 for i in range(mwidth-1):
100 comb += Assert(out[i] == (a_intervals[i] > \
101 b_intervals[i]))
102 with m.If(opcode == 0b10):
103 with m.Switch(gates):
104 with m.Case(0b00):
105 comb += Assert(out[0] == (a >= b))
106 comb += Assert(out[1] == out[0])
107 comb += Assert(out[2] == out[1])
108 with m.Case(0b01):
109 comb += Assert(out[0] == (a_intervals[0] >= b_intervals[0]))
110
111 comb += Assert(out[1] == (Cat(*a_intervals[1:3]) >= \
112 Cat(*b_intervals[1:3])))
113 comb += Assert(out[2] == out[1])
114 with m.Case(0b10):
115 comb += Assert(out[0] == (Cat(*a_intervals[0:2]) >= \
116 Cat(*b_intervals[0:2])))
117 comb += Assert(out[1] == out[0])
118 comb += Assert(out[2] == (a_intervals[2] >= b_intervals[2]))
119 with m.Case(0b11):
120 for i in range(mwidth-1):
121 comb += Assert(out[i] == \
122 (a_intervals[i] >= b_intervals[i]))
123
124
125
126 comb += [dut.a.eq(a),
127 dut.b.eq(b),
128 dut.opcode.eq(opcode),
129 out.eq(dut.output)]
130 return m
131
132 class PartitionedEqTestCase(FHDLTestCase):
133 def test_eq(self):
134 module = EqualsDriver()
135 self.assertFormal(module, mode="bmc", depth=4)
136
137 if __name__ == "__main__":
138 unittest.main()
139