nir: Add fisnormal op
[mesa.git] / src / compiler / nir / nir_opcodes.py
index 57be67320c53840714158792bf61a1bdafc5b7ba..d880c51eebce7d0af7b60796f156128e31501303 100644 (file)
@@ -77,7 +77,7 @@ class Opcode(object):
       assert len(input_sizes) == len(input_types)
       assert 0 <= output_size <= 4 or (output_size == 8) or (output_size == 16)
       for size in input_sizes:
-         assert 0 <= size <= 4
+         assert 0 <= size <= 4 or (size == 8) or (size == 16)
          if output_size != 0:
             assert size != 0
       self.name = name
@@ -99,6 +99,8 @@ tbool8 = "bool8"
 tbool16 = "bool16"
 tbool32 = "bool32"
 tuint = "uint"
+tuint8 = "uint8"
+tint16 = "int16"
 tuint16 = "uint16"
 tfloat16 = "float16"
 tfloat32 = "float32"
@@ -199,9 +201,9 @@ unop("fsign", tfloat, ("bit_size == 64 ? " +
 unop("isign", tint, "(src0 == 0) ? 0 : ((src0 > 0) ? 1 : -1)")
 unop("iabs", tint, "(src0 < 0) ? -src0 : src0")
 unop("fabs", tfloat, "fabs(src0)")
-unop("fsat", tfloat, ("bit_size == 64 ? " +
-                      "((src0 > 1.0) ? 1.0 : ((src0 <= 0.0) ? 0.0 : src0)) : " +
-                      "((src0 > 1.0f) ? 1.0f : ((src0 <= 0.0f) ? 0.0f : src0))"))
+unop("fsat", tfloat, ("fmin(fmax(src0, 0.0), 1.0)"))
+unop("fsat_signed", tfloat, ("fmin(fmax(src0, -1.0), 1.0)"))
+unop("fclamp_pos", tfloat, ("fmax(src0, 0.0)"))
 unop("frcp", tfloat, "bit_size == 64 ? 1.0 / src0 : 1.0f / src0")
 unop("frsq", tfloat, "bit_size == 64 ? 1.0 / sqrt(src0) : 1.0f / sqrtf(src0)")
 unop("fsqrt", tfloat, "bit_size == 64 ? sqrt(src0) : sqrtf(src0)")
@@ -266,11 +268,13 @@ for src_t in [tint, tuint, tfloat, tbool]:
                                                        dst_bit_size),
                                    dst_t + str(dst_bit_size), src_t, conv_expr)
 
-# Special opcode that is the same as f2f16 except that it is safe to remove it
-# if the result is immediately converted back to float32 again. This is
-# generated as part of the precision lowering pass. mp stands for medium
+# Special opcode that is the same as f2f16, i2i16, u2u16 except that it is safe
+# to remove it if the result is immediately converted back to 32 bits again.
+# This is generated as part of the precision lowering pass. mp stands for medium
 # precision.
 unop_numeric_convert("f2fmp", tfloat16, tfloat, opcodes["f2f16"].const_expr)
+unop_numeric_convert("i2imp", tint16, tint, opcodes["i2i16"].const_expr)
+unop_numeric_convert("u2ump", tuint16, tuint, opcodes["u2u16"].const_expr)
 
 # Unary floating-point rounding operations.
 
@@ -357,6 +361,9 @@ dst.x = (src0.x <<  0) |
         (src0.w << 24);
 """)
 
+unop_horiz("pack_32_4x8", 1, tuint32, 4, tuint8,
+           "dst.x = src0.x | ((uint32_t)src0.y << 8) | ((uint32_t)src0.z << 16) | ((uint32_t)src0.w << 24);")
+
 unop_horiz("pack_32_2x16", 1, tuint32, 2, tuint16,
            "dst.x = src0.x | ((uint32_t)src0.y << 16);")
 
@@ -375,6 +382,9 @@ unop_horiz("unpack_64_4x16", 4, tuint16, 1, tuint64,
 unop_horiz("unpack_32_2x16", 2, tuint16, 1, tuint32,
            "dst.x = src0.x; dst.y = src0.x >> 16;")
 
+unop_horiz("unpack_32_4x8", 4, tuint8, 1, tuint32,
+           "dst.x = src0.x; dst.y = src0.x >> 8; dst.z = src0.x >> 16; dst.w = src0.x >> 24;")
+
 unop_horiz("unpack_half_2x16_flush_to_zero", 2, tfloat32, 1, tuint32, """
 dst.x = unpack_half_1x16_flush_to_zero((uint16_t)(src0.x & 0xffff));
 dst.y = unpack_half_1x16_flush_to_zero((uint16_t)(src0.x << 16));
