nv50/ir: fix use-after-free in ConstantFolding::visit
authorKarol Herbst <kherbst@redhat.com>
Fri, 7 Dec 2018 08:44:55 +0000 (09:44 +0100)
committerKarol Herbst <kherbst@redhat.com>
Sun, 9 Dec 2018 17:19:59 +0000 (18:19 +0100)
opnd() might delete the passed in instruction, but it's used through
i->srcExists() later in visit

v2: use continue instead return
v3: use brackets for the outer if/else chain

Signed-off-by: Karol Herbst <kherbst@redhat.com>
Reviewed-by: Ilia Mirkin <imirkin@alum.mit.edu>
src/gallium/drivers/nouveau/codegen/nv50_ir_peephole.cpp

index 202faf0746a784ea95d2d9f03af2e783506ed921..5d3b4aba9cce0e86c09912f59bfdfa77f69908c7 100644 (file)
@@ -370,7 +370,8 @@ private:
 
    void expr(Instruction *, ImmediateValue&, ImmediateValue&);
    void expr(Instruction *, ImmediateValue&, ImmediateValue&, ImmediateValue&);
-   void opnd(Instruction *, ImmediateValue&, int s);
+   /* true if i was deleted */
+   bool opnd(Instruction *i, ImmediateValue&, int s);
    void opnd3(Instruction *, ImmediateValue&);
 
    void unary(Instruction *, const ImmediateValue&);
@@ -414,18 +415,21 @@ ConstantFolding::visit(BasicBlock *bb)
       if (i->srcExists(2) &&
           i->src(0).getImmediate(src0) &&
           i->src(1).getImmediate(src1) &&
-          i->src(2).getImmediate(src2))
+          i->src(2).getImmediate(src2)) {
          expr(i, src0, src1, src2);
-      else
+      else
       if (i->srcExists(1) &&
-          i->src(0).getImmediate(src0) && i->src(1).getImmediate(src1))
+          i->src(0).getImmediate(src0) && i->src(1).getImmediate(src1)) {
          expr(i, src0, src1);
-      else
-      if (i->srcExists(0) && i->src(0).getImmediate(src0))
-         opnd(i, src0, 0);
-      else
-      if (i->srcExists(1) && i->src(1).getImmediate(src1))
-         opnd(i, src1, 1);
+      } else
+      if (i->srcExists(0) && i->src(0).getImmediate(src0)) {
+         if (opnd(i, src0, 0))
+            continue;
+      } else
+      if (i->srcExists(1) && i->src(1).getImmediate(src1)) {
+         if (opnd(i, src1, 1))
+            continue;
+      }
       if (i->srcExists(2) && i->src(2).getImmediate(src2))
          opnd3(i, src2);
    }
@@ -1011,12 +1015,13 @@ ConstantFolding::createMul(DataType ty, Value *def, Value *a, int64_t b, Value *
    return false;
 }
 
-void
+bool
 ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
 {
    const int t = !s;
    const operation op = i->op;
    Instruction *newi = i;
+   bool deleted = false;
 
    switch (i->op) {
    case OP_SPLIT: {
@@ -1036,6 +1041,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
          val >>= bitsize;
       }
       delete_Instruction(prog, i);
+      deleted = true;
       break;
    }
    case OP_MUL:
@@ -1050,6 +1056,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
             newi = bld.mkCmp(OP_SET, CC_LT, TYPE_S32, i->getDef(0),
                              TYPE_S32, i->getSrc(t), bld.mkImm(0));
             delete_Instruction(prog, i);
+            deleted = true;
          } else if (imm0.isInteger(0) || imm0.isInteger(1)) {
             // The high bits can't be set in this case (either mul by 0 or
             // unsigned by 1)
@@ -1101,8 +1108,10 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
       if (!isFloatType(i->dType) && !i->src(t).mod) {
          bld.setPosition(i, false);
          int64_t b = typeSizeof(i->dType) == 8 ? imm0.reg.data.s64 : imm0.reg.data.s32;
-         if (createMul(i->dType, i->getDef(0), i->getSrc(t), b, NULL))
+         if (createMul(i->dType, i->getDef(0), i->getSrc(t), b, NULL)) {
             delete_Instruction(prog, i);
+            deleted = true;
+         }
       } else
       if (i->postFactor && i->sType == TYPE_F32) {
          /* Can't emit a postfactor with an immediate, have to fold it in */
@@ -1139,8 +1148,10 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
       if (!isFloatType(i->dType) && !i->subOp && !i->src(t).mod && !i->src(2).mod) {
          bld.setPosition(i, false);
          int64_t b = typeSizeof(i->dType) == 8 ? imm0.reg.data.s64 : imm0.reg.data.s32;
-         if (createMul(i->dType, i->getDef(0), i->getSrc(t), b, i->getSrc(2)))
+         if (createMul(i->dType, i->getDef(0), i->getSrc(t), b, i->getSrc(2))) {
             delete_Instruction(prog, i);
+            deleted = true;
+         }
       }
       break;
    case OP_SUB:
@@ -1210,6 +1221,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
             bld.mkOp2(OP_SHR, TYPE_U32, i->getDef(0), tB, bld.mkImm(s));
 
          delete_Instruction(prog, i);
+         deleted = true;
       } else
       if (imm0.reg.data.s32 == -1) {
          i->op = OP_NEG;
@@ -1242,6 +1254,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
             bld.mkOp1(OP_NEG, TYPE_S32, i->getDef(0), tB);
 
          delete_Instruction(prog, i);
+         deleted = true;
       }
       break;
 
@@ -1273,6 +1286,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
             newi = bld.mkOp2(OP_UNION, TYPE_S32, i->getDef(0), v1, v2);
 
             delete_Instruction(prog, i);
+            deleted = true;
          }
       } else if (s == 1) {
          // In this case, we still want the optimized lowering that we get
@@ -1289,6 +1303,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
          newi->src(1).mod = Modifier(NV50_IR_MOD_NEG);
 
          delete_Instruction(prog, i);
+         deleted = true;
       }
       break;
 
