aco: optimize some masked swizzles to DPP
[mesa.git] / src / amd / compiler / aco_instruction_selection.cpp
index 73be000351c116a689d2557567779064e0b0828b..3b82e46e33fe0c197317664ddc98e4ad44c27d70 100644 (file)
@@ -136,8 +136,11 @@ Temp emit_mbcnt(isel_context *ctx, Definition dst,
 
    if (ctx->program->wave_size == 32) {
       return thread_id_lo;
+   } else if (ctx->program->chip_class <= GFX7) {
+      Temp thread_id_hi = bld.vop2(aco_opcode::v_mbcnt_hi_u32_b32, dst, mask_hi, thread_id_lo);
+      return thread_id_hi;
    } else {
-      Temp thread_id_hi = bld.vop3(aco_opcode::v_mbcnt_hi_u32_b32, dst, mask_hi, thread_id_lo);
+      Temp thread_id_hi = bld.vop3(aco_opcode::v_mbcnt_hi_u32_b32_e64, dst, mask_hi, thread_id_lo);
       return thread_id_hi;
    }
 }
@@ -204,6 +207,36 @@ static Temp emit_bpermute(isel_context *ctx, Builder &bld, Temp index, Temp data
    }
 }
 
+static Temp emit_masked_swizzle(isel_context *ctx, Builder &bld, Temp src, unsigned mask)
+{
+   if (ctx->options->chip_class >= GFX8) {
+      unsigned and_mask = mask & 0x1f;
+      unsigned or_mask = (mask >> 5) & 0x1f;
+      unsigned xor_mask = (mask >> 10) & 0x1f;
+
+      uint16_t dpp_ctrl = 0xffff;
+
+      // TODO: we could use DPP8 for some swizzles
+      if (and_mask == 0x1f && or_mask < 4 && xor_mask < 4) {
+         unsigned res[4] = {0, 1, 2, 3};
+         for (unsigned i = 0; i < 4; i++)
+            res[i] = ((res[i] | or_mask) ^ xor_mask) & 0x3;
+         dpp_ctrl = dpp_quad_perm(res[0], res[1], res[2], res[3]);
+      } else if (and_mask == 0x1f && !or_mask && xor_mask == 8) {
+         dpp_ctrl = dpp_row_rr(8);
+      } else if (and_mask == 0x1f && !or_mask && xor_mask == 0xf) {
+         dpp_ctrl = dpp_row_mirror;
+      } else if (and_mask == 0x1f && !or_mask && xor_mask == 0x7) {
+         dpp_ctrl = dpp_row_half_mirror;
+      }
+
+      if (dpp_ctrl != 0xffff)
+         return bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), src, dpp_ctrl);
+   }
+
+   return bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), src, mask, 0, false);
+}
+
 Temp as_vgpr(isel_context *ctx, Temp val)
 {
    if (val.type() == RegType::sgpr) {
@@ -402,7 +435,7 @@ void byte_align_scalar(isel_context *ctx, Temp vec, Operand offset, Temp dst)
       bld.pseudo(aco_opcode::p_split_vector, Definition(lo), Definition(hi), vec);
       hi = bld.pseudo(aco_opcode::p_extract_vector, bld.def(s1), hi, Operand(0u));
       if (select != Temp())
-         hi = bld.sop2(aco_opcode::s_cselect_b32, bld.def(s1), hi, Operand(0u), select);
+         hi = bld.sop2(aco_opcode::s_cselect_b32, bld.def(s1), hi, Operand(0u), bld.scc(select));
       lo = bld.sop2(aco_opcode::s_lshr_b64, bld.def(s2), bld.def(s1, scc), lo, shift);
       Temp mid = bld.tmp(s1);
       lo = bld.pseudo(aco_opcode::p_split_vector, bld.def(s1), Definition(mid), lo);
@@ -979,7 +1012,8 @@ Temp emit_floor_f64(isel_context *ctx, Builder& bld, Definition dst, Temp val)
    if (ctx->options->chip_class >= GFX7)
       return bld.vop1(aco_opcode::v_floor_f64, Definition(dst), val);
 
-   /* GFX6 doesn't support V_FLOOR_F64, lower it. */
+   /* GFX6 doesn't support V_FLOOR_F64, lower it (note that it's actually
+    * lowered at NIR level for precision reasons). */
    Temp src0 = as_vgpr(ctx, val);
 
    Temp mask = bld.copy(bld.def(s1), Operand(3u)); /* isnan */
@@ -1907,6 +1941,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       } else if (dst.regClass() == v1) {
          emit_rsq(ctx, bld, Definition(dst), src);
       } else if (dst.regClass() == v2) {
+         /* Lowered at NIR level for precision reasons. */
          emit_vop1_instruction(ctx, instr, aco_opcode::v_rsq_f64, dst);
       } else {
          fprintf(stderr, "Unimplemented NIR instr bit size: ");
@@ -1998,6 +2033,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       } else if (dst.regClass() == v1) {
          emit_rcp(ctx, bld, Definition(dst), src);
       } else if (dst.regClass() == v2) {
+         /* Lowered at NIR level for precision reasons. */
          emit_vop1_instruction(ctx, instr, aco_opcode::v_rcp_f64, dst);
       } else {
          fprintf(stderr, "Unimplemented NIR instr bit size: ");
@@ -2025,6 +2061,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       } else if (dst.regClass() == v1) {
          emit_sqrt(ctx, bld, Definition(dst), src);
       } else if (dst.regClass() == v2) {
+         /* Lowered at NIR level for precision reasons. */
          emit_vop1_instruction(ctx, instr, aco_opcode::v_sqrt_f64, dst);
       } else {
          fprintf(stderr, "Unimplemented NIR instr bit size: ");
@@ -3111,7 +3148,9 @@ void emit_load(isel_context *ctx, Builder& bld, const LoadEmitInfo *info)
       int byte_align = align_mul % 4 == 0 ? align_offset % 4 : -1;
 
       if (byte_align) {
-         if ((bytes_needed > 2 || !supports_8bit_16bit_loads) && byte_align_loads) {
+         if ((bytes_needed > 2 ||
+              (bytes_needed == 2 && (align_mul % 2 || align_offset % 2)) ||
+              !supports_8bit_16bit_loads) && byte_align_loads) {
             if (info->component_stride) {
                assert(supports_8bit_16bit_loads && "unimplemented");
                bytes_needed = 2;
@@ -7914,13 +7953,13 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
       uint32_t mask = nir_intrinsic_swizzle_mask(instr);
       if (dst.regClass() == v1) {
          emit_wqm(ctx,
-                  bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), src, mask, 0, false),
+                  emit_masked_swizzle(ctx, bld, src, mask),
                   dst);
       } else if (dst.regClass() == v2) {
          Temp lo = bld.tmp(v1), hi = bld.tmp(v1);
          bld.pseudo(aco_opcode::p_split_vector, Definition(lo), Definition(hi), src);
-         lo = emit_wqm(ctx, bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), lo, mask, 0, false));
-         hi = emit_wqm(ctx, bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), hi, mask, 0, false));
+         lo = emit_wqm(ctx, emit_masked_swizzle(ctx, bld, lo, mask));
+         hi = emit_wqm(ctx, emit_masked_swizzle(ctx, bld, hi, mask));
          bld.pseudo(aco_opcode::p_create_vector, Definition(dst), lo, hi);
          emit_split_vector(ctx, dst, 2);
       } else {
@@ -8944,13 +8983,19 @@ void visit_tex(isel_context *ctx, nir_tex_instr *instr)
 }
 
 
