nv50/ir: optimize imul/imad to xmads
[mesa.git] / src / gallium / drivers / nouveau / codegen / nv50_ir_peephole.cpp
index 5dac2f64f0a1e3547fffa39a4a9a68d837d3d9ca..dc7bf24ba238e38b1eff3d63b032f163b4d63bb0 100644 (file)
@@ -191,9 +191,17 @@ void
 LoadPropagation::checkSwapSrc01(Instruction *insn)
 {
    const Target *targ = prog->getTarget();
-   if (!targ->getOpInfo(insn).commutative)
-      if (insn->op != OP_SET && insn->op != OP_SLCT && insn->op != OP_SUB)
+   if (!targ->getOpInfo(insn).commutative) {
+      if (insn->op != OP_SET && insn->op != OP_SLCT &&
+          insn->op != OP_SUB && insn->op != OP_XMAD)
          return;
+      // XMAD is only commutative if both the CBCC and MRG flags are not set.
+      if (insn->op == OP_XMAD &&
+          (insn->subOp & NV50_IR_SUBOP_XMAD_CMODE_MASK) == NV50_IR_SUBOP_XMAD_CBCC)
+         return;
+      if (insn->op == OP_XMAD && (insn->subOp & NV50_IR_SUBOP_XMAD_MRG))
+         return;
+   }
    if (insn->src(1).getFile() != FILE_GPR)
       return;
    // This is the special OP_SET used for alphatesting, we can't reverse its
@@ -236,6 +244,12 @@ LoadPropagation::checkSwapSrc01(Instruction *insn)
    if (insn->op == OP_SUB) {
       insn->src(0).mod = insn->src(0).mod ^ Modifier(NV50_IR_MOD_NEG);
       insn->src(1).mod = insn->src(1).mod ^ Modifier(NV50_IR_MOD_NEG);
+   } else
+   if (insn->op == OP_XMAD) {
+      // swap h1 flags
+      uint16_t h1 = (insn->subOp >> 1 & NV50_IR_SUBOP_XMAD_H1(0)) |
+                    (insn->subOp << 1 & NV50_IR_SUBOP_XMAD_H1(1));
+      insn->subOp = (insn->subOp & ~NV50_IR_SUBOP_XMAD_H1_MASK) | h1;
    }
 }
 
@@ -283,6 +297,8 @@ class IndirectPropagation : public Pass
 {
 private:
    virtual bool visit(BasicBlock *);
+
+   BuildUtil bld;
 };
 
 bool
@@ -294,6 +310,8 @@ IndirectPropagation::visit(BasicBlock *bb)
    for (Instruction *i = bb->getEntry(); i; i = next) {
       next = i->next;
 
+      bld.setPosition(i, false);
+
       for (int s = 0; i->srcExists(s); ++s) {
          Instruction *insn;
          ImmediateValue imm;
@@ -325,6 +343,14 @@ IndirectPropagation::visit(BasicBlock *bb)
             i->setIndirect(s, 0, NULL);
             i->setSrc(s, cloneShallow(func, i->getSrc(s)));
             i->src(s).get()->reg.data.offset += imm.reg.data.u32;
+         } else if (insn->op == OP_SHLADD) {
+            if (!insn->src(2).getImmediate(imm) ||
+                !targ->insnCanLoadOffset(i, s, imm.reg.data.s32))
+               continue;
+            i->setIndirect(s, 0, bld.mkOp2v(
+               OP_SHL, TYPE_U32, bld.getSSA(), insn->getSrc(0), insn->getSrc(1)));
+            i->setSrc(s, cloneShallow(func, i->getSrc(s)));
+            i->src(s).get()->reg.data.offset += imm.reg.data.u32;
          }
       }
    }
@@ -727,7 +753,9 @@ ConstantFolding::expr(Instruction *i,
       // Leave PFETCH alone... we just folded its 2 args into 1.
       break;
    default:
-      i->op = i->saturate ? OP_SAT : OP_MOV; /* SAT handled by unary() */
+      i->op = i->saturate ? OP_SAT : OP_MOV;
+      if (i->saturate)
+         unary(i, *i->getSrc(0)->asImm());
       break;
    }
    i->subOp = 0;
