base: Tag API methods and variables in trie.hh
[gem5.git] / src / base / trie.hh
1 /*
2 * Copyright (c) 2012 Google
3 * All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are
7 * met: redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer;
9 * redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution;
12 * neither the name of the copyright holders nor the names of its
13 * contributors may be used to endorse or promote products derived from
14 * this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
20 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
21 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
22 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
23 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
24 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 */
28
29 #ifndef __BASE_TRIE_HH__
30 #define __BASE_TRIE_HH__
31
32 #include <cassert>
33 #include <iostream>
34 #include <type_traits>
35
36 #include "base/cprintf.hh"
37 #include "base/logging.hh"
38 #include "base/types.hh"
39
40 /**
41 * A trie is a tree-based data structure used for data retrieval. It uses
42 * bits masked from the msb of the key to to determine a value's location,
43 * so its lookups have their worst case time dictated by the key's size.
44 *
45 * @tparam Key Type of the key of the tree nodes. Must be an integral type.
46 * @tparam Value Type of the values associated to the keys.
47 *
48 * @ingroup api_base_utils
49 */
50 template <class Key, class Value>
51 class Trie
52 {
53 protected:
54 static_assert(std::is_integral<Key>::value,
55 "Key has to be an integral type");
56
57 struct Node
58 {
59 Key key;
60 Key mask;
61
62 bool
63 matches(Key test)
64 {
65 return (test & mask) == key;
66 }
67
68 Value *value;
69
70 Node *parent;
71 Node *kids[2];
72
73 Node(Key _key, Key _mask, Value *_val) :
74 key(_key & _mask), mask(_mask), value(_val),
75 parent(NULL)
76 {
77 kids[0] = NULL;
78 kids[1] = NULL;
79 }
80
81 void
82 clear()
83 {
84 if (kids[1]) {
85 kids[1]->clear();
86 delete kids[1];
87 kids[1] = NULL;
88 }
89 if (kids[0]) {
90 kids[0]->clear();
91 delete kids[0];
92 kids[0] = NULL;
93 }
94 }
95
96 void
97 dump(std::ostream &os, int level)
98 {
99 for (int i = 1; i < level; i++) {
100 ccprintf(os, "|");
101 }
102 if (level == 0)
103 ccprintf(os, "Root ");
104 else
105 ccprintf(os, "+ ");
106 ccprintf(os, "(%p, %p, %#X, %#X, %p)\n",
107 parent, this, key, mask, value);
108 if (kids[0])
109 kids[0]->dump(os, level + 1);
110 if (kids[1])
111 kids[1]->dump(os, level + 1);
112 }
113 };
114
115 protected:
116 Node head;
117
118 public:
119 /**
120 * @ingroup api_base_utils
121 */
122 typedef Node *Handle;
123
124 /**
125 * @ingroup api_base_utils
126 */
127 Trie() : head(0, 0, NULL)
128 {}
129
130 /**
131 * @ingroup api_base_utils
132 */
133 static const unsigned MaxBits = sizeof(Key) * 8;
134
135 private:
136 /**
137 * A utility method which checks whether the key being looked up lies
138 * beyond the Node being examined. If so, it returns true and advances the
139 * node being examined.
140 * @param parent The node we're currently "at", which can be updated.
141 * @param kid The node we may want to move to.
142 * @param key The key we're looking for.
143 * @param new_mask The mask to use when matching against the key.
144 * @return Whether the current Node was advanced.
145 */
146 bool
147 goesAfter(Node **parent, Node *kid, Key key, Key new_mask)
148 {
149 if (kid && kid->matches(key) && (kid->mask & new_mask) == kid->mask) {
150 *parent = kid;
151 return true;
152 } else {
153 return false;
154 }
155 }
156
157 /**
158 * A utility method which extends a mask value one more bit towards the
159 * lsb. This is almost just a signed right shift, except that the shifted
160 * in bits are technically undefined. This is also slightly complicated by
161 * the zero case.
162 * @param orig The original mask to extend.
163 * @return The extended mask.
164 */
165 Key
166 extendMask(Key orig)
167 {
168 // Just in case orig was 0.
169 const Key msb = ULL(1) << (MaxBits - 1);
170 return orig | (orig >> 1) | msb;
171 }
172
173 /**
174 * Method which looks up the Handle corresponding to a particular key. This
175 * is useful if you want to delete the Handle corresponding to a key since
176 * the "remove" function takes a Handle as its argument.
177 * @param key The key to look up.
178 * @return The first Handle matching this key, or NULL if none was found.
179 */
180 Handle
181 lookupHandle(Key key)
182 {
183 Node *node = &head;
184 while (node) {
185 if (node->value)
186 return node;
187
188 if (node->kids[0] && node->kids[0]->matches(key))
189 node = node->kids[0];
190 else if (node->kids[1] && node->kids[1]->matches(key))
191 node = node->kids[1];
192 else
193 node = NULL;
194 }
195
196 return NULL;
197 }
198
199 public:
200 /**
201 * Method which inserts a key/value pair into the trie.
202 * @param key The key which can later be used to look up this value.
203 * @param width How many bits of the key (from msb) should be used.
204 * @param val A pointer to the value to store in the trie.
205 * @return A Handle corresponding to this value.
206 *
207 * @ingroup api_base_utils
208 */
209 Handle
210 insert(Key key, unsigned width, Value *val)
211 {
212 // We use NULL value pointers to mark internal nodes of the trie, so
213 // we don't allow inserting them as real values.
214 assert(val);
215
216 // Build a mask which masks off all the bits we don't care about.
217 Key new_mask = ~(Key)0;
218 if (width < MaxBits)
219 new_mask <<= (MaxBits - width);
220 // Use it to tidy up the key.
221 key &= new_mask;
222
223 // Walk past all the nodes this new node will be inserted after. They
224 // can be ignored for the purposes of this function.
225 Node *node = &head;
226 while (goesAfter(&node, node->kids[0], key, new_mask) ||
227 goesAfter(&node, node->kids[1], key, new_mask))
228 {}
229 assert(node);
230
231 Key cur_mask = node->mask;
232 // If we're already where the value needs to be...
233 if (cur_mask == new_mask) {
234 assert(!node->value);
235 node->value = val;
236 return node;
237 }
238
239 for (unsigned int i = 0; i < 2; i++) {
240 Node *&kid = node->kids[i];
241 Node *new_node;
242 if (!kid) {
243 // No kid. Add a new one.
244 new_node = new Node(key, new_mask, val);
245 new_node->parent = node;
246 kid = new_node;
247 return new_node;
248 }
249
250 // Walk down the leg until something doesn't match or we run out
251 // of bits.
252 Key last_mask;
253 bool done;
254 do {
255 last_mask = cur_mask;
256 cur_mask = extendMask(cur_mask);
257 done = ((key & cur_mask) != (kid->key & cur_mask)) ||
258 last_mask == new_mask;
259 } while (!done);
260 cur_mask = last_mask;
261
262 // If this isn't the right leg to go down at all, skip it.
263 if (cur_mask == node->mask)
264 continue;
265
266 // At the point we walked to above, add a new node.
267 new_node = new Node(key, cur_mask, NULL);
268 new_node->parent = node;
269 kid->parent = new_node;
270 new_node->kids[0] = kid;
271 kid = new_node;
272
273 // If we ran out of bits, the value goes right here.
274 if (cur_mask == new_mask) {
275 new_node->value = val;
276 return new_node;
277 }
278
279 // Still more bits to deal with, so add a new node for that path.
280 new_node = new Node(key, new_mask, val);
281 new_node->parent = kid;
282 kid->kids[1] = new_node;
283 return new_node;
284 }
285
286 panic("Reached the end of the Trie insert function!\n");
287 return NULL;
288 }
289
290 /**
291 * Method which looks up the Value corresponding to a particular key.
292 * @param key The key to look up.
293 * @return The first Value matching this key, or NULL if none was found.
294 *
295 * @ingroup api_base_utils
296 */
297 Value *
298 lookup(Key key)
299 {
300 Node *node = lookupHandle(key);
301 if (node)
302 return node->value;
303 else
304 return NULL;
305 }
306
307 /**
308 * Method to delete a value from the trie.
309 * @param node A Handle to remove.
310 * @return The Value pointer from the removed entry.
311 *
312 * @ingroup api_base_utils
313 */
314 Value *
315 remove(Handle handle)
316 {
317 Node *node = handle;
318 Value *val = node->value;
319 if (node->kids[1]) {
320 assert(node->value);
321 node->value = NULL;
322 return val;
323 }
324 if (!node->parent)
325 panic("Trie: Can't remove root node.\n");
326
327 Node *parent = node->parent;
328
329 // If there's a kid, fix up it's parent pointer.
330 if (node->kids[0])
331 node->kids[0]->parent = parent;
332 // Figure out which kid we are, and update our parent's pointers.
333 if (parent->kids[0] == node)
334 parent->kids[0] = node->kids[0];
335 else if (parent->kids[1] == node)
336 parent->kids[1] = node->kids[0];
337 else
338 panic("Trie: Inconsistent parent/kid relationship.\n");
339 // Make sure if the parent only has one kid, it's kid[0].
340 if (parent->kids[1] && !parent->kids[0]) {
341 parent->kids[0] = parent->kids[1];
342 parent->kids[1] = NULL;
343 }
344
345 // If the parent has less than two kids and no cargo and isn't the
346 // root, delete it too.
347 if (!parent->kids[1] && !parent->value && parent->parent)
348 remove(parent);
349 delete node;
350 return val;
351 }
352
353 /**
354 * Method to lookup a value from the trie and then delete it.
355 * @param key The key to look up and then remove.
356 * @return The Value pointer from the removed entry, if any.
357 *
358 * @ingroup api_base_utils
359 */
360 Value *
361 remove(Key key)
362 {
363 Handle handle = lookupHandle(key);
364 if (!handle)
365 return NULL;
366 return remove(handle);
367 }
368
369 /**
370 * A method which removes all key/value pairs from the trie. This is more
371 * efficient than trying to remove elements individually.
372 *
373 * @ingroup api_base_utils
374 */
375 void
376 clear()
377 {
378 head.clear();
379 }
380
381 /**
382 * A debugging method which prints the contents of this trie.
383 * @param title An identifying title to put in the dump header.
384 */
385 void
386 dump(const char *title, std::ostream &os=std::cout)
387 {
388 ccprintf(os, "**************************************************\n");
389 ccprintf(os, "*** Start of Trie: %s\n", title);
390 ccprintf(os, "*** (parent, me, key, mask, value pointer)\n");
391 ccprintf(os, "**************************************************\n");
392 head.dump(os, 0);
393 }
394 };
395
396 #endif