nir: Allow uniform in nir_lower_vars_to_explicit_types
[mesa.git] / src / compiler / nir / nir_opcodes.py
index f18668493b22509bd6267b4b26b7702b5a8086e0..e19d7b00a7d3e956a57334ad32b674dfbc9f28ad 100644 (file)
@@ -77,7 +77,7 @@ class Opcode(object):
       assert len(input_sizes) == len(input_types)
       assert 0 <= output_size <= 4 or (output_size == 8) or (output_size == 16)
       for size in input_sizes:
-         assert 0 <= size <= 4
+         assert 0 <= size <= 4 or (size == 8) or (size == 16)
          if output_size != 0:
             assert size != 0
       self.name = name
@@ -544,19 +544,18 @@ def binop_reduce(name, output_size, output_type, src_type, prereduce_expr,
       return reduce_expr.format(src0=src0, src1=src1)
    def prereduce(src0, src1):
       return "(" + prereduce_expr.format(src0=src0, src1=src1) + ")"
-   src0 = prereduce("src0.x", "src1.x")
-   src1 = prereduce("src0.y", "src1.y")
-   src2 = prereduce("src0.z", "src1.z")
-   src3 = prereduce("src0.w", "src1.w")
-   opcode(name + "2", output_size, output_type,
-          [2, 2], [src_type, src_type], False, _2src_commutative,
-          final(reduce_(src0, src1)))
+   srcs = [prereduce("src0." + letter, "src1." + letter) for letter in "xyzwefghijklmnop"]
+   def pairwise_reduce(start, size):
+      if (size == 1):
+         return srcs[start]
+      return reduce_(pairwise_reduce(start, size // 2), pairwise_reduce(start + size // 2, size // 2))
+   for size in [2, 4, 8, 16]:
+      opcode(name + str(size), output_size, output_type,
+             [size, size], [src_type, src_type], False, _2src_commutative,
+             final(pairwise_reduce(0, size)))
    opcode(name + "3", output_size, output_type,
           [3, 3], [src_type, src_type], False, _2src_commutative,
-          final(reduce_(reduce_(src0, src1), src2)))
-   opcode(name + "4", output_size, output_type,
-          [4, 4], [src_type, src_type], False, _2src_commutative,
-          final(reduce_(reduce_(src0, src1), reduce_(src2, src3))))
+          final(reduce_(reduce_(srcs[0], srcs[1]), srcs[2])))
 
 def binop_reduce_all_sizes(name, output_size, src_type, prereduce_expr,
                            reduce_expr, final_expr):
@@ -748,7 +747,7 @@ binop("frem", tfloat, "", "src0 - src1 * truncf(src0 / src1)")
 binop_compare_all_sizes("flt", tfloat, "", "src0 < src1")
 binop_compare_all_sizes("fge", tfloat, "", "src0 >= src1")
 binop_compare_all_sizes("feq", tfloat, _2src_commutative, "src0 == src1")
-binop_compare_all_sizes("fne", tfloat, _2src_commutative, "src0 != src1")
+binop_compare_all_sizes("fneu", tfloat, _2src_commutative, "src0 != src1")
 binop_compare_all_sizes("ilt", tint, "", "src0 < src1")
 binop_compare_all_sizes("ige", tint, "", "src0 >= src1")
 binop_compare_all_sizes("ieq", tint, _2src_commutative, "src0 == src1")
@@ -786,7 +785,7 @@ binop("sne", tfloat32, _2src_commutative, "(src0 != src1) ? 1.0f : 0.0f") # Set
 # but SM5 shifts are defined to use the least significant bits, only
 # The NIR definition is according to the SM5 specification.
 opcode("ishl", 0, tint, [0, 0], [tint, tuint32], False, "",
-       "src0 << (src1 & (sizeof(src0) * 8 - 1))")
+       "(uint64_t)src0 << (src1 & (sizeof(src0) * 8 - 1))")
 opcode("ishr", 0, tint, [0, 0], [tint, tuint32], False, "",
        "src0 >> (src1 & (sizeof(src0) * 8 - 1))")
 opcode("ushr", 0, tuint, [0, 0], [tuint, tuint32], False, "",
@@ -951,22 +950,8 @@ triop("flrp", tfloat, "", "src0 * (1 - src2) + src1 * src2")
 # component on vectors). There are two versions, one for floating point
 # bools (0.0 vs 1.0) and one for integer bools (0 vs ~0).
 
-
 triop("fcsel", tfloat32, "", "(src0 != 0.0f) ? src1 : src2")
 
-# 3 way min/max/med
-triop("fmin3", tfloat, "", "fminf(src0, fminf(src1, src2))")
-triop("imin3", tint, "", "MIN2(src0, MIN2(src1, src2))")
-triop("umin3", tuint, "", "MIN2(src0, MIN2(src1, src2))")
-
-triop("fmax3", tfloat, "", "fmaxf(src0, fmaxf(src1, src2))")
-triop("imax3", tint, "", "MAX2(src0, MAX2(src1, src2))")
-triop("umax3", tuint, "", "MAX2(src0, MAX2(src1, src2))")
-
-triop("fmed3", tfloat, "", "fmaxf(fminf(fmaxf(src0, src1), src2), fminf(src0, src1))")
-triop("imed3", tint, "", "MAX2(MIN2(MAX2(src0, src1), src2), MIN2(src0, src1))")
-triop("umed3", tuint, "", "MAX2(MIN2(MAX2(src0, src1), src2), MIN2(src0, src1))")
-
 opcode("bcsel", 0, tuint, [0, 0, 0],
        [tbool1, tuint, tuint], False, "", "src0 ? src1 : src2")
 opcode("b8csel", 0, tuint, [0, 0, 0],
@@ -1155,3 +1140,6 @@ triop("umad24", tuint32, _2src_commutative,
 # unsigned 24b multiply into 32b result uint
 binop("umul24", tint32, _2src_commutative + associative,
       "(((uint32_t)src0 << 8) >> 8) * (((uint32_t)src1 << 8) >> 8)")
+
+unop_convert("fisnormal", tbool1, tfloat, "isnormal(src0)")
+unop_convert("fisfinite", tbool1, tfloat, "isfinite(src0)")