nir: rename variables in nir_lower_io_to_temporaries for clarity
[mesa.git] / src / compiler / nir / nir_constant_expressions.py
index 32784f6398d5d184198cc444bfb6977bc5147906..ee92be51dbeb55d84e22f3fb58f65285bee520f0 100644 (file)
@@ -1,4 +1,53 @@
-#! /usr/bin/python2
+
+import re
+
+type_split_re = re.compile(r'(?P<type>[a-z]+)(?P<bits>\d+)')
+
+def type_has_size(type_):
+    return type_[-1:].isdigit()
+
+def type_size(type_):
+    assert type_has_size(type_)
+    return int(type_split_re.match(type_).group('bits'))
+
+def type_sizes(type_):
+    if type_has_size(type_):
+        return [type_size(type_)]
+    elif type_ == 'float':
+        return [16, 32, 64]
+    else:
+        return [8, 16, 32, 64]
+
+def type_add_size(type_, size):
+    if type_has_size(type_):
+        return type_
+    return type_ + str(size)
+
+def op_bit_sizes(op):
+    sizes = None
+    if not type_has_size(op.output_type):
+        sizes = set(type_sizes(op.output_type))
+
+    for input_type in op.input_types:
+        if not type_has_size(input_type):
+            if sizes is None:
+                sizes = set(type_sizes(input_type))
+            else:
+                sizes = sizes.intersection(set(type_sizes(input_type)))
+
+    return sorted(list(sizes)) if sizes is not None else None
+
+def get_const_field(type_):
+    if type_ == "bool32":
+        return "u32"
+    elif type_ == "float16":
+        return "u16"
+    else:
+        m = type_split_re.match(type_)
+        if not m:
+            raise Exception(str(type_))
+        return m.group('type')[0] + m.group('bits')
+
 template = """\
 /*
  * Copyright (C) 2014 Intel Corporation
@@ -205,20 +254,33 @@ unpack_half_1x16(uint16_t u)
 }
 
 /* Some typed vector structures to make things like src0.y work */
-% for type in ["float", "int", "uint", "bool"]:
-struct ${type}_vec {
-   ${type} x;
-   ${type} y;
-   ${type} z;
-   ${type} w;
+typedef float float16_t;
+typedef float float32_t;
+typedef double float64_t;
+typedef bool bool32_t;
+% for type in ["float", "int", "uint"]:
+% for width in type_sizes(type):
+struct ${type}${width}_vec {
+   ${type}${width}_t x;
+   ${type}${width}_t y;
+   ${type}${width}_t z;
+   ${type}${width}_t w;
 };
 % endfor
+% endfor
 
-% for name, op in sorted(opcodes.iteritems()):
-static nir_const_value
-evaluate_${name}(unsigned num_components, nir_const_value *_src)
-{
-   nir_const_value _dst_val = { { {0, 0, 0, 0} } };
+struct bool32_vec {
+    bool x;
+    bool y;
+    bool z;
+    bool w;
+};
+
+<%def name="evaluate_op(op, bit_size)">
+   <%
+   output_type = type_add_size(op.output_type, bit_size)
+   input_types = [type_add_size(type_, bit_size) for type_ in op.input_types]
+   %>
 
    ## For each non-per-component input, create a variable srcN that
    ## contains x, y, z, and w elements which are filled in with the
@@ -231,14 +293,19 @@ evaluate_${name}(unsigned num_components, nir_const_value *_src)
          <% continue %>
       %endif
 
-      struct ${op.input_types[j]}_vec src${j} = {
+      const struct ${input_types[j]}_vec src${j} = {
       % for k in range(op.input_sizes[j]):
-         % if op.input_types[j] == "bool":
-            _src[${j}].u[${k}] != 0,
+         % if input_types[j] == "bool32":
+            _src[${j}].u32[${k}] != 0,
+         % elif input_types[j] == "float16":
+            _mesa_half_to_float(_src[${j}].u16[${k}]),
          % else:
-            _src[${j}].${op.input_types[j][:1]}[${k}],
+            _src[${j}].${get_const_field(input_types[j])}[${k}],
          % endif
       % endfor
+      % for k in range(op.input_sizes[j], 4):
+         0,
+      % endfor
       };
    % endfor
 
@@ -255,10 +322,14 @@ evaluate_${name}(unsigned num_components, nir_const_value *_src)
             % elif "src" + str(j) not in op.const_expr:
                ## Avoid unused variable warnings
                <% continue %>
-            % elif op.input_types[j] == "bool":
-               bool src${j} = _src[${j}].u[_i] != 0;
+            % elif input_types[j] == "bool32":
+               const bool src${j} = _src[${j}].u32[_i] != 0;
+            % elif input_types[j] == "float16":
+               const float src${j} =
+                  _mesa_half_to_float(_src[${j}].u16[_i]);
             % else:
-               ${op.input_types[j]} src${j} = _src[${j}].${op.input_types[j][:1]}[_i];
+               const ${input_types[j]}_t src${j} =
+                  _src[${j}].${get_const_field(input_types[j])}[_i];
             % endif
          % endfor
 
@@ -266,19 +337,22 @@ evaluate_${name}(unsigned num_components, nir_const_value *_src)
          ## result of the const_expr to it.  If const_expr already contains
          ## writes to dst, just include const_expr directly.
          % if "dst" in op.const_expr:
-            ${op.output_type} dst;
+            ${output_type}_t dst;
+
             ${op.const_expr}
          % else:
-            ${op.output_type} dst = ${op.const_expr};
+            ${output_type}_t dst = ${op.const_expr};
          % endif
 
          ## Store the current component of the actual destination to the
          ## value of dst.
-         % if op.output_type == "bool":
+         % if output_type == "bool32":
             ## Sanitize the C value to a proper NIR bool
-            _dst_val.u[_i] = dst ? NIR_TRUE : NIR_FALSE;
+            _dst_val.u32[_i] = dst ? NIR_TRUE : NIR_FALSE;
+         % elif output_type == "float16":
+            _dst_val.u16[_i] = _mesa_float_to_half(dst);
          % else:
-            _dst_val.${op.output_type[:1]}[_i] = dst;
+            _dst_val.${get_const_field(output_type)}[_i] = dst;
          % endif
       }
    % else:
@@ -286,7 +360,7 @@ evaluate_${name}(unsigned num_components, nir_const_value *_src)
       ## appropriately-typed elements x, y, z, and w and assign the result
       ## of the const_expr to all components of dst, or include the
       ## const_expr directly if it writes to dst already.
-      struct ${op.output_type}_vec dst;
+      struct ${output_type}_vec dst;
 
       % if "dst" in op.const_expr:
          ${op.const_expr}
@@ -301,14 +375,41 @@ evaluate_${name}(unsigned num_components, nir_const_value *_src)
       ## For each component in the destination, copy the value of dst to
       ## the actual destination.
       % for k in range(op.output_size):
-         % if op.output_type == "bool":
+         % if output_type == "bool32":
             ## Sanitize the C value to a proper NIR bool
-            _dst_val.u[${k}] = dst.${"xyzw"[k]} ? NIR_TRUE : NIR_FALSE;
+            _dst_val.u32[${k}] = dst.${"xyzw"[k]} ? NIR_TRUE : NIR_FALSE;
+         % elif output_type == "float16":
+            _dst_val.u16[${k}] = _mesa_float_to_half(dst.${"xyzw"[k]});
          % else:
-            _dst_val.${op.output_type[:1]}[${k}] = dst.${"xyzw"[k]};
+            _dst_val.${get_const_field(output_type)}[${k}] = dst.${"xyzw"[k]};
          % endif
       % endfor
    % endif
+</%def>
+
+% for name, op in sorted(opcodes.iteritems()):
+static nir_const_value
+evaluate_${name}(MAYBE_UNUSED unsigned num_components,
+                 ${"UNUSED" if op_bit_sizes(op) is None else ""} unsigned bit_size,
+                 MAYBE_UNUSED nir_const_value *_src)
+{
+   nir_const_value _dst_val = { {0, } };
+
+   % if op_bit_sizes(op) is not None:
+      switch (bit_size) {
+      % for bit_size in op_bit_sizes(op):
+      case ${bit_size}: {
+         ${evaluate_op(op, bit_size)}
+         break;
+      }
+      % endfor
+
+      default:
+         unreachable("unknown bit width");
+      }
+   % else:
+      ${evaluate_op(op, 0)}
+   % endif
 
    return _dst_val;
 }
@@ -316,14 +417,12 @@ evaluate_${name}(unsigned num_components, nir_const_value *_src)
 
 nir_const_value
 nir_eval_const_opcode(nir_op op, unsigned num_components,
-                      nir_const_value *src)
+                      unsigned bit_width, nir_const_value *src)
 {
    switch (op) {
 % for name in sorted(opcodes.iterkeys()):
-   case nir_op_${name}: {
-      return evaluate_${name}(num_components, src);
-      break;
-   }
+   case nir_op_${name}:
+      return evaluate_${name}(num_components, bit_width, src);
 % endfor
    default:
       unreachable("shouldn't get here");
@@ -333,4 +432,8 @@ nir_eval_const_opcode(nir_op op, unsigned num_components,
 from nir_opcodes import opcodes
 from mako.template import Template
 
-print Template(template).render(opcodes=opcodes)
+print Template(template).render(opcodes=opcodes, type_sizes=type_sizes,
+                                type_has_size=type_has_size,
+                                type_add_size=type_add_size,
+                                op_bit_sizes=op_bit_sizes,
+                                get_const_field=get_const_field)