nv50/ir: optimize signed integer modulo by pow-of-2
[mesa.git] / src / gallium / drivers / nouveau / codegen / nv50_ir_peephole.cpp
index 7e4e193e3d27d08141c4ba7d533d9996c72de277..2448c737e73c1072441c98b306dc19291a525212 100644 (file)
@@ -1054,6 +1054,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
          i->op = OP_ADD;
       } else
       if (s == 1 && !imm0.isNegative() && imm0.isPow2() &&
+          !isFloatType(i->dType) &&
           target->isOpSupported(OP_SHLADD, i->dType)) {
          i->op = OP_SHLADD;
          imm0.applyLog2();
@@ -1163,10 +1164,34 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
       break;
 
    case OP_MOD:
-      if (i->sType == TYPE_U32 && imm0.isPow2()) {
+      if (s == 1 && imm0.isPow2()) {
          bld.setPosition(i, false);
-         i->op = OP_AND;
-         i->setSrc(1, bld.loadImm(NULL, imm0.reg.data.u32 - 1));
+         if (i->sType == TYPE_U32) {
+            i->op = OP_AND;
+            i->setSrc(1, bld.loadImm(NULL, imm0.reg.data.u32 - 1));
+         } else if (i->sType == TYPE_S32) {
+            // Do it on the absolute value of the input, and then restore the
+            // sign. The only odd case is MIN_INT, but that should work out
+            // as well, since MIN_INT mod any power of 2 is 0.
+            //
+            // Technically we don't have to do any of this since MOD is
+            // undefined with negative arguments in GLSL, but this seems like
+            // the nice thing to do.
+            Value *abs = bld.mkOp1v(OP_ABS, TYPE_S32, bld.getSSA(), i->getSrc(0));
+            Value *neg, *v1, *v2;
+            bld.mkCmp(OP_SET, CC_LT, TYPE_S32,
+                      (neg = bld.getSSA(1, prog->getTarget()->nativeFile(FILE_PREDICATE))),
+                      TYPE_S32, i->getSrc(0), bld.loadImm(NULL, 0));
+            Value *mod = bld.mkOp2v(OP_AND, TYPE_U32, bld.getSSA(), abs,
+                                    bld.loadImm(NULL, imm0.reg.data.u32 - 1));
+            bld.mkOp1(OP_NEG, TYPE_S32, (v1 = bld.getSSA()), mod)
+               ->setPredicate(CC_P, neg);
+            bld.mkOp1(OP_MOV, TYPE_S32, (v2 = bld.getSSA()), mod)
+               ->setPredicate(CC_NOT_P, neg);
+            newi = bld.mkOp2(OP_UNION, TYPE_S32, i->getDef(0), v1, v2);
+
+            delete_Instruction(prog, i);
+         }
       }
       break;