speed up ==, hash, <, >, <=, and >= for plain_data
[nmutil.git] / src / nmutil / formal / test_plru.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay
3
4 import unittest
5 from nmigen.hdl.ast import (AnySeq, Assert, Signal, Value, Array, Value)
6 from nmigen.hdl.dsl import Module
7 from nmigen.sim import Delay, Tick
8 from nmutil.formaltest import FHDLTestCase
9 from nmutil.plru2 import PLRU # , PLRUs
10 from nmutil.sim_util import write_il, do_sim
11 from nmutil.plain_data import plain_data
12
13
14 @plain_data()
15 class PrettyPrintState:
16 __slots__ = "indent", "file", "at_line_start"
17
18 def __init__(self, indent=0, file=None, at_line_start=True):
19 self.indent = indent
20 self.file = file
21 self.at_line_start = at_line_start
22
23 def write(self, text):
24 # type: (str) -> None
25 for ch in text:
26 if ch == "\n":
27 self.at_line_start = True
28 elif self.at_line_start:
29 self.at_line_start = False
30 print(" " * self.indent, file=self.file, end='')
31 print(ch, file=self.file, end='')
32
33
34 @plain_data()
35 class PLRUNode:
36 __slots__ = "id", "state", "left_child", "right_child"
37
38 def __init__(self, id, left_child=None, right_child=None):
39 # type: (int, PLRUNode | None, PLRUNode | None) -> None
40 self.id = id
41 self.state = Signal(name=f"state_{id}")
42 self.left_child = left_child
43 self.right_child = right_child
44
45 @property
46 def depth(self):
47 depth = 0
48 if self.left_child is not None:
49 depth = max(depth, 1 + self.left_child.depth)
50 if self.right_child is not None:
51 depth = max(depth, 1 + self.right_child.depth)
52 return depth
53
54 def __pretty_print(self, state):
55 # type: (PrettyPrintState) -> None
56 state.write("PLRUNode(")
57 state.indent += 1
58 state.write(f"id={self.id!r},\n")
59 state.write(f"state={self.state!r},\n")
60 state.write("left_child=")
61 if self.left_child is None:
62 state.write("None")
63 else:
64 self.left_child.__pretty_print(state)
65 state.write(",\nright_child=")
66 if self.right_child is None:
67 state.write("None")
68 else:
69 self.right_child.__pretty_print(state)
70 state.indent -= 1
71 state.write("\n)")
72
73 def pretty_print(self, file=None):
74 self.__pretty_print(PrettyPrintState(file=file))
75 print(file=file)
76
77 def set_states_from_index(self, m, index, ids):
78 # type: (Module, Value, list[Signal]) -> None
79 m.d.sync += self.state.eq(~index[-1])
80 m.d.comb += ids[0].eq(self.id)
81 with m.If(index[-1]):
82 if self.right_child is not None:
83 self.right_child.set_states_from_index(m, index[:-1], ids[1:])
84 with m.Else():
85 if self.left_child is not None:
86 self.left_child.set_states_from_index(m, index[:-1], ids[1:])
87
88 def get_lru(self, m, ids):
89 # type: (Module, list[Signal]) -> Signal
90 retval = Signal(1 + self.depth, name=f"lru_{self.id}", reset=0)
91 m.d.comb += retval[-1].eq(self.state)
92 m.d.comb += ids[0].eq(self.id)
93 with m.If(self.state):
94 if self.right_child is not None:
95 right_lru = self.right_child.get_lru(m, ids[1:])
96 m.d.comb += retval[:-1].eq(right_lru)
97 with m.Else():
98 if self.left_child is not None:
99 left_lru = self.left_child.get_lru(m, ids[1:])
100 m.d.comb += retval[:-1].eq(left_lru)
101 return retval
102
103
104 class TestPLRU(FHDLTestCase):
105 def tst(self, log2_num_ways, test_seq=None):
106 # type: (int, list[int | None] | None) -> None
107
108 @plain_data()
109 class MyAssert:
110 __slots__ = "test", "en"
111
112 def __init__(self, test, en):
113 # type: (Value, Signal) -> None
114 self.test = test
115 self.en = en
116
117 asserts = [] # type: list[MyAssert]
118
119 def assert_(test):
120 if test_seq is None:
121 return [Assert(test, src_loc_at=1)]
122 assert_en = Signal(name="assert_en", src_loc_at=1, reset=False)
123 asserts.append(MyAssert(test=test, en=assert_en))
124 return [assert_en.eq(True)]
125
126 dut = PLRU(log2_num_ways, debug=True) # check debug works
127 write_il(self, dut, ports=dut.ports())
128 # debug clutters up vcd, so disable it for formal proofs
129 dut = PLRU(log2_num_ways, debug=test_seq is not None)
130 num_ways = 1 << log2_num_ways
131 self.assertEqual(dut.log2_num_ways, log2_num_ways)
132 self.assertEqual(dut.num_ways, num_ways)
133 self.assertIsInstance(dut.acc_i, Signal)
134 self.assertIsInstance(dut.acc_en_i, Signal)
135 self.assertIsInstance(dut.lru_o, Signal)
136 self.assertEqual(len(dut.acc_i), log2_num_ways)
137 self.assertEqual(len(dut.acc_en_i), 1)
138 self.assertEqual(len(dut.lru_o), log2_num_ways)
139 write_il(self, dut, ports=dut.ports())
140 m = Module()
141 nodes = [PLRUNode(i) for i in range(num_ways - 1)]
142 self.assertIsInstance(dut._tree, Array)
143 self.assertEqual(len(dut._tree), len(nodes))
144 for i in range(len(nodes)):
145 if i != 0:
146 parent = (i + 1) // 2 - 1
147 if i % 2:
148 nodes[parent].left_child = nodes[i]
149 else:
150 nodes[parent].right_child = nodes[i]
151 self.assertIsInstance(dut._tree[i], Signal)
152 self.assertEqual(len(dut._tree[i]), 1)
153 m.d.comb += assert_(nodes[i].state == dut._tree[i])
154
155 if test_seq is None:
156 m.d.comb += [
157 dut.acc_i.eq(AnySeq(log2_num_ways)),
158 dut.acc_en_i.eq(AnySeq(1)),
159 ]
160
161 l2nwr = range(log2_num_ways)
162 upd_ids = [Signal(log2_num_ways, name=f"upd_id_{i}") for i in l2nwr]
163 with m.If(dut.acc_en_i):
164 nodes[0].set_states_from_index(m, dut.acc_i, upd_ids)
165
166 self.assertEqual(len(dut._upd_lru_nodes), len(upd_ids))
167 for l, r in zip(dut._upd_lru_nodes, upd_ids):
168 m.d.comb += assert_(l == r)
169
170 get_ids = [Signal(log2_num_ways, name=f"get_id_{i}") for i in l2nwr]
171 lru = Signal(log2_num_ways)
172 m.d.comb += lru.eq(nodes[0].get_lru(m, get_ids))
173 m.d.comb += assert_(dut.lru_o == lru)
174 self.assertEqual(len(dut._get_lru_nodes), len(get_ids))
175 for l, r in zip(dut._get_lru_nodes, get_ids):
176 m.d.comb += assert_(l == r)
177
178 nodes[0].pretty_print()
179
180 m.submodules.dut = dut
181 if test_seq is None:
182 self.assertFormal(m, mode="prove", depth=2)
183 else:
184 traces = [dut.acc_i, dut.acc_en_i, *dut._tree]
185 for node in nodes:
186 traces.append(node.state)
187 traces += [
188 dut.lru_o, lru, *dut._get_lru_nodes, *get_ids,
189 *dut._upd_lru_nodes, *upd_ids,
190 ]
191
192 def subtest(acc_i, acc_en_i):
193 yield dut.acc_i.eq(acc_i)
194 yield dut.acc_en_i.eq(acc_en_i)
195 yield Tick()
196 yield Delay(0.7e-6)
197 for a in asserts:
198 if (yield a.en):
199 with self.subTest(
200 assert_loc=':'.join(map(str, a.en.src_loc))):
201 self.assertTrue((yield a.test))
202
203 def process():
204 for test_item in test_seq:
205 if test_item is None:
206 with self.subTest(test_item="None"):
207 yield from subtest(acc_i=0, acc_en_i=0)
208 else:
209 with self.subTest(test_item=hex(test_item)):
210 yield from subtest(acc_i=test_item, acc_en_i=1)
211
212 with do_sim(self, m, traces) as sim:
213 sim.add_clock(1e-6)
214 sim.add_process(process)
215 sim.run()
216
217 def test_bits_1(self):
218 self.tst(1)
219
220 def test_bits_2(self):
221 self.tst(2)
222
223 def test_bits_3(self):
224 self.tst(3)
225
226 def test_bits_4(self):
227 self.tst(4)
228
229 def test_bits_5(self):
230 self.tst(5)
231
232 def test_bits_6(self):
233 self.tst(6)
234
235 def test_bits_3_sim(self):
236 self.tst(3, [
237 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7,
238 None,
239 0x0, 0x4, 0x2, 0x6, 0x1, 0x5, 0x3, 0x7,
240 None,
241 ])
242
243
244 if __name__ == "__main__":
245 unittest.main()