intel/fs: add 64 bit integer multiplication lowering
[mesa.git] / src / intel / compiler / brw_fs.cpp
index ccf6c95520201f9e519206fca7b4fd6a12f2dc09..a4ebae336f5dea29ecd4e35bcff4656e4e7fe72a 100644 (file)
@@ -3990,6 +3990,62 @@ fs_visitor::lower_mul_dword_inst(fs_inst *inst, bblock_t *block)
    }
 }
 
+void
+fs_visitor::lower_mul_qword_inst(fs_inst *inst, bblock_t *block)
+{
+   const fs_builder ibld(this, block, inst);
+
+   /* Considering two 64-bit integers ab and cd where each letter        ab
+    * corresponds to 32 bits, we get a 128-bit result WXYZ. We         * cd
+    * only need to provide the YZ part of the result.               -------
+    *                                                                    BD
+    *  Only BD needs to be 64 bits. For AD and BC we only care       +  AD
+    *  about the lower 32 bits (since they are part of the upper     +  BC
+    *  32 bits of our result). AC is not needed since it starts      + AC
+    *  on the 65th bit of the result.                               -------
+    *                                                                  WXYZ
+    */
+   unsigned int q_regs = regs_written(inst);
+   unsigned int d_regs = (q_regs + 1) / 2;
+
+   fs_reg bd(VGRF, alloc.allocate(q_regs), BRW_REGISTER_TYPE_UQ);
+   fs_reg ad(VGRF, alloc.allocate(d_regs), BRW_REGISTER_TYPE_UD);
+   fs_reg bc(VGRF, alloc.allocate(d_regs), BRW_REGISTER_TYPE_UD);
+
+   /* Here we need the full 64 bit result for 32b * 32b. */
+   if (devinfo->has_integer_dword_mul) {
+      ibld.MUL(bd, subscript(inst->src[0], BRW_REGISTER_TYPE_UD, 0),
+               subscript(inst->src[1], BRW_REGISTER_TYPE_UD, 0));
+   } else {
+      fs_reg bd_high(VGRF, alloc.allocate(d_regs), BRW_REGISTER_TYPE_UD);
+      fs_reg bd_low(VGRF, alloc.allocate(d_regs), BRW_REGISTER_TYPE_UD);
+      fs_reg acc = retype(brw_acc_reg(inst->exec_size), BRW_REGISTER_TYPE_UD);
+
+      fs_inst *mul = ibld.MUL(acc,
+                            subscript(inst->src[0], BRW_REGISTER_TYPE_UD, 0),
+                            subscript(inst->src[1], BRW_REGISTER_TYPE_UW, 0));
+      mul->writes_accumulator = true;
+
+      ibld.MACH(bd_high, subscript(inst->src[0], BRW_REGISTER_TYPE_UD, 0),
+                subscript(inst->src[1], BRW_REGISTER_TYPE_UD, 0));
+      ibld.MOV(bd_low, acc);
+
+      ibld.MOV(subscript(bd, BRW_REGISTER_TYPE_UD, 0), bd_low);
+      ibld.MOV(subscript(bd, BRW_REGISTER_TYPE_UD, 1), bd_high);
+   }
+
+   ibld.MUL(ad, subscript(inst->src[0], BRW_REGISTER_TYPE_UD, 1),
+            subscript(inst->src[1], BRW_REGISTER_TYPE_UD, 0));
+   ibld.MUL(bc, subscript(inst->src[0], BRW_REGISTER_TYPE_UD, 0),
+            subscript(inst->src[1], BRW_REGISTER_TYPE_UD, 1));
+
+   ibld.ADD(ad, ad, bc);
+   ibld.ADD(subscript(bd, BRW_REGISTER_TYPE_UD, 1),
+            subscript(bd, BRW_REGISTER_TYPE_UD, 1), ad);
+
+   ibld.MOV(inst->dst, bd);
+}
+
 void
 fs_visitor::lower_mulh_inst(fs_inst *inst, bblock_t *block)
 {
@@ -4062,10 +4118,19 @@ fs_visitor::lower_integer_multiplication()
 
    foreach_block_and_inst_safe(block, fs_inst, inst, cfg) {
       if (inst->opcode == BRW_OPCODE_MUL) {
-         if (!inst->dst.is_accumulator() &&
-             (inst->dst.type == BRW_REGISTER_TYPE_D ||
-              inst->dst.type == BRW_REGISTER_TYPE_UD) &&
-             !devinfo->has_integer_dword_mul) {
+         if ((inst->dst.type == BRW_REGISTER_TYPE_Q ||
+              inst->dst.type == BRW_REGISTER_TYPE_UQ) &&
+             (inst->src[0].type == BRW_REGISTER_TYPE_Q ||
+              inst->src[0].type == BRW_REGISTER_TYPE_UQ) &&
+             (inst->src[1].type == BRW_REGISTER_TYPE_Q ||
+              inst->src[1].type == BRW_REGISTER_TYPE_UQ)) {
+            lower_mul_qword_inst(inst, block);
+            inst->remove(block);
+            progress = true;
+         } else if (!inst->dst.is_accumulator() &&
+                    (inst->dst.type == BRW_REGISTER_TYPE_D ||
+                     inst->dst.type == BRW_REGISTER_TYPE_UD) &&
+                    !devinfo->has_integer_dword_mul) {
             lower_mul_dword_inst(inst, block);
             inst->remove(block);
             progress = true;