nir: Add fisnormal op
[mesa.git] / src / compiler / nir / nir_opcodes.py
index 142d7a427d553a358f4500c86974206362235c8d..d880c51eebce7d0af7b60796f156128e31501303 100644 (file)
@@ -77,7 +77,7 @@ class Opcode(object):
       assert len(input_sizes) == len(input_types)
       assert 0 <= output_size <= 4 or (output_size == 8) or (output_size == 16)
       for size in input_sizes:
-         assert 0 <= size <= 4
+         assert 0 <= size <= 4 or (size == 8) or (size == 16)
          if output_size != 0:
             assert size != 0
       self.name = name
@@ -100,6 +100,7 @@ tbool16 = "bool16"
 tbool32 = "bool32"
 tuint = "uint"
 tuint8 = "uint8"
+tint16 = "int16"
 tuint16 = "uint16"
 tfloat16 = "float16"
 tfloat32 = "float32"
@@ -267,11 +268,13 @@ for src_t in [tint, tuint, tfloat, tbool]:
                                                        dst_bit_size),
                                    dst_t + str(dst_bit_size), src_t, conv_expr)
 
-# Special opcode that is the same as f2f16 except that it is safe to remove it
-# if the result is immediately converted back to float32 again. This is
-# generated as part of the precision lowering pass. mp stands for medium
+# Special opcode that is the same as f2f16, i2i16, u2u16 except that it is safe
+# to remove it if the result is immediately converted back to 32 bits again.
+# This is generated as part of the precision lowering pass. mp stands for medium
 # precision.
 unop_numeric_convert("f2fmp", tfloat16, tfloat, opcodes["f2f16"].const_expr)
+unop_numeric_convert("i2imp", tint16, tint, opcodes["i2i16"].const_expr)
+unop_numeric_convert("u2ump", tuint16, tuint, opcodes["u2u16"].const_expr)
 
 # Unary floating-point rounding operations.
 
@@ -484,8 +487,8 @@ if (src0.y < 0 && absY >= absX && absY >= absZ) { dst.x = src0.x; dst.y = -src0.
 if (src0.z >= 0 && absZ >= absX && absZ >= absY) { dst.x = src0.x; dst.y = -src0.y; }
 if (src0.z < 0 && absZ >= absX && absZ >= absY) { dst.x = -src0.x; dst.y = -src0.y; }
 
-dst.x = dst.x / ma + 0.5;
-dst.y = dst.y / ma + 0.5;
+dst.x = dst.x * (1.0f / ma) + 0.5f;
+dst.y = dst.y * (1.0f / ma) + 0.5f;
 """)
 
 unop_horiz("cube_face_index", 1, tfloat32, 3, tfloat32, """
@@ -541,19 +544,18 @@ def binop_reduce(name, output_size, output_type, src_type, prereduce_expr,
       return reduce_expr.format(src0=src0, src1=src1)
    def prereduce(src0, src1):
       return "(" + prereduce_expr.format(src0=src0, src1=src1) + ")"
-   src0 = prereduce("src0.x", "src1.x")
-   src1 = prereduce("src0.y", "src1.y")
-   src2 = prereduce("src0.z", "src1.z")
-   src3 = prereduce("src0.w", "src1.w")
-   opcode(name + "2", output_size, output_type,
-          [2, 2], [src_type, src_type], False, _2src_commutative,
-          final(reduce_(src0, src1)))
+   srcs = [prereduce("src0." + letter, "src1." + letter) for letter in "xyzwefghijklmnop"]
+   def pairwise_reduce(start, size):
+      if (size == 1):
+         return srcs[start]
+      return reduce_(pairwise_reduce(start, size // 2), pairwise_reduce(start + size // 2, size // 2))
+   for size in [2, 4, 8, 16]:
+      opcode(name + str(size), output_size, output_type,
+             [size, size], [src_type, src_type], False, _2src_commutative,
+             final(pairwise_reduce(0, size)))
    opcode(name + "3", output_size, output_type,
           [3, 3], [src_type, src_type], False, _2src_commutative,
-          final(reduce_(reduce_(src0, src1), src2)))
-   opcode(name + "4", output_size, output_type,
-          [4, 4], [src_type, src_type], False, _2src_commutative,
-          final(reduce_(reduce_(src0, src1), reduce_(src2, src3))))
+          final(reduce_(reduce_(srcs[0], srcs[1]), srcs[2])))
 
 def binop_reduce_all_sizes(name, output_size, src_type, prereduce_expr,
                            reduce_expr, final_expr):
@@ -1152,3 +1154,5 @@ triop("umad24", tuint32, _2src_commutative,
 # unsigned 24b multiply into 32b result uint
 binop("umul24", tint32, _2src_commutative + associative,
       "(((uint32_t)src0 << 8) >> 8) * (((uint32_t)src1 << 8) >> 8)")
+
+unop_convert("fisnormal", tbool1, tfloat, "isnormal(src0)")