nir/opcodes: Pull in the type helpers from constant_expressions
authorJason Ekstrand <jason.ekstrand@intel.com>
Wed, 7 Nov 2018 18:15:22 +0000 (12:15 -0600)
committerJason Ekstrand <jason.ekstrand@intel.com>
Wed, 5 Dec 2018 21:02:06 +0000 (15:02 -0600)
While we're at it, we rework them a bit to all use regular expressions
and assert more.

Reviewed-by: Connor Abbott <cwabbott0@gmail.com>
src/compiler/nir/nir_constant_expressions.py
src/compiler/nir/nir_opcodes.py
src/compiler/nir/nir_opcodes_c.py

index 118af9f7818accdfe26ca43247c9c31df11f1ddc..0cd4ffcf5580cdc8081f5436edd29d4c1781c6e6 100644 (file)
@@ -1,23 +1,8 @@
 from __future__ import print_function
 
 import re
-
-type_split_re = re.compile(r'(?P<type>[a-z]+)(?P<bits>\d+)')
-
-def type_has_size(type_):
-    return type_[-1:].isdigit()
-
-def type_size(type_):
-    assert type_has_size(type_)
-    return int(type_split_re.match(type_).group('bits'))
-
-def type_sizes(type_):
-    if type_has_size(type_):
-        return [type_size(type_)]
-    elif type_ == 'float':
-        return [16, 32, 64]
-    else:
-        return [8, 16, 32, 64]
+from nir_opcodes import opcodes
+from nir_opcodes import type_has_size, type_size, type_sizes, type_base_type
 
 def type_add_size(type_, size):
     if type_has_size(type_):
@@ -44,10 +29,7 @@ def get_const_field(type_):
     elif type_ == "float16":
         return "u16"
     else:
-        m = type_split_re.match(type_)
-        if not m:
-            raise Exception(str(type_))
-        return m.group('type')[0] + m.group('bits')
+        return type_base_type(type_)[0] + str(type_size(type_))
 
 template = """\
 /*
@@ -429,7 +411,6 @@ nir_eval_const_opcode(nir_op op, unsigned num_components,
    }
 }"""
 
-from nir_opcodes import opcodes
 from mako.template import Template
 
 print(Template(template).render(opcodes=opcodes, type_sizes=type_sizes,
index 4ef4ecc6f221ba42bf809050f0801910823b9076..d69d09d30ce655be52d111ed2b909aad58a470c9 100644 (file)
@@ -23,6 +23,7 @@
 # Authors:
 #    Connor Abbott (cwabbott0@gmail.com)
 
+import re
 
 # Class that represents all the information we have about the opcode
 # NOTE: this must be kept in sync with nir_op_info
@@ -99,6 +100,33 @@ tint64 = "int64"
 tuint64 = "uint64"
 tfloat64 = "float64"
 
+_TYPE_SPLIT_RE = re.compile(r'(?P<type>int|uint|float|bool)(?P<bits>\d+)?')
+
+def type_has_size(type_):
+    m = _TYPE_SPLIT_RE.match(type_)
+    assert m is not None, 'Invalid NIR type string: "{}"'.format(type_)
+    return m.group('bits') is not None
+
+def type_size(type_):
+    m = _TYPE_SPLIT_RE.match(type_)
+    assert m is not None, 'Invalid NIR type string: "{}"'.format(type_)
+    assert m.group('bits') is not None, \
+           'NIR type string has no bit size: "{}"'.format(type_)
+    return int(m.group('bits'))
+
+def type_sizes(type_):
+    if type_has_size(type_):
+        return [type_size(type_)]
+    elif type_ == 'float':
+        return [16, 32, 64]
+    else:
+        return [8, 16, 32, 64]
+
+def type_base_type(type_):
+    m = _TYPE_SPLIT_RE.match(type_)
+    assert m is not None, 'Invalid NIR type string: "{}"'.format(type_)
+    return m.group('type')
+
 commutative = "commutative "
 associative = "associative "
 
@@ -175,11 +203,7 @@ for src_t in [tint, tuint, tfloat]:
       dst_types = [tint, tuint, tfloat]
 
    for dst_t in dst_types:
-      if dst_t == tfloat:
-         bit_sizes = [16, 32, 64]
-      else:
-         bit_sizes = [8, 16, 32, 64]
-      for bit_size in bit_sizes:
+      for bit_size in type_sizes(dst_t):
           if bit_size == 16 and dst_t == tfloat and src_t == tfloat:
               rnd_modes = ['_rtne', '_rtz', '']
               for rnd_mode in rnd_modes:
index 8bfcda6d719cf3abd27bdc3e8267c9280677e54c..9e3c06b8634c88d6bc27ebb80bf9d5c4394f5847 100644 (file)
@@ -25,7 +25,7 @@
 
 from __future__ import print_function
 
-from nir_opcodes import opcodes
+from nir_opcodes import opcodes, type_sizes
 from mako.template import Template
 
 template = Template("""
@@ -64,12 +64,7 @@ nir_type_conversion_op(nir_alu_type src, nir_alu_type dst, nir_rounding_mode rnd
 %                 endif
 %              endif
                switch (dst_bit_size) {
-%                 if dst_t == 'float':
-<%                    bit_sizes = [16, 32, 64] %>
-%                 else:
-<%                    bit_sizes = [8, 16, 32, 64] %>
-%                 endif
-%                 for dst_bits in bit_sizes:
+%                 for dst_bits in type_sizes(dst_t):
                   case ${dst_bits}:
 %                    if src_t == 'float' and dst_t == 'float' and dst_bits == 16:
                      switch(rnd) {
@@ -137,4 +132,4 @@ const nir_op_info nir_op_infos[nir_num_opcodes] = {
 };
 """)
 
-print(template.render(opcodes=opcodes))
+print(template.render(opcodes=opcodes, type_sizes=type_sizes))