1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay
5 from nmigen
.hdl
.ast
import (AnySeq
, Assert
, Signal
, Value
, Array
, Value
)
6 from nmigen
.hdl
.dsl
import Module
7 from nmigen
.sim
import Delay
, Tick
8 from nmutil
.formaltest
import FHDLTestCase
9 from nmutil
.plru2
import PLRU
# , PLRUs
10 from nmutil
.sim_util
import write_il
, do_sim
11 from nmutil
.plain_data
import plain_data
15 class PrettyPrintState
:
16 __slots__
= "indent", "file", "at_line_start"
18 def __init__(self
, indent
=0, file=None, at_line_start
=True):
21 self
.at_line_start
= at_line_start
23 def write(self
, text
):
27 self
.at_line_start
= True
28 elif self
.at_line_start
:
29 self
.at_line_start
= False
30 print(" " * self
.indent
, file=self
.file, end
='')
31 print(ch
, file=self
.file, end
='')
36 __slots__
= "id", "state", "left_child", "right_child"
38 def __init__(self
, id, left_child
=None, right_child
=None):
39 # type: (int, PLRUNode | None, PLRUNode | None) -> None
41 self
.state
= Signal(name
=f
"state_{id}")
42 self
.left_child
= left_child
43 self
.right_child
= right_child
48 if self
.left_child
is not None:
49 depth
= max(depth
, 1 + self
.left_child
.depth
)
50 if self
.right_child
is not None:
51 depth
= max(depth
, 1 + self
.right_child
.depth
)
54 def __pretty_print(self
, state
):
55 # type: (PrettyPrintState) -> None
56 state
.write("PLRUNode(")
58 state
.write(f
"id={self.id!r},\n")
59 state
.write(f
"state={self.state!r},\n")
60 state
.write("left_child=")
61 if self
.left_child
is None:
64 self
.left_child
.__pretty
_print
(state
)
65 state
.write(",\nright_child=")
66 if self
.right_child
is None:
69 self
.right_child
.__pretty
_print
(state
)
73 def pretty_print(self
, file=None):
74 self
.__pretty
_print
(PrettyPrintState(file=file))
77 def set_states_from_index(self
, m
, index
, ids
):
78 # type: (Module, Value, list[Signal]) -> None
79 m
.d
.sync
+= self
.state
.eq(~index
[-1])
80 m
.d
.comb
+= ids
[0].eq(self
.id)
82 if self
.right_child
is not None:
83 self
.right_child
.set_states_from_index(m
, index
[:-1], ids
[1:])
85 if self
.left_child
is not None:
86 self
.left_child
.set_states_from_index(m
, index
[:-1], ids
[1:])
88 def get_lru(self
, m
, ids
):
89 # type: (Module, list[Signal]) -> Signal
90 retval
= Signal(1 + self
.depth
, name
=f
"lru_{self.id}", reset
=0)
91 m
.d
.comb
+= retval
[-1].eq(self
.state
)
92 m
.d
.comb
+= ids
[0].eq(self
.id)
93 with m
.If(self
.state
):
94 if self
.right_child
is not None:
95 right_lru
= self
.right_child
.get_lru(m
, ids
[1:])
96 m
.d
.comb
+= retval
[:-1].eq(right_lru
)
98 if self
.left_child
is not None:
99 left_lru
= self
.left_child
.get_lru(m
, ids
[1:])
100 m
.d
.comb
+= retval
[:-1].eq(left_lru
)
104 class TestPLRU(FHDLTestCase
):
105 def tst(self
, log2_num_ways
, test_seq
=None):
106 # type: (int, list[int | None] | None) -> None
110 __slots__
= "test", "en"
112 def __init__(self
, test
, en
):
113 # type: (Value, Signal) -> None
117 asserts
= [] # type: list[MyAssert]
121 return [Assert(test
, src_loc_at
=1)]
122 assert_en
= Signal(name
="assert_en", src_loc_at
=1, reset
=False)
123 asserts
.append(MyAssert(test
=test
, en
=assert_en
))
124 return [assert_en
.eq(True)]
126 dut
= PLRU(log2_num_ways
, debug
=True) # check debug works
127 write_il(self
, dut
, ports
=dut
.ports())
128 # debug clutters up vcd, so disable it for formal proofs
129 dut
= PLRU(log2_num_ways
, debug
=test_seq
is not None)
130 num_ways
= 1 << log2_num_ways
131 self
.assertEqual(dut
.log2_num_ways
, log2_num_ways
)
132 self
.assertEqual(dut
.num_ways
, num_ways
)
133 self
.assertIsInstance(dut
.acc_i
, Signal
)
134 self
.assertIsInstance(dut
.acc_en_i
, Signal
)
135 self
.assertIsInstance(dut
.lru_o
, Signal
)
136 self
.assertEqual(len(dut
.acc_i
), log2_num_ways
)
137 self
.assertEqual(len(dut
.acc_en_i
), 1)
138 self
.assertEqual(len(dut
.lru_o
), log2_num_ways
)
139 write_il(self
, dut
, ports
=dut
.ports())
141 nodes
= [PLRUNode(i
) for i
in range(num_ways
- 1)]
142 self
.assertIsInstance(dut
._tree
, Array
)
143 self
.assertEqual(len(dut
._tree
), len(nodes
))
144 for i
in range(len(nodes
)):
146 parent
= (i
+ 1) // 2 - 1
148 nodes
[parent
].left_child
= nodes
[i
]
150 nodes
[parent
].right_child
= nodes
[i
]
151 self
.assertIsInstance(dut
._tree
[i
], Signal
)
152 self
.assertEqual(len(dut
._tree
[i
]), 1)
153 m
.d
.comb
+= assert_(nodes
[i
].state
== dut
._tree
[i
])
157 dut
.acc_i
.eq(AnySeq(log2_num_ways
)),
158 dut
.acc_en_i
.eq(AnySeq(1)),
161 l2nwr
= range(log2_num_ways
)
162 upd_ids
= [Signal(log2_num_ways
, name
=f
"upd_id_{i}") for i
in l2nwr
]
163 with m
.If(dut
.acc_en_i
):
164 nodes
[0].set_states_from_index(m
, dut
.acc_i
, upd_ids
)
166 self
.assertEqual(len(dut
._upd
_lru
_nodes
), len(upd_ids
))
167 for l
, r
in zip(dut
._upd
_lru
_nodes
, upd_ids
):
168 m
.d
.comb
+= assert_(l
== r
)
170 get_ids
= [Signal(log2_num_ways
, name
=f
"get_id_{i}") for i
in l2nwr
]
171 lru
= Signal(log2_num_ways
)
172 m
.d
.comb
+= lru
.eq(nodes
[0].get_lru(m
, get_ids
))
173 m
.d
.comb
+= assert_(dut
.lru_o
== lru
)
174 self
.assertEqual(len(dut
._get
_lru
_nodes
), len(get_ids
))
175 for l
, r
in zip(dut
._get
_lru
_nodes
, get_ids
):
176 m
.d
.comb
+= assert_(l
== r
)
178 nodes
[0].pretty_print()
180 m
.submodules
.dut
= dut
182 self
.assertFormal(m
, mode
="prove", depth
=2)
184 traces
= [dut
.acc_i
, dut
.acc_en_i
, *dut
._tree
]
186 traces
.append(node
.state
)
188 dut
.lru_o
, lru
, *dut
._get
_lru
_nodes
, *get_ids
,
189 *dut
._upd
_lru
_nodes
, *upd_ids
,
192 def subtest(acc_i
, acc_en_i
):
193 yield dut
.acc_i
.eq(acc_i
)
194 yield dut
.acc_en_i
.eq(acc_en_i
)
200 assert_loc
=':'.join(map(str, a
.en
.src_loc
))):
201 self
.assertTrue((yield a
.test
))
204 for test_item
in test_seq
:
205 if test_item
is None:
206 with self
.subTest(test_item
="None"):
207 yield from subtest(acc_i
=0, acc_en_i
=0)
209 with self
.subTest(test_item
=hex(test_item
)):
210 yield from subtest(acc_i
=test_item
, acc_en_i
=1)
212 with
do_sim(self
, m
, traces
) as sim
:
214 sim
.add_process(process
)
217 def test_bits_1(self
):
220 def test_bits_2(self
):
223 def test_bits_3(self
):
226 def test_bits_4(self
):
229 def test_bits_5(self
):
232 def test_bits_6(self
):
235 def test_bits_3_sim(self
):
237 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7,
239 0x0, 0x4, 0x2, 0x6, 0x1, 0x5, 0x3, 0x7,
244 if __name__
== "__main__":