Refactor rewriteMinus (#7932)
authorGereon Kremer <gkremer@stanford.edu>
Wed, 12 Jan 2022 22:19:31 +0000 (14:19 -0800)
committerGitHub <noreply@github.com>
Wed, 12 Jan 2022 22:19:31 +0000 (22:19 +0000)
Refactors rewriteMinus to be generally simpler and explicitly rely on rewriting addition.
Note that rewriteMinus would almost never be called in post rewrite anyway, as it would rewrite to addition in pre rewrite.

src/theory/arith/arith_rewriter.cpp
src/theory/arith/arith_rewriter.h

index 4e79515bd737de0d642286f5dc8db7993618f4c7..2c3fcdf48b48e49a904dab3d9a0289e993c39a45 100644 (file)
@@ -60,25 +60,21 @@ RewriteResponse ArithRewriter::rewriteVariable(TNode t){
   return RewriteResponse(REWRITE_DONE, t);
 }
 
-RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
+RewriteResponse ArithRewriter::rewriteMinus(TNode t)
+{
   Assert(t.getKind() == kind::MINUS);
+  Assert(t.getNumChildren() == 2);
 
-  if(pre){
-    if(t[0] == t[1]){
-      Rational zero(0);
-      Node zeroNode =
-          NodeManager::currentNM()->mkConstRealOrInt(t.getType(), zero);
-      return RewriteResponse(REWRITE_DONE, zeroNode);
-    }else{
-      Node noMinus = makeSubtractionNode(t[0],t[1]);
-      return RewriteResponse(REWRITE_DONE, noMinus);
-    }
-  }else{
-    Polynomial minuend = Polynomial::parsePolynomial(t[0]);
-    Polynomial subtrahend = Polynomial::parsePolynomial(t[1]);
-    Polynomial diff = minuend - subtrahend;
-    return RewriteResponse(REWRITE_DONE, diff.getNode());
+  auto* nm = NodeManager::currentNM();
+
+  if (t[0] == t[1])
+  {
+    return RewriteResponse(REWRITE_DONE,
+                           nm->mkConstRealOrInt(t.getType(), Rational(0)));
   }
+  return RewriteResponse(
+      REWRITE_AGAIN_FULL,
+      nm->mkNode(Kind::PLUS, t[0], makeUnaryMinusNode(t[1])));
 }
 
 RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
@@ -106,60 +102,56 @@ RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
     return rewriteVariable(t);
   }else{
     switch(Kind k = t.getKind()){
-    case kind::MINUS:
-      return rewriteMinus(t, true);
-    case kind::UMINUS:
-      return rewriteUMinus(t, true);
-    case kind::DIVISION:
-    case kind::DIVISION_TOTAL:
-      return rewriteDiv(t,true);
-    case kind::PLUS:
-      return preRewritePlus(t);
-    case kind::MULT:
-    case kind::NONLINEAR_MULT: return preRewriteMult(t);
-    case kind::IAND: return RewriteResponse(REWRITE_DONE, t);
-    case kind::POW2: return RewriteResponse(REWRITE_DONE, t);
-    case kind::EXPONENTIAL:
-    case kind::SINE:
-    case kind::COSINE:
-    case kind::TANGENT:
-    case kind::COSECANT:
-    case kind::SECANT:
-    case kind::COTANGENT:
-    case kind::ARCSINE:
-    case kind::ARCCOSINE:
-    case kind::ARCTANGENT:
-    case kind::ARCCOSECANT:
-    case kind::ARCSECANT:
-    case kind::ARCCOTANGENT:
-    case kind::SQRT: return preRewriteTranscendental(t);
-    case kind::INTS_DIVISION:
-    case kind::INTS_MODULUS: return rewriteIntsDivMod(t, true);
-    case kind::INTS_DIVISION_TOTAL:
-    case kind::INTS_MODULUS_TOTAL:
-      return rewriteIntsDivModTotal(t,true);
-    case kind::ABS:
-      if(t[0].isConst()) {
-        const Rational& rat = t[0].getConst<Rational>();
-        if(rat >= 0) {
-          return RewriteResponse(REWRITE_DONE, t[0]);
-        } else {
-          return RewriteResponse(
-              REWRITE_DONE,
-              NodeManager::currentNM()->mkConstRealOrInt(t[0].getType(), -rat));
+      case kind::MINUS: return rewriteMinus(t);
+      case kind::UMINUS: return rewriteUMinus(t, true);
+      case kind::DIVISION:
+      case kind::DIVISION_TOTAL: return rewriteDiv(t, true);
+      case kind::PLUS: return preRewritePlus(t);
+      case kind::MULT:
+      case kind::NONLINEAR_MULT: return preRewriteMult(t);
+      case kind::IAND: return RewriteResponse(REWRITE_DONE, t);
+      case kind::POW2: return RewriteResponse(REWRITE_DONE, t);
+      case kind::EXPONENTIAL:
+      case kind::SINE:
+      case kind::COSINE:
+      case kind::TANGENT:
+      case kind::COSECANT:
+      case kind::SECANT:
+      case kind::COTANGENT:
+      case kind::ARCSINE:
+      case kind::ARCCOSINE:
+      case kind::ARCTANGENT:
+      case kind::ARCCOSECANT:
+      case kind::ARCSECANT:
+      case kind::ARCCOTANGENT:
+      case kind::SQRT: return preRewriteTranscendental(t);
+      case kind::INTS_DIVISION:
+      case kind::INTS_MODULUS: return rewriteIntsDivMod(t, true);
+      case kind::INTS_DIVISION_TOTAL:
+      case kind::INTS_MODULUS_TOTAL: return rewriteIntsDivModTotal(t, true);
+      case kind::ABS:
+        if (t[0].isConst())
+        {
+          const Rational& rat = t[0].getConst<Rational>();
+          if (rat >= 0)
+          {
+            return RewriteResponse(REWRITE_DONE, t[0]);
+          }
+          else
+          {
+            return RewriteResponse(REWRITE_DONE,
+                                   NodeManager::currentNM()->mkConstRealOrInt(
+                                       t[0].getType(), -rat));
+          }
         }
-      }
-      return RewriteResponse(REWRITE_DONE, t);
-    case kind::IS_INTEGER:
-    case kind::TO_INTEGER:
-      return RewriteResponse(REWRITE_DONE, t);
-    case kind::TO_REAL:
-    case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
-    case kind::POW:
-      return RewriteResponse(REWRITE_DONE, t);
-    case kind::PI:
-      return RewriteResponse(REWRITE_DONE, t);
-    default: Unhandled() << k;
+        return RewriteResponse(REWRITE_DONE, t);
+      case kind::IS_INTEGER:
+      case kind::TO_INTEGER: return RewriteResponse(REWRITE_DONE, t);
+      case kind::TO_REAL:
+      case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
+      case kind::POW: return RewriteResponse(REWRITE_DONE, t);
+      case kind::PI: return RewriteResponse(REWRITE_DONE, t);
+      default: Unhandled() << k;
     }
   }
 }
@@ -172,54 +164,53 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
   }else{
     Trace("arith-rewriter") << "postRewriteTerm: " << t << std::endl;
     switch(t.getKind()){
-    case kind::MINUS:
-      return rewriteMinus(t, false);
-    case kind::UMINUS:
-      return rewriteUMinus(t, false);
-    case kind::DIVISION:
-    case kind::DIVISION_TOTAL:
-      return rewriteDiv(t, false);
-    case kind::PLUS:
-      return postRewritePlus(t);
-    case kind::MULT:
-    case kind::NONLINEAR_MULT: return postRewriteMult(t);
-    case kind::IAND: return postRewriteIAnd(t);
-    case kind::POW2: return postRewritePow2(t);
-    case kind::EXPONENTIAL:
-    case kind::SINE:
-    case kind::COSINE:
-    case kind::TANGENT:
-    case kind::COSECANT:
-    case kind::SECANT:
-    case kind::COTANGENT:
-    case kind::ARCSINE:
-    case kind::ARCCOSINE:
-    case kind::ARCTANGENT:
-    case kind::ARCCOSECANT:
-    case kind::ARCSECANT:
-    case kind::ARCCOTANGENT:
-    case kind::SQRT: return postRewriteTranscendental(t);
-    case kind::INTS_DIVISION:
-    case kind::INTS_MODULUS: return rewriteIntsDivMod(t, false);
-    case kind::INTS_DIVISION_TOTAL:
-    case kind::INTS_MODULUS_TOTAL:
-      return rewriteIntsDivModTotal(t, false);
-    case kind::ABS:
-      if(t[0].isConst()) {
-        const Rational& rat = t[0].getConst<Rational>();
-        if(rat >= 0) {
-          return RewriteResponse(REWRITE_DONE, t[0]);
-        } else {
-          return RewriteResponse(
-              REWRITE_DONE,
-              NodeManager::currentNM()->mkConstRealOrInt(t[0].getType(), -rat));
+      case kind::MINUS: return rewriteMinus(t);
+      case kind::UMINUS: return rewriteUMinus(t, false);
+      case kind::DIVISION:
+      case kind::DIVISION_TOTAL: return rewriteDiv(t, false);
+      case kind::PLUS: return postRewritePlus(t);
+      case kind::MULT:
+      case kind::NONLINEAR_MULT: return postRewriteMult(t);
+      case kind::IAND: return postRewriteIAnd(t);
+      case kind::POW2: return postRewritePow2(t);
+      case kind::EXPONENTIAL:
+      case kind::SINE:
+      case kind::COSINE:
+      case kind::TANGENT:
+      case kind::COSECANT:
+      case kind::SECANT:
+      case kind::COTANGENT:
+      case kind::ARCSINE:
+      case kind::ARCCOSINE:
+      case kind::ARCTANGENT:
+      case kind::ARCCOSECANT:
+      case kind::ARCSECANT:
+      case kind::ARCCOTANGENT:
+      case kind::SQRT: return postRewriteTranscendental(t);
+      case kind::INTS_DIVISION:
+      case kind::INTS_MODULUS: return rewriteIntsDivMod(t, false);
+      case kind::INTS_DIVISION_TOTAL:
+      case kind::INTS_MODULUS_TOTAL: return rewriteIntsDivModTotal(t, false);
+      case kind::ABS:
+        if (t[0].isConst())
+        {
+          const Rational& rat = t[0].getConst<Rational>();
+          if (rat >= 0)
+          {
+            return RewriteResponse(REWRITE_DONE, t[0]);
+          }
+          else
+          {
+            return RewriteResponse(REWRITE_DONE,
+                                   NodeManager::currentNM()->mkConstRealOrInt(
+                                       t[0].getType(), -rat));
+          }
         }
-      }
-      return RewriteResponse(REWRITE_DONE, t);
-    case kind::TO_REAL:
-    case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
-    case kind::TO_INTEGER: return rewriteExtIntegerOp(t);
-    case kind::POW:
+        return RewriteResponse(REWRITE_DONE, t);
+      case kind::TO_REAL:
+      case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
+      case kind::TO_INTEGER: return rewriteExtIntegerOp(t);
+      case kind::POW:
       {
         if (t[1].isConst())
         {
@@ -757,13 +748,6 @@ Node ArithRewriter::makeUnaryMinusNode(TNode n){
   return nm->mkNode(kind::MULT, nm->mkConstRealOrInt(n.getType(), qNegOne), n);
 }
 
-Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
-  Node negR = makeUnaryMinusNode(r);
-  Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
-
-  return diff;
-}
-
 RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
   Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind() == kind::DIVISION);
 
index 1f861b08a8f26e47446552c93a8960e2d795787b..5fcb628b957d22872c5f5610e6dc5bf097dd4941 100644 (file)
@@ -51,7 +51,7 @@ class ArithRewriter : public TheoryRewriter
 
   static RewriteResponse rewriteVariable(TNode t);
   static RewriteResponse rewriteConstant(TNode t);
-  static RewriteResponse rewriteMinus(TNode t, bool pre);
+  static RewriteResponse rewriteMinus(TNode t);
   static RewriteResponse rewriteUMinus(TNode t, bool pre);
   static RewriteResponse rewriteDiv(TNode t, bool pre);
   static RewriteResponse rewriteIntsDivMod(TNode t, bool pre);