bfad73a80385bffeefc70217a54f790e981ccd95
1 # Proof of correctness for partitioned equals module
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
4 from nmigen
import Module
, Signal
, Elaboratable
, Mux
, Cat
5 from nmigen
.asserts
import Assert
, AnyConst
, Assume
6 from nmigen
.test
.utils
import FHDLTestCase
7 from nmigen
.cli
import rtlil
9 from ieee754
.part_mul_add
.partpoints
import PartitionPoints
10 from ieee754
.part_cmp
.eq_gt_ge
import PartitionedEqGtGe
14 # This defines a module to drive the device under test and assert
15 # properties about its outputs
16 class EqualsDriver(Elaboratable
):
21 def get_intervals(self
, signal
, points
):
24 keys
= list(points
.keys()) + [signal
.width
]
27 interval
.append(signal
[start
:end
])
31 def elaborate(self
, platform
):
37 # setup the inputs and outputs of the DUT as anyconst
40 points
= PartitionPoints()
41 gates
= Signal(mwidth
-1)
43 for i
in range(mwidth
-1):
44 points
[i
*8+8] = gates
[i
]
47 comb
+= [a
.eq(AnyConst(width
)),
48 b
.eq(AnyConst(width
)),
49 opcode
.eq(AnyConst(opcode
.width
)),
50 gates
.eq(AnyConst(mwidth
-1))]
52 m
.submodules
.dut
= dut
= PartitionedEqGtGe(width
, points
)
54 a_intervals
= self
.get_intervals(a
, points
)
55 b_intervals
= self
.get_intervals(b
, points
)
57 with m
.If(opcode
== 0b00):
60 comb
+= Assert(out
[-1] == (a
== b
))
62 comb
+= Assert(out
[2] == ((a_intervals
[1] == \
66 comb
+= Assert(out
[0] == (a_intervals
[0] == b_intervals
[0]))
68 comb
+= Assert(out
[1] == ((a_intervals
[0] == \
72 comb
+= Assert(out
[2] == (a_intervals
[2] == b_intervals
[2]))
74 for i
in range(mwidth
-1):
75 comb
+= Assert(out
[i
] == \
76 (a_intervals
[i
] == b_intervals
[i
]))
77 with m
.If(opcode
== 0b01):
80 comb
+= Assert(out
[-1] == (a
> b
))
82 comb
+= Assert(out
[0] == (a_intervals
[0] > b_intervals
[0]))
84 comb
+= Assert(out
[1] == 0)
85 comb
+= Assert(out
[2] == (Cat(*a_intervals
[1:3]) > \
86 Cat(*b_intervals
[1:3])))
88 comb
+= Assert(out
[0] == 0)
89 comb
+= Assert(out
[1] == (Cat(*a_intervals
[0:2]) > \
90 Cat(*b_intervals
[0:2])))
91 comb
+= Assert(out
[2] == (a_intervals
[2] > b_intervals
[2]))
93 for i
in range(mwidth
-1):
94 comb
+= Assert(out
[i
] == (a_intervals
[i
] > \
96 with m
.If(opcode
== 0b10):
99 comb
+= Assert(out
[-1] == (a
>= b
))
101 comb
+= Assert(out
[0] == (a_intervals
[0] >= b_intervals
[0]))
103 comb
+= Assert(out
[1] == 0)
104 comb
+= Assert(out
[2] == (Cat(*a_intervals
[1:3]) >= \
105 Cat(*b_intervals
[1:3])))
107 comb
+= Assert(out
[0] == 0)
108 comb
+= Assert(out
[1] == (Cat(*a_intervals
[0:2]) >= \
109 Cat(*b_intervals
[0:2])))
110 comb
+= Assert(out
[2] == (a_intervals
[2] >= b_intervals
[2]))
112 for i
in range(mwidth
-1):
113 comb
+= Assert(out
[i
] == \
114 (a_intervals
[i
] >= b_intervals
[i
]))
118 comb
+= [dut
.a
.eq(a
),
120 dut
.opcode
.eq(opcode
),
124 class PartitionedEqTestCase(FHDLTestCase
):
126 module
= EqualsDriver()
127 self
.assertFormal(module
, mode
="bmc", depth
=4)
129 if __name__
== "__main__":