Refactor rewriting of arithmetic negation and subtraction (#8170)
authorGereon Kremer <gkremer@cs.stanford.edu>
Fri, 25 Feb 2022 20:56:21 +0000 (21:56 +0100)
committerGitHub <noreply@github.com>
Fri, 25 Feb 2022 20:56:21 +0000 (20:56 +0000)
Slightly refactor negation and subtraction, get rid of utility functions.

src/theory/arith/arith_rewriter.cpp
src/theory/arith/arith_rewriter.h

index f8d8594ebab585b03cf5a6b2f757ca5f0be33c11..aae34abb62fcf18f95e6862a9fe697999814e638 100644 (file)
@@ -270,28 +270,13 @@ RewriteResponse ArithRewriter::rewriteRAN(TNode t)
   return RewriteResponse(REWRITE_DONE, t);
 }
 
-RewriteResponse ArithRewriter::rewriteVariable(TNode t){
+RewriteResponse ArithRewriter::rewriteVariable(TNode t)
+{
   Assert(t.isVar());
 
   return RewriteResponse(REWRITE_DONE, t);
 }
 
-RewriteResponse ArithRewriter::rewriteSub(TNode t)
-{
-  Assert(t.getKind() == kind::SUB);
-  Assert(t.getNumChildren() == 2);
-
-  auto* nm = NodeManager::currentNM();
-
-  if (t[0] == t[1])
-  {
-    return RewriteResponse(REWRITE_DONE,
-                           nm->mkConstRealOrInt(t.getType(), Rational(0)));
-  }
-  return RewriteResponse(REWRITE_AGAIN_FULL,
-                         nm->mkNode(Kind::ADD, t[0], makeUnaryMinusNode(t[1])));
-}
-
 RewriteResponse ArithRewriter::rewriteNeg(TNode t, bool pre)
 {
   Assert(t.getKind() == kind::NEG);
@@ -299,25 +284,39 @@ RewriteResponse ArithRewriter::rewriteNeg(TNode t, bool pre)
   if (t[0].isConst())
   {
     Rational neg = -(t[0].getConst<Rational>());
-    NodeManager* nm = NodeManager::currentNM();
-    return RewriteResponse(REWRITE_DONE,
-                           nm->mkConstRealOrInt(t[0].getType(), neg));
+    return RewriteResponse(REWRITE_DONE, rewriter::mkConst(neg));
   }
-  if (t[0].getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+  if (rewriter::isRAN(t[0]))
   {
-    const RealAlgebraicNumber& r =
-        t[0].getOperator().getConst<RealAlgebraicNumber>();
-    NodeManager* nm = NodeManager::currentNM();
-    return RewriteResponse(REWRITE_DONE, nm->mkRealAlgebraicNumber(-r));
+    return RewriteResponse(REWRITE_DONE,
+                           rewriter::mkConst(-rewriter::getRAN(t[0])));
   }
 
-  Node noUminus = makeUnaryMinusNode(t[0]);
-  if(pre)
+  auto* nm = NodeManager::currentNM();
+  Node noUminus = nm->mkNode(kind::MULT, rewriter::mkConst(Integer(-1)), t[0]);
+  if (pre)
     return RewriteResponse(REWRITE_DONE, noUminus);
   else
     return RewriteResponse(REWRITE_AGAIN, noUminus);
 }
 
+RewriteResponse ArithRewriter::rewriteSub(TNode t)
+{
+  Assert(t.getKind() == kind::SUB);
+  Assert(t.getNumChildren() == 2);
+
+  if (t[0] == t[1])
+  {
+    return RewriteResponse(REWRITE_DONE, rewriter::mkConst(Integer(0)));
+  }
+  auto* nm = NodeManager::currentNM();
+  return RewriteResponse(
+      REWRITE_AGAIN_FULL,
+      nm->mkNode(Kind::ADD,
+                 t[0],
+                 nm->mkNode(kind::MULT, rewriter::mkConst(Integer(-1)), t[1])));
+}
+
 RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
   if(t.isConst()){
     return rewriteConstant(t);
@@ -602,13 +601,6 @@ RewriteResponse ArithRewriter::postRewriteMult(TNode t){
                          rewriter::mkMultTerm(ran, std::move(leafs)));
 }
 
-Node ArithRewriter::makeUnaryMinusNode(TNode n)
-{
-  NodeManager* nm = NodeManager::currentNM();
-  Rational qNegOne(-1);
-  return nm->mkNode(kind::MULT, nm->mkConstRealOrInt(n.getType(), qNegOne), n);
-}
-
 RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre)
 {
   Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind() == kind::DIVISION);
index 9e2a15c77f9d2b2284af39758e1b8946460ab4cd..a1079e60f5c3a911b82c92f815071697fea07067 100644 (file)
@@ -46,9 +46,6 @@ class ArithRewriter : public TheoryRewriter
   static RewriteResponse preRewriteAtom(TNode t);
   static RewriteResponse postRewriteAtom(TNode t);
 
-  static Node makeSubtractionNode(TNode l, TNode r);
-  static Node makeUnaryMinusNode(TNode n);
-
   static RewriteResponse preRewriteTerm(TNode t);
   static RewriteResponse postRewriteTerm(TNode t);
 
@@ -56,8 +53,10 @@ class ArithRewriter : public TheoryRewriter
   static RewriteResponse rewriteRAN(TNode t);
   static RewriteResponse rewriteVariable(TNode t);
 
-  static RewriteResponse rewriteSub(TNode t);
+  /** rewrite unary minus */
   static RewriteResponse rewriteNeg(TNode t, bool pre);
+  /** rewrite binary minus */
+  static RewriteResponse rewriteSub(TNode t);
   static RewriteResponse rewriteDiv(TNode t, bool pre);
   static RewriteResponse rewriteAbs(TNode t);
   static RewriteResponse rewriteIntsDivMod(TNode t, bool pre);