aco: create better code for boolean phis with constant operands
[mesa.git] / src / amd / compiler / aco_lower_phis.cpp
index b90d99ee42456e899754b73793eb14d7e80436ea..a7d2b6dce724e181394665cf6013ac5bb1e6a61e 100644 (file)
@@ -96,9 +96,56 @@ void insert_before_logical_end(Block *block, aco_ptr<Instruction> instr)
    if (it == block->instructions.crend()) {
       assert(block->instructions.back()->format == Format::PSEUDO_BRANCH);
       block->instructions.insert(std::prev(block->instructions.end()), std::move(instr));
-   }
-   else
+   } else {
       block->instructions.insert(std::prev(it.base()), std::move(instr));
+   }
+}
+
+void build_merge_code(Program *program, Block *block, Definition dst, Operand prev, Operand cur)
+{
+   Builder bld(program);
+
+   auto IsLogicalEnd = [] (const aco_ptr<Instruction>& instr) -> bool {
+      return instr->opcode == aco_opcode::p_logical_end;
+   };
+   auto it = std::find_if(block->instructions.rbegin(), block->instructions.rend(), IsLogicalEnd);
+   assert(it != block->instructions.rend());
+   bld.reset(&block->instructions, std::prev(it.base()));
+
+   if (prev.isUndefined()) {
+      bld.sop1(Builder::s_mov, dst, cur);
+      return;
+   }
+
+   bool prev_is_constant = prev.isConstant() && prev.constantValue64(true) + 1u < 2u;
+   bool cur_is_constant = cur.isConstant() && cur.constantValue64(true) + 1u < 2u;
+
+   if (!prev_is_constant) {
+      if (!cur_is_constant) {
+         Temp tmp1 = bld.tmp(bld.lm), tmp2 = bld.tmp(bld.lm);
+         bld.sop2(Builder::s_andn2, Definition(tmp1), bld.def(s1, scc), prev, Operand(exec, bld.lm));
+         bld.sop2(Builder::s_and, Definition(tmp2), bld.def(s1, scc), cur, Operand(exec, bld.lm));
+         bld.sop2(Builder::s_or, dst, bld.def(s1, scc), tmp1, tmp2);
+      } else if (cur.constantValue64(true)) {
+         bld.sop2(Builder::s_or, dst, bld.def(s1, scc), prev, Operand(exec, bld.lm));
+      } else {
+         bld.sop2(Builder::s_andn2, dst, bld.def(s1, scc), prev, Operand(exec, bld.lm));
+      }
+   } else if (prev.constantValue64(true)) {
+      if (!cur_is_constant)
+         bld.sop2(Builder::s_orn2, dst, bld.def(s1, scc), cur, Operand(exec, bld.lm));
+      else if (cur.constantValue64(true))
+         bld.sop1(Builder::s_mov, dst, program->wave_size == 64 ? Operand(UINT64_MAX) : Operand(UINT32_MAX));
+      else
+         bld.sop1(Builder::s_not, dst, bld.def(s1, scc), Operand(exec, bld.lm));
+   } else {
+      if (!cur_is_constant)
+         bld.sop2(Builder::s_and, dst, bld.def(s1, scc), cur, Operand(exec, bld.lm));
+      else if (cur.constantValue64(true))
+         bld.sop1(Builder::s_mov, dst, Operand(exec, bld.lm));
+      else
+         bld.sop1(Builder::s_mov, dst, program->wave_size == 64 ? Operand((uint64_t)0u) : Operand(0u));
+   }
 }
 
 void lower_divergent_bool_phi(Program *program, ssa_state *state, Block *block, aco_ptr<Instruction>& phi)
@@ -144,20 +191,9 @@ void lower_divergent_bool_phi(Program *program, ssa_state *state, Block *block,
       Temp new_cur = {state->writes.at(pred->index), program->lane_mask};
       assert(new_cur.regClass() == bld.lm);
 
-      if (cur.isUndefined()) {
-         insert_before_logical_end(pred, bld.sop1(aco_opcode::s_mov_b64, Definition(new_cur), phi->operands[i]).get_ptr());
-      } else {
-         Temp tmp1 = bld.tmp(bld.lm), tmp2 = bld.tmp(bld.lm);
-         insert_before_logical_end(pred,
-            bld.sop2(Builder::s_andn2, Definition(tmp1), bld.def(s1, scc),
-                     cur, Operand(exec, bld.lm)).get_ptr());
-         insert_before_logical_end(pred,
-            bld.sop2(Builder::s_and, Definition(tmp2), bld.def(s1, scc),
-                     phi->operands[i].getTemp(), Operand(exec, bld.lm)).get_ptr());
-         insert_before_logical_end(pred,
-            bld.sop2(Builder::s_or, Definition(new_cur), bld.def(s1, scc),
-                     tmp1, tmp2).get_ptr());
-      }
+      if (i == 1 && (block->kind & block_kind_merge) && phi->operands[0].isConstant())
+         cur = phi->operands[0];
+      build_merge_code(program, pred, Definition(new_cur), cur, phi->operands[i]);
    }
 
    unsigned num_preds = block->linear_preds.size();