-Operand get_phi_operand(isel_context *ctx, nir_ssa_def *ssa, RegClass rc)
+Operand get_phi_operand(isel_context *ctx, nir_ssa_def *ssa, RegClass rc, bool logical)
 {
    Temp tmp = get_ssa_temp(ctx, ssa);
-   if (ssa->parent_instr->type == nir_instr_type_ssa_undef)
+   if (ssa->parent_instr->type == nir_instr_type_ssa_undef) {
       return Operand(rc);
-   else
+   } else if (logical && ssa->bit_size == 1 && ssa->parent_instr->type == nir_instr_type_load_const) {
+      if (ctx->program->wave_size == 64)
+         return Operand(nir_instr_as_load_const(ssa->parent_instr)->value[0].b ? UINT64_MAX : 0u);
+      else
+         return Operand(nir_instr_as_load_const(ssa->parent_instr)->value[0].b ? UINT32_MAX : 0u);
+   } else {
       return Operand(tmp);
+   }
 }
 
 void visit_phi(isel_context *ctx, nir_phi_instr *instr)
@@ -8993,7 +9038,7 @@ void visit_phi(isel_context *ctx, nir_phi_instr *instr)
       if (!(ctx->block->kind & block_kind_loop_header) && cur_pred_idx >= preds.size())
          continue;
       cur_pred_idx++;
-      Operand op = get_phi_operand(ctx, src.second, dst.regClass());
+      Operand op = get_phi_operand(ctx, src.second, dst.regClass(), logical);
       operands[num_operands++] = op;
       num_defined += !op.isUndefined();
    }