anv/entrypoints: Generalize the string map a bit
authorJason Ekstrand <jason.ekstrand@intel.com>
Wed, 20 Sep 2017 15:25:05 +0000 (08:25 -0700)
committerJason Ekstrand <jason.ekstrand@intel.com>
Wed, 7 Mar 2018 20:13:47 +0000 (12:13 -0800)
The original string map assumed that the mapping from strings to
entrypoints was a bijection.  This will not be true the moment we
add entrypoint aliasing.  This reworks things to be an arbitrary map
from strings to non-negative signed integers.  The old one also had a
potential bug if we ever had a hash collision because it didn't do the
strcmp inside the lookup loop.  While we're at it, we break things out
into a helpful class.

Reviewed-by: Lionel Landwerlin <lionel.g.landwerlin@intel.com>
Reviewed-by: Samuel Iglesias Gonsálvez <siglesias@igalia.com>
src/intel/vulkan/anv_entrypoints_gen.py

index 34ffedb1165fd8f2f3bed37b4991d4baa1962edd..dc0f0e968cab10962b9b33033ddffdc130ee86c1 100644 (file)
@@ -115,9 +115,10 @@ TEMPLATE_C = Template(u"""\
 
 #include "anv_private.h"
 
-struct anv_entrypoint {
+struct string_map_entry {
    uint32_t name;
    uint32_t hash;
+   uint32_t num;
 };
 
 /* We use a big string constant to avoid lots of reloctions from the entry
@@ -126,17 +127,60 @@ struct anv_entrypoint {
  */
 
 static const char strings[] =
-% for e in entrypoints:
-    "${e.name}\\0"
+% for s in strmap.sorted_strings:
+    "${s.string}\\0"
 % endfor
 ;
 
-static const struct anv_entrypoint entrypoints[] = {
-% for e in entrypoints:
-    [${e.num}] = { ${offsets[e.num]}, ${'{:0=#8x}'.format(e.get_c_hash())} }, /* ${e.name} */
+static const struct string_map_entry string_map_entries[] = {
+% for s in strmap.sorted_strings:
+    { ${s.offset}, ${'{:0=#8x}'.format(s.hash)}, ${s.num} }, /* ${s.string} */
 % endfor
 };
 
+/* Hash table stats:
+ * size ${len(strmap.sorted_strings)} entries
+ * collisions entries:
+% for i in xrange(10):
+ *     ${i}${'+' if i == 9 else ' '}     ${strmap.collisions[i]}
+% endfor
+ */
+
+#define none 0xffff
+static const uint16_t string_map[${strmap.hash_size}] = {
+% for e in strmap.mapping:
+    ${ '{:0=#6x}'.format(e) if e >= 0 else 'none' },
+% endfor
+};
+
+static int
+string_map_lookup(const char *str)
+{
+    static const uint32_t prime_factor = ${strmap.prime_factor};
+    static const uint32_t prime_step = ${strmap.prime_step};
+    const struct string_map_entry *e;
+    uint32_t hash, h;
+    uint16_t i;
+    const char *p;
+
+    hash = 0;
+    for (p = str; *p; p++)
+        hash = hash * prime_factor + *p;
+
+    h = hash;
+    while (1) {
+        i = string_map[h & ${strmap.hash_mask}];
+        if (i == none)
+           return -1;
+        e = &string_map_entries[i];
+        if (e->hash == hash && strcmp(str, strings + e->name) == 0)
+            return e->num;
+        h += prime_step;
+    }
+
+    return -1;
+}
+
 /* Weak aliases for all potential implementations. These will resolve to
  * NULL if they're not defined, which lets the resolve_entrypoint() function
  * either pick the correct entry point.
@@ -275,54 +319,10 @@ anv_resolve_entrypoint(const struct gen_device_info *devinfo, uint32_t index)
       return anv_dispatch_table.entrypoints[index];
 }
 
-/* Hash table stats:
- * size ${hash_size} entries
- * collisions entries:
-% for i in xrange(10):
- *     ${i}${'+' if i == 9 else ''}     ${collisions[i]}
-% endfor
- */
-
-#define none ${'{:#x}'.format(none)}
-static const uint16_t map[] = {
-% for i in xrange(0, hash_size, 8):
-  % for j in xrange(i, i + 8):
-    ## This is 6 because the 0x is counted in the length
-    % if mapping[j] & 0xffff == 0xffff:
-      none,
-    % else:
-      ${'{:0=#6x}'.format(mapping[j] & 0xffff)},
-    % endif
-  % endfor
-% endfor
-};
-
 int
 anv_get_entrypoint_index(const char *name)
 {
-   static const uint32_t prime_factor = ${prime_factor};
-   static const uint32_t prime_step = ${prime_step};
-   const struct anv_entrypoint *e;
-   uint32_t hash, h, i;
-   const char *p;
-
-   hash = 0;
-   for (p = name; *p; p++)
-      hash = hash * prime_factor + *p;
-
-   h = hash;
-   do {
-      i = map[h & ${hash_mask}];
-      if (i == none)
-         return -1;
-      e = &entrypoints[i];
-      h += prime_step;
-   } while (e->hash != hash);
-
-   if (strcmp(name, strings + e->name) != 0)
-      return -1;
-
-   return i;
+   return string_map_lookup(name);
 }
 
 void *
@@ -334,7 +334,6 @@ anv_lookup_entrypoint(const struct gen_device_info *devinfo, const char *name)
    return anv_resolve_entrypoint(devinfo, idx);
 }""", output_encoding='utf-8')
 