@@ -1052,7 +1080,9 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
          i->op = OP_ADD;
       } else
       if (s == 1 && !imm0.isNegative() && imm0.isPow2() &&
-          target->isOpSupported(OP_SHLADD, i->dType)) {
+          !isFloatType(i->dType) &&
+          target->isOpSupported(OP_SHLADD, i->dType) &&
+          !i->subOp) {
          i->op = OP_SHLADD;
          imm0.applyLog2();
          i->setSrc(1, new_ImmediateValue(prog, imm0.reg.data.u32));
@@ -1161,10 +1191,49 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
       break;
 
    case OP_MOD:
-      if (i->sType == TYPE_U32 && imm0.isPow2()) {
+      if (s == 1 && imm0.isPow2()) {
          bld.setPosition(i, false);
-         i->op = OP_AND;
-         i->setSrc(1, bld.loadImm(NULL, imm0.reg.data.u32 - 1));
+         if (i->sType == TYPE_U32) {
+            i->op = OP_AND;
+            i->setSrc(1, bld.loadImm(NULL, imm0.reg.data.u32 - 1));
+         } else if (i->sType == TYPE_S32) {
+            // Do it on the absolute value of the input, and then restore the
+            // sign. The only odd case is MIN_INT, but that should work out
+            // as well, since MIN_INT mod any power of 2 is 0.
+            //
+            // Technically we don't have to do any of this since MOD is
+            // undefined with negative arguments in GLSL, but this seems like
+            // the nice thing to do.
+            Value *abs = bld.mkOp1v(OP_ABS, TYPE_S32, bld.getSSA(), i->getSrc(0));
+            Value *neg, *v1, *v2;
+            bld.mkCmp(OP_SET, CC_LT, TYPE_S32,
+                      (neg = bld.getSSA(1, prog->getTarget()->nativeFile(FILE_PREDICATE))),
+                      TYPE_S32, i->getSrc(0), bld.loadImm(NULL, 0));
+            Value *mod = bld.mkOp2v(OP_AND, TYPE_U32, bld.getSSA(), abs,
+                                    bld.loadImm(NULL, imm0.reg.data.u32 - 1));
+            bld.mkOp1(OP_NEG, TYPE_S32, (v1 = bld.getSSA()), mod)
+               ->setPredicate(CC_P, neg);
+            bld.mkOp1(OP_MOV, TYPE_S32, (v2 = bld.getSSA()), mod)
+               ->setPredicate(CC_NOT_P, neg);
+            newi = bld.mkOp2(OP_UNION, TYPE_S32, i->getDef(0), v1, v2);
+
+            delete_Instruction(prog, i);
+         }
+      } else if (s == 1) {
+         // In this case, we still want the optimized lowering that we get
+         // from having division by an immediate.
+         //
+         // a % b == a - (a/b) * b
+         bld.setPosition(i, false);
+         Value *div = bld.mkOp2v(OP_DIV, i->sType, bld.getSSA(),
+                                 i->getSrc(0), i->getSrc(1));
+         newi = bld.mkOp2(OP_ADD, i->sType, i->getDef(0), i->getSrc(0),
+                          bld.mkOp2v(OP_MUL, i->sType, bld.getSSA(), div, i->getSrc(1)));
+         // TODO: Check that target supports this. In this case, we know that
+         // all backends do.
+         newi->src(1).mod = Modifier(NV50_IR_MOD_NEG);
+
+         delete_Instruction(prog, i);
       }
       break;
 
@@ -1262,7 +1331,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
                  src->op == OP_SHR &&
                  src->src(1).getImmediate(imm1) &&
                  i->src(t).mod == Modifier(0) &&
-                 util_is_power_of_two(imm0.reg.data.u32 + 1)) {
+                 util_is_power_of_two_or_zero(imm0.reg.data.u32 + 1)) {
          // low byte = offset, high byte = width
          uint32_t ext = (util_last_bit(imm0.reg.data.u32) << 8) | imm1.reg.data.u32;
          i->op = OP_EXTBF;
@@ -1271,7 +1340,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
       } else if (src->op == OP_SHL &&
                  src->src(1).getImmediate(imm1) &&
                  i->src(t).mod == Modifier(0) &&
-                 util_is_power_of_two(~imm0.reg.data.u32 + 1) &&
+                 util_is_power_of_two_or_zero(~imm0.reg.data.u32 + 1) &&
                  util_last_bit(~imm0.reg.data.u32) <= imm1.reg.data.u32) {
          i->op = OP_MOV;
          i->setSrc(s, NULL);
@@ -1509,6 +1578,17 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
    default:
       return;
    }
+
+   // This can get left behind some of the optimizations which simplify
+   // saturatable values.
+   if (newi->op == OP_MOV && newi->saturate) {
+      ImmediateValue tmp;
+      newi->saturate = 0;
+      newi->op = OP_SAT;
+      if (newi->src(0).getImmediate(tmp))
+         unary(newi, tmp);
+   }
+
    if (newi->op != op)
       foldCount++;
 }
