1 # Proof of correctness for partitioned equals module
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
4 from nmigen
import Module
, Signal
, Elaboratable
, Mux
, Cat
5 from nmigen
.asserts
import Assert
, AnyConst
, Assume
6 from nmutil
.formaltest
import FHDLTestCase
7 from nmigen
.cli
import rtlil
9 from ieee754
.part_mul_add
.partpoints
import PartitionPoints
10 from ieee754
.part_cmp
.eq_gt_ge
import PartitionedEqGtGe
14 # This defines a module to drive the device under test and assert
15 # properties about its outputs
16 class EqualsDriver(Elaboratable
):
21 def get_intervals(self
, signal
, points
):
24 keys
= list(points
.keys()) + [signal
.width
]
27 interval
.append(signal
[start
:end
])
31 def elaborate(self
, platform
):
37 # setup the inputs and outputs of the DUT as anyconst
40 points
= PartitionPoints()
41 gates
= Signal(mwidth
-1)
43 for i
in range(mwidth
-1):
44 points
[i
*8+8] = gates
[i
]
47 comb
+= [a
.eq(AnyConst(width
)),
48 b
.eq(AnyConst(width
)),
49 opcode
.eq(AnyConst(opcode
.width
)),
50 gates
.eq(AnyConst(mwidth
-1))]
52 m
.submodules
.dut
= dut
= PartitionedEqGtGe(width
, points
)
54 a_intervals
= self
.get_intervals(a
, points
)
55 b_intervals
= self
.get_intervals(b
, points
)
57 with m
.If(opcode
== 0b00):
60 comb
+= Assert(out
[0] == (a
== b
))
61 comb
+= Assert(out
[1] == out
[0])
62 comb
+= Assert(out
[2] == out
[1])
64 comb
+= Assert(out
[0] == (a_intervals
[0] == b_intervals
[0]))
65 comb
+= Assert(out
[1] == ((a_intervals
[1] == \
69 comb
+= Assert(out
[2] == out
[1])
71 comb
+= Assert(out
[0] == ((a_intervals
[0] == \
75 comb
+= Assert(out
[1] == out
[0])
76 comb
+= Assert(out
[2] == (a_intervals
[2] == b_intervals
[2]))
78 for i
in range(mwidth
-1):
79 comb
+= Assert(out
[i
] == \
80 (a_intervals
[i
] == b_intervals
[i
]))
81 with m
.If(opcode
== 0b01):
84 comb
+= Assert(out
[0] == (a
> b
))
85 comb
+= Assert(out
[1] == out
[0])
86 comb
+= Assert(out
[2] == out
[1])
88 comb
+= Assert(out
[0] == (a_intervals
[0] > b_intervals
[0]))
90 comb
+= Assert(out
[1] == (Cat(*a_intervals
[1:3]) > \
91 Cat(*b_intervals
[1:3])))
92 comb
+= Assert(out
[2] == out
[1])
94 comb
+= Assert(out
[0] == (Cat(*a_intervals
[0:2]) > \
95 Cat(*b_intervals
[0:2])))
96 comb
+= Assert(out
[1] == out
[0])
97 comb
+= Assert(out
[2] == (a_intervals
[2] > b_intervals
[2]))
99 for i
in range(mwidth
-1):
100 comb
+= Assert(out
[i
] == (a_intervals
[i
] > \
102 with m
.If(opcode
== 0b10):
103 with m
.Switch(gates
):
105 comb
+= Assert(out
[0] == (a
>= b
))
106 comb
+= Assert(out
[1] == out
[0])
107 comb
+= Assert(out
[2] == out
[1])
109 comb
+= Assert(out
[0] == (a_intervals
[0] >= b_intervals
[0]))
111 comb
+= Assert(out
[1] == (Cat(*a_intervals
[1:3]) >= \
112 Cat(*b_intervals
[1:3])))
113 comb
+= Assert(out
[2] == out
[1])
115 comb
+= Assert(out
[0] == (Cat(*a_intervals
[0:2]) >= \
116 Cat(*b_intervals
[0:2])))
117 comb
+= Assert(out
[1] == out
[0])
118 comb
+= Assert(out
[2] == (a_intervals
[2] >= b_intervals
[2]))
120 for i
in range(mwidth
-1):
121 comb
+= Assert(out
[i
] == \
122 (a_intervals
[i
] >= b_intervals
[i
]))
126 comb
+= [dut
.a
.eq(a
),
128 dut
.opcode
.eq(opcode
),
132 class PartitionedEqTestCase(FHDLTestCase
):
134 module
= EqualsDriver()
135 self
.assertFormal(module
, mode
="bmc", depth
=4)
137 if __name__
== "__main__":