assert len(input_sizes) == len(input_types)
assert 0 <= output_size <= 4 or (output_size == 8) or (output_size == 16)
for size in input_sizes:
- assert 0 <= size <= 4
+ assert 0 <= size <= 4 or (size == 8) or (size == 16)
if output_size != 0:
assert size != 0
self.name = name
tbool32 = "bool32"
tuint = "uint"
tuint8 = "uint8"
+tint16 = "int16"
tuint16 = "uint16"
tfloat16 = "float16"
tfloat32 = "float32"
dst_bit_size),
dst_t + str(dst_bit_size), src_t, conv_expr)
-# Special opcode that is the same as f2f16 except that it is safe to remove it
-# if the result is immediately converted back to float32 again. This is
-# generated as part of the precision lowering pass. mp stands for medium
+# Special opcode that is the same as f2f16, i2i16, u2u16 except that it is safe
+# to remove it if the result is immediately converted back to 32 bits again.
+# This is generated as part of the precision lowering pass. mp stands for medium
# precision.
unop_numeric_convert("f2fmp", tfloat16, tfloat, opcodes["f2f16"].const_expr)
+unop_numeric_convert("i2imp", tint16, tint, opcodes["i2i16"].const_expr)
+unop_numeric_convert("u2ump", tuint16, tuint, opcodes["u2u16"].const_expr)
# Unary floating-point rounding operations.
if (src0.z >= 0 && absZ >= absX && absZ >= absY) { dst.x = src0.x; dst.y = -src0.y; }
if (src0.z < 0 && absZ >= absX && absZ >= absY) { dst.x = -src0.x; dst.y = -src0.y; }
-dst.x = dst.x / ma + 0.5;
-dst.y = dst.y / ma + 0.5;
+dst.x = dst.x * (1.0f / ma) + 0.5f;
+dst.y = dst.y * (1.0f / ma) + 0.5f;
""")
unop_horiz("cube_face_index", 1, tfloat32, 3, tfloat32, """
return reduce_expr.format(src0=src0, src1=src1)
def prereduce(src0, src1):
return "(" + prereduce_expr.format(src0=src0, src1=src1) + ")"
- src0 = prereduce("src0.x", "src1.x")
- src1 = prereduce("src0.y", "src1.y")
- src2 = prereduce("src0.z", "src1.z")
- src3 = prereduce("src0.w", "src1.w")
- opcode(name + "2", output_size, output_type,
- [2, 2], [src_type, src_type], False, _2src_commutative,
- final(reduce_(src0, src1)))
+ srcs = [prereduce("src0." + letter, "src1." + letter) for letter in "xyzwefghijklmnop"]
+ def pairwise_reduce(start, size):
+ if (size == 1):
+ return srcs[start]
+ return reduce_(pairwise_reduce(start, size // 2), pairwise_reduce(start + size // 2, size // 2))
+ for size in [2, 4, 8, 16]:
+ opcode(name + str(size), output_size, output_type,
+ [size, size], [src_type, src_type], False, _2src_commutative,
+ final(pairwise_reduce(0, size)))
opcode(name + "3", output_size, output_type,
[3, 3], [src_type, src_type], False, _2src_commutative,
- final(reduce_(reduce_(src0, src1), src2)))
- opcode(name + "4", output_size, output_type,
- [4, 4], [src_type, src_type], False, _2src_commutative,
- final(reduce_(reduce_(src0, src1), reduce_(src2, src3))))
+ final(reduce_(reduce_(srcs[0], srcs[1]), srcs[2])))
def binop_reduce_all_sizes(name, output_size, src_type, prereduce_expr,
reduce_expr, final_expr):
# unsigned 24b multiply into 32b result uint
binop("umul24", tint32, _2src_commutative + associative,
"(((uint32_t)src0 << 8) >> 8) * (((uint32_t)src1 << 8) >> 8)")
+
+unop_convert("fisnormal", tbool1, tfloat, "isnormal(src0)")