@@ -1600,6 +1680,7 @@ ModifierFolding::visit(BasicBlock *bb)
 // SLCT(a, b, const) -> cc(const) ? a : b
 // RCP(RCP(a)) -> a
 // MUL(MUL(a, b), const) -> MUL_Xconst(a, b)
+// EXTBF(RDSV(COMBINED_TID)) -> RDSV(TID)
 class AlgebraicOpt : public Pass
 {
 private:
@@ -1617,6 +1698,7 @@ private:
    void handleCVT_EXTBF(Instruction *);
    void handleSUCLAMP(Instruction *);
    void handleNEG(Instruction *);
+   void handleEXTBF_RDSV(Instruction *);
 
    BuildUtil bld;
 };
@@ -1677,7 +1759,8 @@ AlgebraicOpt::handleADD(Instruction *add)
       return false;
 
    bool changed = false;
-   if (!changed && prog->getTarget()->isOpSupported(OP_MAD, add->dType))
+   // we can't optimize to MAD if the add is precise
+   if (!add->precise && prog->getTarget()->isOpSupported(OP_MAD, add->dType))
       changed = tryADDToMADOrSAD(add, OP_MAD);
    if (!changed && prog->getTarget()->isOpSupported(OP_SAD, add->dType))
       changed = tryADDToMADOrSAD(add, OP_SAD);
@@ -1713,7 +1796,7 @@ AlgebraicOpt::tryADDToMADOrSAD(Instruction *add, operation toOp)
       return false;
 
    if (src->getInsn()->saturate || src->getInsn()->postFactor ||
-       src->getInsn()->dnz)
+       src->getInsn()->dnz || src->getInsn()->precise)
       return false;
 
    if (toOp == OP_SAD) {
@@ -1779,15 +1862,24 @@ AlgebraicOpt::handleMINMAX(Instruction *minmax)
    }
 }
 
+// rcp(rcp(a)) = a
+// rcp(sqrt(a)) = rsq(a)
 void
 AlgebraicOpt::handleRCP(Instruction *rcp)
 {
    Instruction *si = rcp->getSrc(0)->getUniqueInsn();
 
-   if (si && si->op == OP_RCP) {
+   if (!si)
+      return;
+
+   if (si->op == OP_RCP) {
       Modifier mod = rcp->src(0).mod * si->src(0).mod;
       rcp->op = mod.getOp();
       rcp->setSrc(0, si->getSrc(0));
+   } else if (si->op == OP_SQRT) {
+      rcp->op = OP_RSQ;
+      rcp->setSrc(0, si->getSrc(0));
+      rcp->src(0).mod = rcp->src(0).mod * si->src(0).mod;
    }
 }
 
@@ -2120,6 +2212,41 @@ AlgebraicOpt::handleNEG(Instruction *i) {
    }
 }
 
