clarify output, use Cat on list
[soc.git] / TLB / src / ariane / tlb.py
1 """
2 # Copyright 2018 ETH Zurich and University of Bologna.
3 # Copyright and related rights are licensed under the Solderpad Hardware
4 # License, Version 0.51 (the "License"); you may not use this file except in
5 # compliance with the License. You may obtain a copy of the License at
6 # http:#solderpad.org/licenses/SHL-0.51. Unless required by applicable law
7 # or agreed to in writing, software, hardware and materials distributed under
8 # this License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
9 # CONDITIONS OF ANY KIND, either express or implied. See the License for the
10 # specific language governing permissions and limitations under the License.
11 #
12 # Author: David Schaffenrath, TU Graz
13 # Author: Florian Zaruba, ETH Zurich
14 # Date: 21.4.2017
15 # Description: Translation Lookaside Buffer, SV39
16 # fully set-associative
17 """
18 from math import log2
19 from nmigen import Signal, Module, Cat, Const, Array
20 from nmigen.cli import verilog, rtlil
21 from nmigen.lib.coding import Encoder
22
23 from ptw import TLBUpdate, PTE, ASID_WIDTH
24
25 TLB_ENTRIES = 8
26
27
28 class TLBEntry:
29 def __init__(self):
30 self.asid = Signal(ASID_WIDTH)
31 # SV39 defines three levels of page tables
32 self.vpn0 = Signal(9)
33 self.vpn1 = Signal(9)
34 self.vpn2 = Signal(9)
35 self.is_2M = Signal()
36 self.is_1G = Signal()
37 self.valid = Signal()
38
39 def flatten(self):
40 return Cat(*self.ports())
41
42 def eq(self, x):
43 return self.flatten().eq(x.flatten())
44
45 def ports(self):
46 return [self.asid, self.vpn0, self.vpn1, self.vpn2,
47 self.is_2M, self.is_1G, self.valid]
48
49
50 class TLBContent:
51 def __init__(self, pte_width):
52 self.pte_width = pte_width
53 self.flush_i = Signal() # Flush signal
54 # Update TLB
55 self.update_i = TLBUpdate()
56 self.vpn2 = Signal(9)
57 self.vpn1 = Signal(9)
58 self.vpn0 = Signal(9)
59 self.replace_en_i = Signal() # replace the following entry,
60 # set by replacement strategy
61 # Lookup signals
62 self.lu_asid_i = Signal(ASID_WIDTH)
63 self.lu_content_o = Signal(self.pte_width)
64 self.lu_is_2M_o = Signal()
65 self.lu_is_1G_o = Signal()
66 self.lu_hit_o = Signal()
67
68 def elaborate(self, platform):
69 m = Module()
70
71 tags = TLBEntry()
72 content = Signal(self.pte_width)
73
74 m.d.comb += [self.lu_hit_o.eq(0),
75 self.lu_is_2M_o.eq(0),
76 self.lu_is_1G_o.eq(0)]
77
78 # temporaries for 1st level match
79 asid_ok = Signal(reset_less=True)
80 vpn2_ok = Signal(reset_less=True)
81 tags_ok = Signal(reset_less=True)
82 vpn2_hit = Signal(reset_less=True)
83 m.d.comb += [tags_ok.eq(tags.valid),
84 asid_ok.eq(tags.asid == self.lu_asid_i),
85 vpn2_ok.eq(tags.vpn2 == self.vpn2),
86 vpn2_hit.eq(tags_ok & asid_ok & vpn2_ok)]
87 # temporaries for 2nd level match
88 vpn1_ok = Signal(reset_less=True)
89 tags_2M = Signal(reset_less=True)
90 vpn0_ok = Signal(reset_less=True)
91 vpn0_or_2M = Signal(reset_less=True)
92 m.d.comb += [vpn1_ok.eq(self.vpn1 == tags.vpn1),
93 tags_2M.eq(tags.is_2M),
94 vpn0_ok.eq(self.vpn0 == tags.vpn0),
95 vpn0_or_2M.eq(tags_2M | vpn0_ok)]
96 # first level match, this may be a giga page,
97 # check the ASID flags as well
98 with m.If(vpn2_hit):
99 # second level
100 with m.If (tags.is_1G):
101 m.d.sync += self.lu_content_o.eq(content)
102 m.d.comb += [ self.lu_is_1G_o.eq(1),
103 self.lu_hit_o.eq(1),
104 ]
105 # not a giga page hit so check further
106 with m.Elif(vpn1_ok):
107 # this could be a 2 mega page hit or a 4 kB hit
108 # output accordingly
109 with m.If(vpn0_or_2M):
110 m.d.sync += self.lu_content_o.eq(content)
111 m.d.comb += [ self.lu_is_2M_o.eq(tags.is_2M),
112 self.lu_hit_o.eq(1),
113 ]
114
115 # ------------------
116 # Update and Flush
117 # ------------------
118
119 replace_valid = Signal(reset_less=True)
120 m.d.comb += replace_valid.eq(self.update_i.valid & self.replace_en_i)
121 with m.If (self.flush_i):
122 # invalidate (flush) conditions: all if zero or just this ASID
123 with m.If (self.lu_asid_i == Const(0, ASID_WIDTH) |
124 (self.lu_asid_i == tags.asid)):
125 m.d.sync += tags.valid.eq(0)
126
127 # normal replacement
128 with m.Elif(replace_valid):
129 m.d.sync += [ # update tag array
130 tags.asid.eq(self.update_i.asid),
131 tags.vpn2.eq(self.update_i.vpn[18:27]),
132 tags.vpn1.eq(self.update_i.vpn[9:18]),
133 tags.vpn0.eq(self.update_i.vpn[0:9]),
134 tags.is_1G.eq(self.update_i.is_1G),
135 tags.is_2M.eq(self.update_i.is_2M),
136 tags.valid.eq(1),
137 # and content as well
138 content.eq(self.update_i.content.flatten())
139 ]
140 return m
141
142 def ports(self):
143 return [self.flush_i,
144 self.lu_asid_i,
145 self.lu_is_2M_o, self.lu_is_1G_o, self.lu_hit_o,
146 ] + self.update_i.content.ports() + self.update_i.ports()
147
148
149 class PLRU:
150 def __init__(self):
151 self.lu_hit = Signal(TLB_ENTRIES)
152 self.replace_en_o = Signal(TLB_ENTRIES)
153 self.lu_access_i = Signal()
154
155 def elaborate(self, platform):
156 m = Module()
157
158 # -----------------------------------------------
159 # PLRU - Pseudo Least Recently Used Replacement
160 # -----------------------------------------------
161
162 TLBSZ = 2*(TLB_ENTRIES-1)
163 plru_tree = Signal(TLBSZ)
164
165 # The PLRU-tree indexing:
166 # lvl0 0
167 # / \
168 # / \
169 # lvl1 1 2
170 # / \ / \
171 # lvl2 3 4 5 6
172 # / \ /\/\ /\
173 # ... ... ... ...
174 # Just predefine which nodes will be set/cleared
175 # E.g. for a TLB with 8 entries, the for-loop is semantically
176 # equivalent to the following pseudo-code:
177 # unique case (1'b1)
178 # lu_hit[7]: plru_tree[0, 2, 6] = {1, 1, 1};
179 # lu_hit[6]: plru_tree[0, 2, 6] = {1, 1, 0};
180 # lu_hit[5]: plru_tree[0, 2, 5] = {1, 0, 1};
181 # lu_hit[4]: plru_tree[0, 2, 5] = {1, 0, 0};
182 # lu_hit[3]: plru_tree[0, 1, 4] = {0, 1, 1};
183 # lu_hit[2]: plru_tree[0, 1, 4] = {0, 1, 0};
184 # lu_hit[1]: plru_tree[0, 1, 3] = {0, 0, 1};
185 # lu_hit[0]: plru_tree[0, 1, 3] = {0, 0, 0};
186 # default: begin /* No hit */ end
187 # endcase
188 LOG_TLB = int(log2(TLB_ENTRIES))
189 for i in range(TLB_ENTRIES):
190 # we got a hit so update the pointer as it was least recently used
191 hit = Signal(reset_less=True)
192 m.d.comb += hit.eq(self.lu_hit[i] & self.lu_access_i)
193 with m.If(hit):
194 # Set the nodes to the values we would expect
195 for lvl in range(LOG_TLB):
196 idx_base = (1<<lvl)-1
197 # lvl0 <=> MSB, lvl1 <=> MSB-1, ...
198 shift = LOG_TLB - lvl;
199 new_idx = Const(~((i >> (shift-1)) & 1), 1)
200 print ("plru", i, lvl, hex(idx_base),
201 idx_base + (i >> shift), shift, new_idx)
202 m.d.sync += plru_tree[idx_base + (i >> shift)].eq(new_idx)
203
204 # Decode tree to write enable signals
205 # Next for-loop basically creates the following logic for e.g.
206 # an 8 entry TLB (note: pseudo-code obviously):
207 # replace_en[7] = &plru_tree[ 6, 2, 0]; #plru_tree[0,2,6]=={1,1,1}
208 # replace_en[6] = &plru_tree[~6, 2, 0]; #plru_tree[0,2,6]=={1,1,0}
209 # replace_en[5] = &plru_tree[ 5,~2, 0]; #plru_tree[0,2,5]=={1,0,1}
210 # replace_en[4] = &plru_tree[~5,~2, 0]; #plru_tree[0,2,5]=={1,0,0}
211 # replace_en[3] = &plru_tree[ 4, 1,~0]; #plru_tree[0,1,4]=={0,1,1}
212 # replace_en[2] = &plru_tree[~4, 1,~0]; #plru_tree[0,1,4]=={0,1,0}
213 # replace_en[1] = &plru_tree[ 3,~1,~0]; #plru_tree[0,1,3]=={0,0,1}
214 # replace_en[0] = &plru_tree[~3,~1,~0]; #plru_tree[0,1,3]=={0,0,0}
215 # For each entry traverse the tree. If every tree-node matches
216 # the corresponding bit of the entry's index, this is
217 # the next entry to replace.
218 for i in range(TLB_ENTRIES):
219 en = []
220 for lvl in range(LOG_TLB):
221 idx_base = (1<<lvl)-1
222 # lvl0 <=> MSB, lvl1 <=> MSB-1, ...
223 shift = LOG_TLB - lvl;
224 new_idx = (i >> (shift-1)) & 1;
225 plru = Signal(reset_less=True)
226 m.d.comb += plru.eq(plru_tree[idx_base + (i>>shift)])
227 # en &= plru_tree_q[idx_base + (i>>shift)] == new_idx;
228 if new_idx:
229 en.append(~plru) # yes inverted (using bool())
230 else:
231 en.append(plru) # yes inverted (using bool())
232 print ("plru", i, en)
233 # boolean logic manipulation:
234 # plur0 & plru1 & plur2 == ~(~plru0 | ~plru1 | ~plru2)
235 m.d.comb += self.replace_en_o[i].eq(~Cat(*en).bool())
236
237 return m
238
239
240 class TLB:
241 def __init__(self):
242 self.flush_i = Signal() # Flush signal
243 # Lookup signals
244 self.lu_access_i = Signal()
245 self.lu_asid_i = Signal(ASID_WIDTH)
246 self.lu_vaddr_i = Signal(64)
247 self.lu_content_o = PTE()
248 self.lu_is_2M_o = Signal()
249 self.lu_is_1G_o = Signal()
250 self.lu_hit_o = Signal()
251 # Update TLB
252 self.pte_width = len(self.lu_content_o.flatten())
253 self.update_i = TLBUpdate()
254
255 def elaborate(self, platform):
256 m = Module()
257
258 vpn2 = Signal(9)
259 vpn1 = Signal(9)
260 vpn0 = Signal(9)
261
262 #-------------
263 # Translation
264 #-------------
265 m.d.comb += [ vpn0.eq(self.lu_vaddr_i[12:21]),
266 vpn1.eq(self.lu_vaddr_i[21:30]),
267 vpn2.eq(self.lu_vaddr_i[30:39]),
268 ]
269
270 # SV39 defines three levels of page tables
271 tc = []
272 for i in range(TLB_ENTRIES):
273 tlc = TLBContent(self.pte_width)
274 setattr(m.submodules, "tc%d" % i, tlc)
275 tc.append(tlc)
276 # connect inputs
277 tlc.update_i = self.update_i # saves a lot of graphviz links
278 m.d.comb += [tlc.vpn0.eq(vpn0),
279 tlc.vpn1.eq(vpn1),
280 tlc.vpn2.eq(vpn2),
281 tlc.flush_i.eq(self.flush_i),
282 #tlc.update_i.eq(self.update_i),
283 tlc.lu_asid_i.eq(self.lu_asid_i)]
284 tc = Array(tc)
285
286 #--------------
287 # Select hit
288 #--------------
289
290 # use Encoder to select hit index
291 # XXX TODO: assert that there's only one valid entry (one lu_hit)
292 hitsel = Encoder(TLB_ENTRIES)
293 m.submodules.hitsel = hitsel
294
295 hits = []
296 for i in range(TLB_ENTRIES):
297 hits.append(tc[i].lu_hit_o)
298 m.d.comb += hitsel.i.eq(Cat(*hits))
299 idx = hitsel.o
300
301 active = Signal(reset_less=True)
302 m.d.comb += active.eq(~hitsel.n)
303 with m.If(active):
304 # active hit, send selected as output
305 m.d.comb += [ self.lu_is_1G_o.eq(tc[idx].lu_is_1G_o),
306 self.lu_is_2M_o.eq(tc[idx].lu_is_2M_o),
307 self.lu_hit_o.eq(1),
308 self.lu_content_o.flatten().eq(tc[idx].lu_content_o),
309 ]
310
311 #--------------
312 # PLRU.
313 #--------------
314
315 p = PLRU()
316 m.submodules.plru = p
317
318 # connect PLRU inputs/outputs
319 # XXX TODO: assert that there's only one valid entry (one replace_en)
320 en = []
321 for i in range(TLB_ENTRIES):
322 en.append(tc[i].replace_en_i)
323 m.d.comb += [Cat(*en).eq(p.replace_en_o), # output from PLRU into tags
324 p.lu_hit.eq(hitsel.i),
325 p.lu_access_i.eq(self.lu_access_i)]
326
327 #--------------
328 # Sanity checks
329 #--------------
330
331 assert (TLB_ENTRIES % 2 == 0) and (TLB_ENTRIES > 1), \
332 "TLB size must be a multiple of 2 and greater than 1"
333 assert (ASID_WIDTH >= 1), \
334 "ASID width must be at least 1"
335
336 return m
337
338 """
339 # Just for checking
340 function int countSetBits(logic[TLB_ENTRIES-1:0] vector);
341 automatic int count = 0;
342 foreach (vector[idx]) begin
343 count += vector[idx];
344 end
345 return count;
346 endfunction
347
348 assert property (@(posedge clk_i)(countSetBits(lu_hit) <= 1))
349 else $error("More then one hit in TLB!"); $stop(); end
350 assert property (@(posedge clk_i)(countSetBits(replace_en) <= 1))
351 else $error("More then one TLB entry selected for next replace!");
352 """
353
354 def ports(self):
355 return [self.flush_i, self.lu_access_i,
356 self.lu_asid_i, self.lu_vaddr_i,
357 self.lu_is_2M_o, self.lu_is_1G_o, self.lu_hit_o,
358 ] + self.lu_content_o.ports() + self.update_i.ports()
359
360 if __name__ == '__main__':
361 tlb = TLB()
362 vl = rtlil.convert(tlb, ports=tlb.ports())
363 with open("test_tlb.il", "w") as f:
364 f.write(vl)
365