@@ -458,12 +468,6 @@ for (unsigned bit = 0; bit < bit_size; bit++) {
 }
 """)
 
-
-for i in range(1, 5):
-   for j in range(1, 5):
-      unop_horiz("fnoise{0}_{1}".format(i, j), i, tfloat, j, tfloat, "0.0f")
-
-
 # AMD_gcn_shader extended instructions
 unop_horiz("cube_face_coord", 2, tfloat32, 3, tfloat32, """
 dst.x = dst.y = 0.0;
@@ -483,8 +487,8 @@ if (src0.y < 0 && absY >= absX && absY >= absZ) { dst.x = src0.x; dst.y = -src0.
 if (src0.z >= 0 && absZ >= absX && absZ >= absY) { dst.x = src0.x; dst.y = -src0.y; }
 if (src0.z < 0 && absZ >= absX && absZ >= absY) { dst.x = -src0.x; dst.y = -src0.y; }
 
-dst.x = dst.x / ma + 0.5;
-dst.y = dst.y / ma + 0.5;
+dst.x = dst.x * (1.0f / ma) + 0.5f;
+dst.y = dst.y * (1.0f / ma) + 0.5f;
 """)
 
 unop_horiz("cube_face_index", 1, tfloat32, 3, tfloat32, """
@@ -540,19 +544,18 @@ def binop_reduce(name, output_size, output_type, src_type, prereduce_expr,
       return reduce_expr.format(src0=src0, src1=src1)
    def prereduce(src0, src1):
       return "(" + prereduce_expr.format(src0=src0, src1=src1) + ")"
-   src0 = prereduce("src0.x", "src1.x")
-   src1 = prereduce("src0.y", "src1.y")
-   src2 = prereduce("src0.z", "src1.z")
-   src3 = prereduce("src0.w", "src1.w")
-   opcode(name + "2", output_size, output_type,
-          [2, 2], [src_type, src_type], False, _2src_commutative,
-          final(reduce_(src0, src1)))
+   srcs = [prereduce("src0." + letter, "src1." + letter) for letter in "xyzwefghijklmnop"]
+   def pairwise_reduce(start, size):
+      if (size == 1):
+         return srcs[start]
+      return reduce_(pairwise_reduce(start, size // 2), pairwise_reduce(start + size // 2, size // 2))
+   for size in [2, 4, 8, 16]:
+      opcode(name + str(size), output_size, output_type,
+             [size, size], [src_type, src_type], False, _2src_commutative,
+             final(pairwise_reduce(0, size)))
    opcode(name + "3", output_size, output_type,
           [3, 3], [src_type, src_type], False, _2src_commutative,
-          final(reduce_(reduce_(src0, src1), src2)))
-   opcode(name + "4", output_size, output_type,
-          [4, 4], [src_type, src_type], False, _2src_commutative,
-          final(reduce_(reduce_(src0, src1), reduce_(src2, src3))))
+          final(reduce_(reduce_(srcs[0], srcs[1]), srcs[2])))
 
 def binop_reduce_all_sizes(name, output_size, src_type, prereduce_expr,
                            reduce_expr, final_expr):
@@ -1143,3 +1146,13 @@ triop("imad24_ir3", tint32, _2src_commutative,
 # 24b multiply into 32b result (with sign extension)
 binop("imul24", tint32, _2src_commutative + associative,
       "(((int32_t)src0 << 8) >> 8) * (((int32_t)src1 << 8) >> 8)")
+
+# unsigned 24b multiply into 32b result plus 32b int
+triop("umad24", tuint32, _2src_commutative,
+      "(((uint32_t)src0 << 8) >> 8) * (((uint32_t)src1 << 8) >> 8) + src2")
+
+# unsigned 24b multiply into 32b result uint
+binop("umul24", tint32, _2src_commutative + associative,
+      "(((uint32_t)src0 << 8) >> 8) * (((uint32_t)src1 << 8) >> 8)")
+
+unop_convert("fisnormal", tbool1, tfloat, "isnormal(src0)")