+// EXTBF(RDSV(COMBINED_TID)) -> RDSV(TID)
+void
+AlgebraicOpt::handleEXTBF_RDSV(Instruction *i)
+{
+   Instruction *rdsv = i->getSrc(0)->getUniqueInsn();
+   if (rdsv->op != OP_RDSV ||
+       rdsv->getSrc(0)->asSym()->reg.data.sv.sv != SV_COMBINED_TID)
+      return;
+   // Avoid creating more RDSV instructions
+   if (rdsv->getDef(0)->refCount() > 1)
+      return;
+
+   ImmediateValue imm;
+   if (!i->src(1).getImmediate(imm))
+      return;
+
+   int index;
+   if (imm.isInteger(0x1000))
+      index = 0;
+   else
+   if (imm.isInteger(0x0a10))
+      index = 1;
+   else
+   if (imm.isInteger(0x061a))
+      index = 2;
+   else
+      return;
+
+   bld.setPosition(i, false);
+
+   i->op = OP_RDSV;
+   i->setSrc(0, bld.mkSysVal(SV_TID, index));
+   i->setSrc(1, NULL);
+}
+
 bool
 AlgebraicOpt::visit(BasicBlock *bb)
 {
@@ -2160,6 +2287,9 @@ AlgebraicOpt::visit(BasicBlock *bb)
       case OP_NEG:
          handleNEG(i);
          break;
+      case OP_EXTBF:
+         handleEXTBF_RDSV(i);
+         break;
       default:
          break;
       }
@@ -2171,13 +2301,18 @@ AlgebraicOpt::visit(BasicBlock *bb)
 // =============================================================================
 
 // ADD(SHL(a, b), c) -> SHLADD(a, b, c)
+// MUL(a, b) -> a few XMADs
+// MAD/FMA(a, b, c) -> a few XMADs
 class LateAlgebraicOpt : public Pass
 {
 private:
    virtual bool visit(Instruction *);
 
    void handleADD(Instruction *);
+   void handleMULMAD(Instruction *);
    bool tryADDToSHLADD(Instruction *);
+
+   BuildUtil bld;
 };
 
 void
@@ -2238,6 +2373,52 @@ LateAlgebraicOpt::tryADDToSHLADD(Instruction *add)
    return true;
 }
 
+// MUL(a, b) -> a few XMADs
+// MAD/FMA(a, b, c) -> a few XMADs
+void
+LateAlgebraicOpt::handleMULMAD(Instruction *i)
+{
+   // TODO: handle NV50_IR_SUBOP_MUL_HIGH
+   if (!prog->getTarget()->isOpSupported(OP_XMAD, TYPE_U32))
+      return;
+   if (isFloatType(i->dType) || typeSizeof(i->dType) != 4)
+      return;
+   if (i->subOp || i->usesFlags() || i->flagsDef >= 0)
+      return;
+
+   assert(!i->src(0).mod);
+   assert(!i->src(1).mod);
+   assert(i->op == OP_MUL ? 1 : !i->src(2).mod);
+
+   bld.setPosition(i, false);
+
+   Value *a = i->getSrc(0);
+   Value *b = i->getSrc(1);
+   Value *c = i->op == OP_MUL ? bld.mkImm(0) : i->getSrc(2);
+
+   Value *tmp0 = bld.getSSA();
+   Value *tmp1 = bld.getSSA();
+
+   Instruction *insn = bld.mkOp3(OP_XMAD, TYPE_U32, tmp0, b, a, c);
+   insn->setPredicate(i->cc, i->getPredicate());
+
+   insn = bld.mkOp3(OP_XMAD, TYPE_U32, tmp1, b, a, bld.mkImm(0));
+   insn->setPredicate(i->cc, i->getPredicate());
+   insn->subOp = NV50_IR_SUBOP_XMAD_MRG | NV50_IR_SUBOP_XMAD_H1(1);
+
+   Value *pred = i->getPredicate();
+   i->setPredicate(i->cc, NULL);
+
+   i->op = OP_XMAD;
+   i->setSrc(0, b);
+   i->setSrc(1, tmp1);
+   i->setSrc(2, tmp0);
+   i->subOp = NV50_IR_SUBOP_XMAD_PSL | NV50_IR_SUBOP_XMAD_CBCC;
+   i->subOp |= NV50_IR_SUBOP_XMAD_H1(0) | NV50_IR_SUBOP_XMAD_H1(1);
+
+   i->setPredicate(i->cc, pred);
+}
+
 bool
 LateAlgebraicOpt::visit(Instruction *i)
 {
@@ -2245,6 +2426,11 @@ LateAlgebraicOpt::visit(Instruction *i)
    case OP_ADD:
       handleADD(i);
       break;
+   case OP_MUL:
+   case OP_MAD:
+   case OP_FMA:
+      handleMULMAD(i);
+      break;
    default:
       break;
    }