@@ -1301,7 +1316,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
       CmpInstruction *si = findOriginForTestWithZero(i->getSrc(t));
       CondCode cc, ccZ;
       if (imm0.reg.data.u32 != 0 || !si)
-         return;
+         return false;
       cc = si->setCond;
       ccZ = (CondCode)((unsigned int)i->asCmp()->setCond & ~CC_U);
       // We do everything assuming var (cmp) 0, reverse the condition if 0 is
@@ -1327,7 +1342,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
       case CC_GT: break; // bool > 0 -- bool
       case CC_NE: break; // bool != 0 -- bool
       default:
-         return;
+         return false;
       }
 
       // Update the condition of this SET to be identical to the origin set,
@@ -1362,13 +1377,13 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
       } else if (src->asCmp()) {
          CmpInstruction *cmp = src->asCmp();
          if (!cmp || cmp->op == OP_SLCT || cmp->getDef(0)->refCount() > 1)
-            return;
+            return false;
          if (!prog->getTarget()->isOpSupported(cmp->op, TYPE_F32))
-            return;
+            return false;
          if (imm0.reg.data.f32 != 1.0)
-            return;
+            return false;
          if (cmp->dType != TYPE_U32)
-            return;
+            return false;
 
          cmp->dType = TYPE_F32;
          if (i->src(t).mod != Modifier(0)) {
@@ -1435,13 +1450,13 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
       case OP_MUL:
          int muls;
          if (isFloatType(si->dType))
-            return;
+            return false;
          if (si->src(1).getImmediate(imm1))
             muls = 1;
          else if (si->src(0).getImmediate(imm1))
             muls = 0;
          else
-            return;
+            return false;
 
          bld.setPosition(i, false);
          i->op = OP_MUL;
@@ -1452,15 +1467,15 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
       case OP_ADD:
          int adds;
          if (isFloatType(si->dType))
-            return;
+            return false;
          if (si->op != OP_SUB && si->src(0).getImmediate(imm1))
             adds = 0;
          else if (si->src(1).getImmediate(imm1))
             adds = 1;
          else
-            return;
+            return false;
          if (si->src(!adds).mod != Modifier(0))
-            return;
+            return false;
          // SHL(ADD(x, y), z) = ADD(SHL(x, z), SHL(y, z))
 
          // This is more operations, but if one of x, y is an immediate, then
@@ -1475,7 +1490,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
                                      bld.mkImm(imm0.reg.data.u32)));
          break;
       default:
-         return;
+         return false;
       }
    }
       break;
@@ -1500,7 +1515,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
       case TYPE_S32: res = util_last_bit_signed(imm0.reg.data.s32) - 1; break;
       case TYPE_U32: res = util_last_bit(imm0.reg.data.u32) - 1; break;
       default:
-         return;
+         return false;
       }
       if (i->subOp == NV50_IR_SUBOP_BFIND_SAMT && res >= 0)
          res = 31 - res;
@@ -1526,11 +1541,11 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
 
       // TODO: handle 64-bit values properly
       if (typeSizeof(i->dType) == 8 || typeSizeof(i->sType) == 8)
-         return;
+         return false;
 
       // TODO: handle single byte/word extractions
       if (i->subOp)
-         return;
+         return false;
 
       bld.setPosition(i, true); /* make sure bld is init'ed */
 
@@ -1567,7 +1582,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
                         CLAMP(imm0.reg.data.u16, umin, umax) : \
                         imm0.reg.data.u16; \
          break; \
-      default: return; \
+      default: return false; \
       } \
       i->setSrc(0, bld.mkImm(res.data.dst)); \
       break
@@ -1594,7 +1609,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
          case TYPE_S16: res.data.f32 = (float) imm0.reg.data.s16; break;
          case TYPE_S32: res.data.f32 = (float) imm0.reg.data.s32; break;
          default:
-            return;
+            return false;
          }
          i->setSrc(0, bld.mkImm(res.data.f32));
          break;
@@ -1615,12 +1630,12 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
          case TYPE_S16: res.data.f64 = (double) imm0.reg.data.s16; break;
          case TYPE_S32: res.data.f64 = (double) imm0.reg.data.s32; break;
          default:
-            return;
+            return false;
          }
          i->setSrc(0, bld.mkImm(res.data.f64));
          break;
       default:
-         return;
+         return false;
       }
 #undef CASE
 
@@ -1631,7 +1646,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
       break;
    }
    default:
-      return;
+      return false;
    }
 
    // This can get left behind some of the optimizations which simplify
@@ -1646,6 +1661,7 @@ ConstantFolding::opnd(Instruction *i, ImmediateValue &imm0, int s)
 
    if (newi->op != op)
       foldCount++;
+   return deleted;
 }
 
 // =============================================================================