nir: rename variables in nir_lower_io_to_temporaries for clarity
[mesa.git] / src / compiler / nir / nir_constant_expressions.py
index ad841e3d31118e2d535fda587c5f3d8b672a0982..ee92be51dbeb55d84e22f3fb58f65285bee520f0 100644 (file)
@@ -13,8 +13,10 @@ def type_size(type_):
 def type_sizes(type_):
     if type_has_size(type_):
         return [type_size(type_)]
+    elif type_ == 'float':
+        return [16, 32, 64]
     else:
-        return [32, 64]
+        return [8, 16, 32, 64]
 
 def type_add_size(type_, size):
     if type_has_size(type_):
@@ -22,17 +24,24 @@ def type_add_size(type_, size):
     return type_ + str(size)
 
 def op_bit_sizes(op):
-    sizes = set([8, 16, 32, 64])
+    sizes = None
     if not type_has_size(op.output_type):
-        sizes = sizes.intersection(set(type_sizes(op.output_type)))
+        sizes = set(type_sizes(op.output_type))
+
     for input_type in op.input_types:
         if not type_has_size(input_type):
-            sizes = sizes.intersection(set(type_sizes(input_type)))
-    return sorted(list(sizes))
+            if sizes is None:
+                sizes = set(type_sizes(input_type))
+            else:
+                sizes = sizes.intersection(set(type_sizes(input_type)))
+
+    return sorted(list(sizes)) if sizes is not None else None
 
 def get_const_field(type_):
     if type_ == "bool32":
         return "u32"
+    elif type_ == "float16":
+        return "u16"
     else:
         m = type_split_re.match(type_)
         if not m:
@@ -245,6 +254,7 @@ unpack_half_1x16(uint16_t u)
 }
 
 /* Some typed vector structures to make things like src0.y work */
+typedef float float16_t;
 typedef float float32_t;
 typedef double float64_t;
 typedef bool bool32_t;
@@ -287,6 +297,8 @@ struct bool32_vec {
       % for k in range(op.input_sizes[j]):
          % if input_types[j] == "bool32":
             _src[${j}].u32[${k}] != 0,
+         % elif input_types[j] == "float16":
+            _mesa_half_to_float(_src[${j}].u16[${k}]),
          % else:
             _src[${j}].${get_const_field(input_types[j])}[${k}],
          % endif
@@ -312,6 +324,9 @@ struct bool32_vec {
                <% continue %>
             % elif input_types[j] == "bool32":
                const bool src${j} = _src[${j}].u32[_i] != 0;
+            % elif input_types[j] == "float16":
+               const float src${j} =
+                  _mesa_half_to_float(_src[${j}].u16[_i]);
             % else:
                const ${input_types[j]}_t src${j} =
                   _src[${j}].${get_const_field(input_types[j])}[_i];
@@ -334,6 +349,8 @@ struct bool32_vec {
          % if output_type == "bool32":
             ## Sanitize the C value to a proper NIR bool
             _dst_val.u32[_i] = dst ? NIR_TRUE : NIR_FALSE;
+         % elif output_type == "float16":
+            _dst_val.u16[_i] = _mesa_float_to_half(dst);
          % else:
             _dst_val.${get_const_field(output_type)}[_i] = dst;
          % endif
@@ -361,6 +378,8 @@ struct bool32_vec {
          % if output_type == "bool32":
             ## Sanitize the C value to a proper NIR bool
             _dst_val.u32[${k}] = dst.${"xyzw"[k]} ? NIR_TRUE : NIR_FALSE;
+         % elif output_type == "float16":
+            _dst_val.u16[${k}] = _mesa_float_to_half(dst.${"xyzw"[k]});
          % else:
             _dst_val.${get_const_field(output_type)}[${k}] = dst.${"xyzw"[k]};
          % endif
@@ -370,22 +389,27 @@ struct bool32_vec {
 
 % for name, op in sorted(opcodes.iteritems()):
 static nir_const_value
-evaluate_${name}(MAYBE_UNUSED unsigned num_components, unsigned bit_size,
+evaluate_${name}(MAYBE_UNUSED unsigned num_components,
+                 ${"UNUSED" if op_bit_sizes(op) is None else ""} unsigned bit_size,
                  MAYBE_UNUSED nir_const_value *_src)
 {
    nir_const_value _dst_val = { {0, } };
 
-   switch (bit_size) {
-   % for bit_size in op_bit_sizes(op):
-   case ${bit_size}: {
-      ${evaluate_op(op, bit_size)}
-      break;
-   }
-   % endfor
+   % if op_bit_sizes(op) is not None:
+      switch (bit_size) {
+      % for bit_size in op_bit_sizes(op):
+      case ${bit_size}: {
+         ${evaluate_op(op, bit_size)}
+         break;
+      }
+      % endfor
 
-   default:
-      unreachable("unknown bit width");
-   }
+      default:
+         unreachable("unknown bit width");
+      }
+   % else:
+      ${evaluate_op(op, 0)}
+   % endif
 
    return _dst_val;
 }