@@ -2486,6 +2672,10 @@ MemoryOpt::combineLd(Record *rec, Instruction *ld)
 
    assert(sizeRc + sizeLd <= 16 && offRc != offLd);
 
+   // lock any stores that overlap with the load being merged into the
+   // existing record.
+   lockStores(ld);
+
    for (j = 0; sizeRc; sizeRc -= rec->insn->getDef(j)->reg.size, ++j);
 
    if (offLd < offRc) {
@@ -2542,6 +2732,10 @@ MemoryOpt::combineSt(Record *rec, Instruction *st)
    if (prog->getType() == Program::TYPE_COMPUTE && rec->rel[0])
       return false;
 
+   // remove any existing load/store records for the store being merged into
+   // the existing record.
+   purgeRecords(st, DATA_FILE_COUNT);
+
    st->takeExtraSources(0, extra); // save predicate and indirect address
 
    if (offRc < offSt) {
@@ -2641,7 +2835,7 @@ MemoryOpt::findRecord(const Instruction *insn, bool load, bool& isAdj) const
    Record *it = load ? loads[sym->reg.file] : stores[sym->reg.file];
 
    for (; it; it = it->next) {
-      if (it->locked && insn->op != OP_LOAD)
+      if (it->locked && insn->op != OP_LOAD && insn->op != OP_VFETCH)
          continue;
       if ((it->offset >> 4) != (sym->reg.data.offset >> 4) ||
           it->rel[0] != insn->getIndirect(0, 0) ||
@@ -2779,11 +2973,15 @@ MemoryOpt::Record::overlaps(const Instruction *ldst) const
    Record that;
    that.set(ldst);
 
-   if (this->fileIndex != that.fileIndex)
+   // This assumes that images/buffers can't overlap. They can.
+   // TODO: Plumb the restrict logic through, and only skip when it's a
+   // restrict situation, or there can implicitly be no writes.
+   if (this->fileIndex != that.fileIndex && this->rel[1] == that.rel[1])
       return false;
 
    if (this->rel[0] || that.rel[0])
       return this->base == that.base;
+
    return
       (this->offset < that.offset + that.size) &&
       (this->offset + this->size > that.offset);
@@ -3242,7 +3440,9 @@ PostRaLoadPropagation::handleMADforNV50(Instruction *i)
          i->setSrc(1, def->getSrc(0));
       } else {
          ImmediateValue val;
-         bool ret = def->src(0).getImmediate(val);
+         // getImmediate() has side-effects on the argument so this *shouldn't*
+         // be folded into the assert()
+         MAYBE_UNUSED bool ret = def->src(0).getImmediate(val);
          assert(ret);
          if (i->getSrc(1)->reg.data.id & 1)
             val.reg.data.u32 >>= 16;
@@ -3363,6 +3563,11 @@ Instruction::isActionEqual(const Instruction *that) const
    } else
    if (this->asFlow()) {
       return false;
+   } else
+   if (this->op == OP_PHI && this->bb != that->bb) {
+      /* TODO: we could probably be a bit smarter here by following the
+       * control flow, but honestly, it is quite painful to check */
+      return false;
    } else {
       if (this->ipa != that->ipa ||
           this->lanes != that->lanes ||
@@ -3459,6 +3664,7 @@ GlobalCSE::visit(BasicBlock *bb)
             break;
       }
       if (!phi->srcExists(s)) {
+         assert(ik->op != OP_PHI);
          Instruction *entry = bb->getEntry();
          ik->bb->remove(ik);
          if (!entry || entry->op != OP_JOIN)
@@ -3728,8 +3934,8 @@ Program::optimizeSSA(int level)
    RUN_PASS(2, AlgebraicOpt, run);
    RUN_PASS(2, ModifierFolding, run); // before load propagation -> less checks
    RUN_PASS(1, ConstantFolding, foldAll);
-   RUN_PASS(2, LateAlgebraicOpt, run);
    RUN_PASS(1, Split64BitOpPreRA, run);
+   RUN_PASS(2, LateAlgebraicOpt, run);
    RUN_PASS(1, LoadPropagation, run);
    RUN_PASS(1, IndirectPropagation, run);
    RUN_PASS(2, MemoryOpt, run);