nir/constant_expressions: Refactor helper functions
authorJason Ekstrand <jason.ekstrand@intel.com>
Thu, 9 Mar 2017 04:23:05 +0000 (20:23 -0800)
committerJason Ekstrand <jason.ekstrand@intel.com>
Tue, 14 Mar 2017 14:36:40 +0000 (07:36 -0700)
Apart from avoiding some unneeded size cases, this shouldn't have any
actual functional impact.

Reviewed-by: Dylan Baker <dylan@pnwbakers.com>
Reviewed-by: Lionel Landwerlin <lionel.g.landwerlin@intel.com>
src/compiler/nir/nir_constant_expressions.py

index 3da20fd503b6e71e94916e3969036e07f0afa33d..c6745f1e9340e954dba984f7d8fa0ec7359fcb07 100644 (file)
@@ -1,16 +1,18 @@
 
+import re
+
+type_split_re = re.compile(r'(?P<type>[a-z]+)(?P<bits>\d+)')
+
 def type_has_size(type_):
     return type_[-1:].isdigit()
 
+def type_size(type_):
+    assert type_has_size(type_)
+    return int(type_split_re.match(type_).group('bits'))
+
 def type_sizes(type_):
-    if type_.endswith("8"):
-        return [8]
-    elif type_.endswith("16"):
-        return [16]
-    elif type_.endswith("32"):
-        return [32]
-    elif type_.endswith("64"):
-        return [64]
+    if type_has_size(type_):
+        return [type_size(type_)]
     else:
         return [32, 64]
 
@@ -19,23 +21,23 @@ def type_add_size(type_, size):
         return type_
     return type_ + str(size)
 
+def op_bit_sizes(op):
+    sizes = set([8, 16, 32, 64])
+    if not type_has_size(op.output_type):
+        sizes = sizes.intersection(set(type_sizes(op.output_type)))
+    for input_type in op.input_types:
+        if not type_has_size(input_type):
+            sizes = sizes.intersection(set(type_sizes(input_type)))
+    return sorted(list(sizes))
+
 def get_const_field(type_):
-    if type_ == "int32":
-        return "i32"
-    if type_ == "uint32":
-        return "u32"
-    if type_ == "int64":
-        return "i64"
-    if type_ == "uint64":
-        return "u64"
     if type_ == "bool32":
         return "u32"
-    if type_ == "float32":
-        return "f32"
-    if type_ == "float64":
-        return "f64"
-    raise Exception(str(type_))
-    assert(0)
+    else:
+        m = type_split_re.match(type_)
+        if not m:
+            raise Exception(str(type_))
+        return m.group('type')[0] + m.group('bits')
 
 template = """\
 /*
@@ -247,7 +249,7 @@ typedef float float32_t;
 typedef double float64_t;
 typedef bool bool32_t;
 % for type in ["float", "int", "uint"]:
-% for width in [32, 64]:
+% for width in type_sizes(type):
 struct ${type}${width}_vec {
    ${type}${width}_t x;
    ${type}${width}_t y;
@@ -272,7 +274,7 @@ evaluate_${name}(MAYBE_UNUSED unsigned num_components, unsigned bit_size,
    nir_const_value _dst_val = { {0, } };
 
    switch (bit_size) {
-   % for bit_size in [32, 64]:
+   % for bit_size in op_bit_sizes(op):
    case ${bit_size}: {
       <%
       output_type = type_add_size(op.output_type, bit_size)
@@ -406,4 +408,5 @@ from mako.template import Template
 print Template(template).render(opcodes=opcodes, type_sizes=type_sizes,
                                 type_has_size=type_has_size,
                                 type_add_size=type_add_size,
+                                op_bit_sizes=op_bit_sizes,
                                 get_const_field=get_const_field)