-NONE = 0xffff
 HASH_SIZE = 256
 U32_MASK = 2**32 - 1
 HASH_MASK = HASH_SIZE - 1
@@ -342,11 +341,54 @@ HASH_MASK = HASH_SIZE - 1
 PRIME_FACTOR = 5024183
 PRIME_STEP = 19
 
-
-def cal_hash(name):
-    """Calculate the same hash value that Mesa will calculate in C."""
-    return functools.reduce(
-        lambda h, c: (h * PRIME_FACTOR + ord(c)) & U32_MASK, name, 0)
+class StringIntMapEntry(object):
+    def __init__(self, string, num):
+        self.string = string
+        self.num = num
+
+        # Calculate the same hash value that we will calculate in C.
+        h = 0
+        for c in string:
+            h = ((h * PRIME_FACTOR) + ord(c)) & U32_MASK
+        self.hash = h
+
+        self.offset = None
+
+class StringIntMap(object):
+    def __init__(self):
+        self.baked = False
+        self.strings = dict()
+
+    def add_string(self, string, num):
+        assert not self.baked
+        assert string not in self.strings
+        assert num >= 0 and num < 2**31
+        self.strings[string] = StringIntMapEntry(string, num)
+
+    def bake(self):
+        self.sorted_strings = \
+            sorted(self.strings.values(), key=lambda x: x.string)
+        offset = 0
+        for entry in self.sorted_strings:
+            entry.offset = offset
+            offset += len(entry.string) + 1
+
+        # Save off some values that we'll need in C
+        self.hash_size = HASH_SIZE
+        self.hash_mask = HASH_SIZE - 1
+        self.prime_factor = PRIME_FACTOR
+        self.prime_step = PRIME_STEP
+
+        self.mapping = [-1] * self.hash_size
+        self.collisions = [0] * 10
+        for idx, s in enumerate(self.sorted_strings):
+            level = 0
+            h = s.hash
+            while self.mapping[h & self.hash_mask] >= 0:
+                h = h + PRIME_STEP
+                level = level + 1
+            self.collisions[min(level, 9)] += 1
+            self.mapping[h & self.hash_mask] = idx
 
 EntrypointParam = namedtuple('EntrypointParam', 'type name decl')
 
@@ -372,9 +414,6 @@ class Entrypoint(object):
     def call_params(self):
         return ', '.join(p.name for p in self.params)
 
-    def get_c_hash(self):
-        return cal_hash(self.name)
-
 def get_entrypoints(doc, entrypoints_to_defines, start_index):
     """Extract the entry points from the registry."""
     entrypoints = OrderedDict()
@@ -443,36 +482,15 @@ def get_entrypoints_defines(doc):
 
 def gen_code(entrypoints):
     """Generate the C code."""
-    i = 0
-    offsets = []
-    for e in entrypoints:
-        offsets.append(i)
-        i += len(e.name) + 1
 
-    mapping = [NONE] * HASH_SIZE
-    collisions = [0] * 10
+    strmap = StringIntMap()
     for e in entrypoints:
-        level = 0
-        h = e.get_c_hash()
-        while mapping[h & HASH_MASK] != NONE:
-            h = h + PRIME_STEP
-            level = level + 1
-        if level > 9:
-            collisions[9] += 1
-        else:
-            collisions[level] += 1
-        mapping[h & HASH_MASK] = e.num
+        strmap.add_string(e.name, e.num)
+    strmap.bake()
 
     return TEMPLATE_C.render(entrypoints=entrypoints,
                              LAYERS=LAYERS,
-                             offsets=offsets,
-                             collisions=collisions,
-                             mapping=mapping,
-                             hash_mask=HASH_MASK,
-                             prime_step=PRIME_STEP,
-                             prime_factor=PRIME_FACTOR,
-                             none=NONE,
-                             hash_size=HASH_SIZE,
+                             strmap=strmap,
                              filename=os.path.basename(__file__))