nir/lower_int64: Add support for [iu]mul_high
authorJason Ekstrand <jason.ekstrand@intel.com>
Fri, 29 Dec 2017 22:38:55 +0000 (14:38 -0800)
committerJason Ekstrand <jason@jlekstrand.net>
Thu, 13 Dec 2018 17:49:48 +0000 (17:49 +0000)
Reviewed-by: Ian Romanick ian.d.romanick@intel.com
src/compiler/nir/nir.h
src/compiler/nir/nir_lower_int64.c

index 5d9c96fe11e693e623a9ef2b750ed2fb303709c2..a2c68d66aea41d6a4c6ae78d546004defab6be68 100644 (file)
@@ -3106,6 +3106,8 @@ typedef enum {
    nir_lower_isign64 = (1 << 1),
    /** Lower all int64 modulus and division opcodes */
    nir_lower_divmod64 = (1 << 2),
+   /** Lower all 64-bit umul_high and imul_high opcodes */
+   nir_lower_imul_high64 = (1 << 3),
 } nir_lower_int64_options;
 
 bool nir_lower_int64(nir_shader *shader, nir_lower_int64_options options);
index 81669c02cc61dbf6faddc5941c719aeb0e2510c8..2a9ea3e1bddbda62ca0a9437dc833a945ce002d8 100644 (file)
@@ -40,6 +40,64 @@ lower_imul64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
    return nir_pack_64_2x32_split(b, res_lo, res_hi);
 }
 
+static nir_ssa_def *
+lower_mul_high64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y,
+                 bool sign_extend)
+{
+   nir_ssa_def *x32[4], *y32[4];
+   x32[0] = nir_unpack_64_2x32_split_x(b, x);
+   x32[1] = nir_unpack_64_2x32_split_y(b, x);
+   if (sign_extend) {
+      x32[2] = x32[3] = nir_ishr(b, x32[1], nir_imm_int(b, 31));
+   } else {
+      x32[2] = x32[3] = nir_imm_int(b, 0);
+   }
+
+   y32[0] = nir_unpack_64_2x32_split_x(b, y);
+   y32[1] = nir_unpack_64_2x32_split_y(b, y);
+   if (sign_extend) {
+      y32[2] = y32[3] = nir_ishr(b, y32[1], nir_imm_int(b, 31));
+   } else {
+      y32[2] = y32[3] = nir_imm_int(b, 0);
+   }
+
+   nir_ssa_def *res[8] = { NULL, };
+
+   /* Yes, the following generates a pile of code.  However, we throw res[0]
+    * and res[1] away in the end and, if we're in the umul case, four of our
+    * eight dword operands will be constant zero and opt_algebraic will clean
+    * this up nicely.
+    */
+   for (unsigned i = 0; i < 4; i++) {
+      nir_ssa_def *carry = NULL;
+      for (unsigned j = 0; j < 4; j++) {
+         /* The maximum values of x32[i] and y32[i] are UINT32_MAX so the
+          * maximum value of tmp is UINT32_MAX * UINT32_MAX.  The maximum
+          * value that will fit in tmp is
+          *
+          *    UINT64_MAX = UINT32_MAX << 32 + UINT32_MAX
+          *               = UINT32_MAX * (UINT32_MAX + 1) + UINT32_MAX
+          *               = UINT32_MAX * UINT32_MAX + 2 * UINT32_MAX
+          *
+          * so we're guaranteed that we can add in two more 32-bit values
+          * without overflowing tmp.
+          */
+         nir_ssa_def *tmp =
+            nir_pack_64_2x32_split(b, nir_imul(b, x32[i], y32[j]),
+                                      nir_umul_high(b, x32[i], y32[j]));
+         if (res[i + j])
+            tmp = nir_iadd(b, tmp, nir_u2u64(b, res[i + j]));
+         if (carry)
+            tmp = nir_iadd(b, tmp, carry);
+         res[i + j] = nir_u2u32(b, tmp);
+         carry = nir_ushr(b, tmp, nir_imm_int(b, 32));
+      }
+      res[i + 4] = nir_u2u32(b, carry);
+   }
+
+   return nir_pack_64_2x32_split(b, res[2], res[3]);
+}
+
 static nir_ssa_def *
 lower_isign64(nir_builder *b, nir_ssa_def *x)
 {
@@ -209,6 +267,9 @@ opcode_to_options_mask(nir_op opcode)
    switch (opcode) {
    case nir_op_imul:
       return nir_lower_imul64;
+   case nir_op_imul_high:
+   case nir_op_umul_high:
+      return nir_lower_imul_high64;
    case nir_op_isign:
       return nir_lower_isign64;
    case nir_op_udiv:
@@ -232,6 +293,10 @@ lower_int64_alu_instr(nir_builder *b, nir_alu_instr *alu)
    switch (alu->op) {
    case nir_op_imul:
       return lower_imul64(b, src[0], src[1]);
+   case nir_op_imul_high:
+      return lower_mul_high64(b, src[0], src[1], true);
+   case nir_op_umul_high:
+      return lower_mul_high64(b, src[0], src[1], false);
    case nir_op_isign:
       return lower_isign64(b, src[0]);
    case nir_op_udiv: