gm107/ir: fix indirect txq emission
[mesa.git] / src / gallium / drivers / nouveau / codegen / nv50_ir_peephole.cpp
index e6f8f69ce91376d20b536bc4fabe922655ef3b9c..ad9bf6f4aa9aed2804c0df76151dbbb3ab1680e0 100644 (file)
@@ -118,6 +118,35 @@ CopyPropagation::visit(BasicBlock *bb)
 
 // =============================================================================
 
+class MergeSplits : public Pass
+{
+private:
+   virtual bool visit(BasicBlock *);
+};
+
+// For SPLIT / MERGE pairs that operate on the same registers, replace the
+// post-merge def with the SPLIT's source.
+bool
+MergeSplits::visit(BasicBlock *bb)
+{
+   Instruction *i, *next, *si;
+
+   for (i = bb->getEntry(); i; i = next) {
+      next = i->next;
+      if (i->op != OP_MERGE || typeSizeof(i->dType) != 8)
+         continue;
+      si = i->getSrc(0)->getInsn();
+      if (si->op != OP_SPLIT || si != i->getSrc(1)->getInsn())
+         continue;
+      i->def(0).replace(si->getSrc(0), false);
+      delete_Instruction(prog, i);
+   }
+
+   return true;
+}
+
+// =============================================================================
+
 class LoadPropagation : public Pass
 {
 private:
@@ -207,6 +236,9 @@ LoadPropagation::visit(BasicBlock *bb)
       if (i->op == OP_CALL) // calls have args as sources, they must be in regs
          continue;
 
+      if (i->op == OP_PFETCH) // pfetch expects arg1 to be a reg
+         continue;
+
       if (i->srcExists(1))
          checkSwapSrc01(i);
 
@@ -249,7 +281,6 @@ private:
 
    void tryCollapseChainedMULs(Instruction *, const int s, ImmediateValue&);
 
-   // TGSI 'true' is converted to -1 by F2I(NEG(SET)), track back to SET
    CmpInstruction *findOriginForTestWithZero(Value *);
 
    unsigned int foldCount;
@@ -308,25 +339,33 @@ ConstantFolding::findOriginForTestWithZero(Value *value)
       return NULL;
    Instruction *insn = value->getInsn();
 
-   while (insn && insn->op != OP_SET) {
-      Instruction *next = NULL;
-      switch (insn->op) {
-      case OP_NEG:
-      case OP_ABS:
-      case OP_CVT:
-         next = insn->getSrc(0)->getInsn();
-         if (insn->sType != next->dType)
+   if (insn->asCmp() && insn->op != OP_SLCT)
+      return insn->asCmp();
+
+   /* Sometimes mov's will sneak in as a result of other folding. This gets
+    * cleaned up later.
+    */
+   if (insn->op == OP_MOV)
+      return findOriginForTestWithZero(insn->getSrc(0));
+
+   /* Deal with AND 1.0 here since nv50 can't fold into boolean float */
+   if (insn->op == OP_AND) {
+      int s = 0;
+      ImmediateValue imm;
+      if (!insn->src(s).getImmediate(imm)) {
+         s = 1;
+         if (!insn->src(s).getImmediate(imm))
             return NULL;
-         break;
-      case OP_MOV:
-         next = insn->getSrc(0)->getInsn();
-         break;
-      default:
-         return NULL;
       }
-      insn = next;
+      if (imm.reg.data.f32 != 1.0f)
+         return NULL;
+      /* TODO: Come up with a way to handle the condition being inverted */
+      if (insn->src(!s).mod != Modifier(0))
+         return NULL;
+      return findOriginForTestWithZero(insn->getSrc(!s));
    }
-   return insn ? insn->asCmp() : NULL;
+
+   return NULL;
 }
 
 void
@@ -422,7 +461,9 @@ ConstantFolding::expr(Instruction *i,
             b->data.f32 = 0.0f;
       }
       switch (i->dType) {
-      case TYPE_F32: res.data.f32 = a->data.f32 * b->data.f32; break;
+      case TYPE_F32:
+         res.data.f32 = a->data.f32 * b->data.f32 * exp2f(i->postFactor);
+         break;
       case TYPE_F64: res.data.f64 = a->data.f64 * b->data.f64; break;
       case TYPE_S32:
          if (i->subOp == NV50_IR_SUBOP_MUL_HIGH) {
@@ -528,14 +569,26 @@ ConstantFolding::expr(Instruction *i,
          rshift = 32 - width;
          lshift = 32 - width - offset;
       }
+      if (i->subOp == NV50_IR_SUBOP_EXTBF_REV)
+         res.data.u32 = util_bitreverse(a->data.u32);
+      else
+         res.data.u32 = a->data.u32;
       switch (i->dType) {
-      case TYPE_S32: res.data.s32 = (a->data.s32 << lshift) >> rshift; break;
-      case TYPE_U32: res.data.u32 = (a->data.u32 << lshift) >> rshift; break;
+      case TYPE_S32: res.data.s32 = (res.data.s32 << lshift) >> rshift; break;
+      case TYPE_U32: res.data.u32 = (res.data.u32 << lshift) >> rshift; break;
       default:
          return;
       }
       break;
    }
+   case OP_POPCNT:
+      res.data.u32 = util_bitcount(a->data.u32 & b->data.u32);
+      break;
+   case OP_PFETCH:
+      // The two arguments to pfetch are logically added together. Normally
+      // the second argument will not be constant, but that can happen.
+      res.data.u32 = a->data.u32 + b->data.u32;
+      break;
    default:
       return;
    }
@@ -543,26 +596,43 @@ ConstantFolding::expr(Instruction *i,
 
    i->src(0).mod = Modifier(0);
    i->src(1).mod = Modifier(0);
+   i->postFactor = 0;
 
    i->setSrc(0, new_ImmediateValue(i->bb->getProgram(), res.data.u32));
    i->setSrc(1, NULL);
 
    i->getSrc(0)->reg.data = res.data;
 
-   if (i->op == OP_MAD || i->op == OP_FMA) {
+   switch (i->op) {
+   case OP_MAD:
+   case OP_FMA: {
       i->op = OP_ADD;
 
+      /* Move the immediate to the second arg, otherwise the ADD operation
+       * won't be emittable
+       */
       i->setSrc(1, i->getSrc(0));
-      i->src(1).mod = i->src(2).mod;
       i->setSrc(0, i->getSrc(2));
+      i->src(0).mod = i->src(2).mod;
       i->setSrc(2, NULL);
 
       ImmediateValue src0;
       if (i->src(0).getImmediate(src0))
          expr(i, src0, *i->getSrc(1)->asImm());
-   } else {
-      i->op = OP_MOV;
+      if (i->saturate && !prog->getTarget()->isSatSupported(i)) {
+         bld.setPosition(i, false);
+         i->setSrc(1, bld.loadImm(NULL, res.data.u32));
+      }
+      break;
+   }
+   case OP_PFETCH:
+      // Leave PFETCH alone... we just folded its 2 args into 1.
+      break;
+   default:
+      i->op = i->saturate ? OP_SAT : OP_MOV; /* SAT handled by unary() */
+      break;
    }
+   i->subOp = 0;
 }
 
 void
@@ -612,6 +682,7 @@ ConstantFolding::unary(Instruction *i, const ImmediateValue &imm)
    switch (i->op) {
    case OP_NEG: res.data.f32 = -imm.reg.data.f32; break;
    case OP_ABS: res.data.f32 = fabsf(imm.reg.data.f32); break;
+   case OP_SAT: res.data.f32 = CLAMP(imm.reg.data.f32, 0.0f, 1.0f); break;
    case OP_RCP: res.data.f32 = 1.0f / imm.reg.data.f32; break;
    case OP_RSQ: res.data.f32 = 1.0f / sqrtf(imm.reg.data.f32); break;
    case OP_LG2: res.data.f32 = log2f(imm.reg.data.f32); break;
@@ -640,7 +711,7 @@ ConstantFolding::tryCollapseChainedMULs(Instruction *mul2,
    Instruction *insn;
    Instruction *mul1 = NULL; // mul1 before mul2
    int e = 0;
-   float f = imm2.reg.data.f32;
+   float f = imm2.reg.data.f32 * exp2f(mul2->postFactor);
    ImmediateValue imm1;
 
    assert(mul2->op == OP_MUL && mul2->dType == TYPE_F32);
@@ -660,6 +731,7 @@ ConstantFolding::tryCollapseChainedMULs(Instruction *mul2,
             mul1->setSrc(s1, bld.loadImm(NULL, f * imm1.reg.data.f32));
             mul1->src(s1).mod = Modifier(0);
             mul2->def(0).replace(mul1->getDef(0), false);
+            mul1->saturate = mul2->saturate;
          } else
          if (prog->getTarget()->isPostMultiplySupported(OP_MUL, f, e)) {
             // c = mul a, b
@@ -668,8 +740,8 @@ ConstantFolding::tryCollapseChainedMULs(Instruction *mul2,
             mul2->def(0).replace(mul1->getDef(0), false);
             if (f < 0)
                mul1->src(0).mod *= Modifier(NV50_IR_MOD_NEG);
+            mul1->saturate = mul2->saturate;
          }
-         mul1->saturate = mul2->saturate;
          return;
       }
    }
@@ -677,7 +749,7 @@ ConstantFolding::tryCollapseChainedMULs(Instruction *mul2,
       // b = mul a, imm
       // d = mul b, c   -> d = mul_x_imm a, c
       int s2, t2;
-      insn = mul2->getDef(0)->uses.front()->getInsn();
+      insn = (*mul2->getDef(0)->uses.begin())->getInsn();
       if (!insn)
          return;
       mul1 = mul2;
@@ -740,9 +812,10 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
          i->op = OP_MOV;
          i->setSrc(0, new_ImmediateValue(prog, 0u));
          i->src(0).mod = Modifier(0);
+         i->postFactor = 0;
          i->setSrc(1, NULL);
       } else
-      if (imm0.isInteger(1) || imm0.isInteger(-1)) {
+      if (!i->postFactor && (imm0.isInteger(1) || imm0.isInteger(-1))) {
          if (imm0.isNegative())
             i->src(t).mod = i->src(t).mod ^ Modifier(NV50_IR_MOD_NEG);
          i->op = i->src(t).mod.getOp();
@@ -755,7 +828,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
             i->src(0).mod = 0;
          i->setSrc(1, NULL);
       } else
-      if (imm0.isInteger(2) || imm0.isInteger(-2)) {
+      if (!i->postFactor && (imm0.isInteger(2) || imm0.isInteger(-2))) {
          if (imm0.isNegative())
             i->src(t).mod = i->src(t).mod ^ Modifier(NV50_IR_MOD_NEG);
          i->op = OP_ADD;
@@ -771,6 +844,29 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
          i->src(1).mod = 0;
       }
       break;
+   case OP_MAD:
+      if (imm0.isInteger(0)) {
+         i->setSrc(0, i->getSrc(2));
+         i->src(0).mod = i->src(2).mod;
+         i->setSrc(1, NULL);
+         i->setSrc(2, NULL);
+         i->op = i->src(0).mod.getOp();
+         if (i->op != OP_CVT)
+            i->src(0).mod = 0;
+      } else
+      if (imm0.isInteger(1) || imm0.isInteger(-1)) {
+         if (imm0.isNegative())
+            i->src(t).mod = i->src(t).mod ^ Modifier(NV50_IR_MOD_NEG);
+         if (s == 0) {
+            i->setSrc(0, i->getSrc(1));
+            i->src(0).mod = i->src(1).mod;
+         }
+         i->setSrc(1, i->getSrc(2));
+         i->src(1).mod = i->src(2).mod;
+         i->setSrc(2, NULL);
+         i->op = OP_ADD;
+      }
+      break;
    case OP_ADD:
       if (i->usesFlags())
          break;
@@ -876,33 +972,82 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
 
    case OP_SET: // TODO: SET_AND,OR,XOR
    {
+      /* This optimizes the case where the output of a set is being compared
+       * to zero. Since the set can only produce 0/-1 (int) or 0/1 (float), we
+       * can be a lot cleverer in our comparison.
+       */
       CmpInstruction *si = findOriginForTestWithZero(i->getSrc(t));
       CondCode cc, ccZ;
-      if (i->src(t).mod != Modifier(0))
-         return;
-      if (imm0.reg.data.u32 != 0 || !si || si->op != OP_SET)
+      if (imm0.reg.data.u32 != 0 || !si)
          return;
       cc = si->setCond;
       ccZ = (CondCode)((unsigned int)i->asCmp()->setCond & ~CC_U);
+      // We do everything assuming var (cmp) 0, reverse the condition if 0 is
+      // first.
       if (s == 0)
          ccZ = reverseCondCode(ccZ);
+      // If there is a negative modifier, we need to undo that, by flipping
+      // the comparison to zero.
+      if (i->src(t).mod.neg())
+         ccZ = reverseCondCode(ccZ);
+      // If this is a signed comparison, we expect the input to be a regular
+      // boolean, i.e. 0/-1. However the rest of the logic assumes that true
+      // is positive, so just flip the sign.
+      if (i->sType == TYPE_S32) {
+         assert(!isFloatType(si->dType));
+         ccZ = reverseCondCode(ccZ);
+      }
       switch (ccZ) {
-      case CC_LT: cc = CC_FL; break;
-      case CC_GE: cc = CC_TR; break;
-      case CC_EQ: cc = inverseCondCode(cc); break;
-      case CC_LE: cc = inverseCondCode(cc); break;
-      case CC_GT: break;
-      case CC_NE: break;
+      case CC_LT: cc = CC_FL; break; // bool < 0 -- this is never true
+      case CC_GE: cc = CC_TR; break; // bool >= 0 -- this is always true
+      case CC_EQ: cc = inverseCondCode(cc); break; // bool == 0 -- !bool
+      case CC_LE: cc = inverseCondCode(cc); break; // bool <= 0 -- !bool
+      case CC_GT: break; // bool > 0 -- bool
+      case CC_NE: break; // bool != 0 -- bool
       default:
          return;
       }
+
+      // Update the condition of this SET to be identical to the origin set,
+      // but with the updated condition code. The original SET should get
+      // DCE'd, ideally.
+      i->op = si->op;
       i->asCmp()->setCond = cc;
       i->setSrc(0, si->src(0));
       i->setSrc(1, si->src(1));
+      if (si->srcExists(2))
+         i->setSrc(2, si->src(2));
       i->sType = si->sType;
    }
       break;
 
+   case OP_AND:
+   {
+      CmpInstruction *cmp = i->getSrc(t)->getInsn()->asCmp();
+      if (!cmp || cmp->op == OP_SLCT || cmp->getDef(0)->refCount() > 1)
+         return;
+      if (!prog->getTarget()->isOpSupported(cmp->op, TYPE_F32))
+         return;
+      if (imm0.reg.data.f32 != 1.0)
+         return;
+      if (i->getSrc(t)->getInsn()->dType != TYPE_U32)
+         return;
+
+      i->getSrc(t)->getInsn()->dType = TYPE_F32;
+      if (i->src(t).mod != Modifier(0)) {
+         assert(i->src(t).mod == Modifier(NV50_IR_MOD_NOT));
+         i->src(t).mod = Modifier(0);
+         cmp->setCond = inverseCondCode(cmp->setCond);
+      }
+      i->op = OP_MOV;
+      i->setSrc(s, NULL);
+      if (t) {
+         i->setSrc(0, i->getSrc(t));
+         i->setSrc(t, NULL);
+      }
+   }
+      break;
+
    case OP_SHL:
    {
       if (s != 1 || i->src(0).mod != Modifier(0))
@@ -922,6 +1067,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
 
    case OP_ABS:
    case OP_NEG:
+   case OP_SAT:
    case OP_LG2:
    case OP_RCP:
    case OP_SQRT:
@@ -933,6 +1079,33 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
    case OP_EX2:
       unary(i, imm0);
       break;
+   case OP_BFIND: {
+      int32_t res;
+      switch (i->dType) {
+      case TYPE_S32: res = util_last_bit_signed(imm0.reg.data.s32) - 1; break;
+      case TYPE_U32: res = util_last_bit(imm0.reg.data.u32) - 1; break;
+      default:
+         return;
+      }
+      if (i->subOp == NV50_IR_SUBOP_BFIND_SAMT && res >= 0)
+         res = 31 - res;
+      bld.setPosition(i, false); /* make sure bld is init'ed */
+      i->setSrc(0, bld.mkImm(res));
+      i->setSrc(1, NULL);
+      i->op = OP_MOV;
+      i->subOp = 0;
+      break;
+   }
+   case OP_POPCNT: {
+      // Only deal with 1-arg POPCNT here
+      if (i->srcExists(1))
+         break;
+      uint32_t res = util_bitcount(imm0.reg.data.u32);
+      i->setSrc(0, new_ImmediateValue(i->bb->getProgram(), res));
+      i->setSrc(1, NULL);
+      i->op = OP_MOV;
+      break;
+   }
    default:
       return;
    }
@@ -2118,7 +2291,7 @@ FlatteningPass::visit(BasicBlock *bb)
              insn->op != OP_LINTERP && // probably just nve4
              insn->op != OP_PINTERP && // probably just nve4
              ((insn->op != OP_LOAD && insn->op != OP_STORE) ||
-              typeSizeof(insn->dType) <= 4) &&
+              (typeSizeof(insn->dType) <= 4 && !insn->src(0).isIndirect(0))) &&
              !insn->isNop()) {
             insn->join = 1;
             bb->remove(bb->getExit());
@@ -2195,6 +2368,57 @@ FlatteningPass::tryPredicateConditional(BasicBlock *bb)
 
 // =============================================================================
 
+// Fold Immediate into MAD; must be done after register allocation due to
+// constraint SDST == SSRC2
+// TODO:
+// Does NVC0+ have other situations where this pass makes sense?
+class NV50PostRaConstantFolding : public Pass
+{
+private:
+   virtual bool visit(BasicBlock *);
+};
+
+bool
+NV50PostRaConstantFolding::visit(BasicBlock *bb)
+{
+   Value *vtmp;
+   Instruction *def;
+
+   for (Instruction *i = bb->getFirst(); i; i = i->next) {
+      switch (i->op) {
+      case OP_MAD:
+         if (i->def(0).getFile() != FILE_GPR ||
+             i->src(0).getFile() != FILE_GPR ||
+             i->src(1).getFile() != FILE_GPR ||
+             i->src(2).getFile() != FILE_GPR ||
+             i->getDef(0)->reg.data.id != i->getSrc(2)->reg.data.id ||
+             !isFloatType(i->dType))
+            break;
+
+         def = i->getSrc(1)->getInsn();
+         if (def->op == OP_MOV && def->src(0).getFile() == FILE_IMMEDIATE) {
+            vtmp = i->getSrc(1);
+            i->setSrc(1, def->getSrc(0));
+
+            /* There's no post-RA dead code elimination, so do it here
+             * XXX: if we add more code-removing post-RA passes, we might
+             *      want to create a post-RA dead-code elim pass */
+            if (vtmp->refCount() == 0)
+               delete_Instruction(bb->getProgram(), def);
+
+            break;
+         }
+         break;
+      default:
+         break;
+      }
+   }
+
+   return true;
+}
+
+// =============================================================================
+
 // Common subexpression elimination. Stupid O^2 implementation.
 class LocalCSE : public Pass
 {
@@ -2548,6 +2772,7 @@ Program::optimizeSSA(int level)
 {
    RUN_PASS(1, DeadCodeElim, buryAll);
    RUN_PASS(1, CopyPropagation, run);
+   RUN_PASS(1, MergeSplits, run);
    RUN_PASS(2, GlobalCSE, run);
    RUN_PASS(1, LocalCSE, run);
    RUN_PASS(2, AlgebraicOpt, run);
@@ -2565,6 +2790,9 @@ bool
 Program::optimizePostRA(int level)
 {
    RUN_PASS(2, FlatteningPass, run);
+   if (getTarget()->getChipset() < 0xc0)
+      RUN_PASS(2, NV50PostRaConstantFolding, run);
+
    return true;
 }