Refactor rewriting of arithmetic division (#8195)
authorGereon Kremer <gkremer@cs.stanford.edu>
Wed, 2 Mar 2022 16:37:03 +0000 (17:37 +0100)
committerGitHub <noreply@github.com>
Wed, 2 Mar 2022 16:37:03 +0000 (16:37 +0000)
This PR does minor refactoring of how we rewrite division. Also, it moves some functions around to reconcile the order in header and source file.

src/theory/arith/arith_rewriter.cpp
src/theory/arith/arith_rewriter.h

index 863eaee1d8794b86606aa8a71349d12d77ead221..91652b7ce7ce27b9f5561fc4ddd574e3828c5592 100644 (file)
@@ -46,27 +46,6 @@ namespace cvc5 {
 namespace theory {
 namespace arith {
 
-namespace {
-
-/**
- * Check whether the parent has a child that is a constant zero.
- * If so, return this child. Otherwise, return std::nullopt.
- */
-template <typename Iterable>
-std::optional<TNode> getZeroChild(const Iterable& parent)
-{
-  for (const auto& node : parent)
-  {
-    if (node.isConst() && node.template getConst<Rational>().isZero())
-    {
-      return node;
-    }
-  }
-  return std::nullopt;
-}
-
-}  // namespace
-
 ArithRewriter::ArithRewriter(OperatorElim& oe) : d_opElim(oe) {}
 
 RewriteResponse ArithRewriter::preRewrite(TNode t)
@@ -225,46 +204,6 @@ RewriteResponse ArithRewriter::postRewriteAtom(TNode atom)
   }
 }
 
-RewriteResponse ArithRewriter::rewriteNeg(TNode t, bool pre)
-{
-  Assert(t.getKind() == kind::NEG);
-
-  if (t[0].isConst())
-  {
-    Rational neg = -(t[0].getConst<Rational>());
-    return RewriteResponse(REWRITE_DONE, rewriter::mkConst(neg));
-  }
-  if (rewriter::isRAN(t[0]))
-  {
-    return RewriteResponse(REWRITE_DONE,
-                           rewriter::mkConst(-rewriter::getRAN(t[0])));
-  }
-
-  auto* nm = NodeManager::currentNM();
-  Node noUminus = nm->mkNode(kind::MULT, rewriter::mkConst(Integer(-1)), t[0]);
-  if (pre)
-    return RewriteResponse(REWRITE_DONE, noUminus);
-  else
-    return RewriteResponse(REWRITE_AGAIN, noUminus);
-}
-
-RewriteResponse ArithRewriter::rewriteSub(TNode t)
-{
-  Assert(t.getKind() == kind::SUB);
-  Assert(t.getNumChildren() == 2);
-
-  if (t[0] == t[1])
-  {
-    return RewriteResponse(REWRITE_DONE, rewriter::mkConst(Integer(0)));
-  }
-  auto* nm = NodeManager::currentNM();
-  return RewriteResponse(
-      REWRITE_AGAIN_FULL,
-      nm->mkNode(Kind::ADD,
-                 t[0],
-                 nm->mkNode(kind::MULT, rewriter::mkConst(Integer(-1)), t[1])));
-}
-
 RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
   if(t.isConst()){
     return RewriteResponse(REWRITE_DONE, t);
@@ -317,7 +256,9 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
     return RewriteResponse(REWRITE_DONE, t);
   }else if(t.isVar()){
     return rewriteVariable(t);
-  }else{
+  }
+  else
+  {
     Trace("arith-rewriter") << "postRewriteTerm: " << t << std::endl;
     switch(t.getKind()){
       case kind::REAL_ALGEBRAIC_NUMBER: return rewriteRAN(t);
@@ -424,6 +365,50 @@ RewriteResponse ArithRewriter::rewriteVariable(TNode t)
   return RewriteResponse(REWRITE_DONE, t);
 }
 
+RewriteResponse ArithRewriter::rewriteNeg(TNode t, bool pre)
+{
+  Assert(t.getKind() == kind::NEG);
+
+  if (t[0].isConst())
+  {
+    Rational neg = -(t[0].getConst<Rational>());
+    return RewriteResponse(REWRITE_DONE, rewriter::mkConst(neg));
+  }
+  if (rewriter::isRAN(t[0]))
+  {
+    return RewriteResponse(REWRITE_DONE,
+                           rewriter::mkConst(-rewriter::getRAN(t[0])));
+  }
+
+  auto* nm = NodeManager::currentNM();
+  Node noUminus = nm->mkNode(kind::MULT, rewriter::mkConst(Integer(-1)), t[0]);
+  if (pre)
+  {
+    return RewriteResponse(REWRITE_DONE, noUminus);
+  }
+  else
+  {
+    return RewriteResponse(REWRITE_AGAIN, noUminus);
+  }
+}
+
+RewriteResponse ArithRewriter::rewriteSub(TNode t)
+{
+  Assert(t.getKind() == kind::SUB);
+  Assert(t.getNumChildren() == 2);
+
+  if (t[0] == t[1])
+  {
+    return RewriteResponse(REWRITE_DONE, rewriter::mkConst(Integer(0)));
+  }
+  auto* nm = NodeManager::currentNM();
+  return RewriteResponse(
+      REWRITE_AGAIN_FULL,
+      nm->mkNode(Kind::ADD,
+                 t[0],
+                 nm->mkNode(kind::MULT, rewriter::mkConst(Integer(-1)), t[1])));
+}
+
 RewriteResponse ArithRewriter::preRewritePlus(TNode t)
 {
   Assert(t.getKind() == kind::ADD);
@@ -537,13 +522,11 @@ RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre)
       const Rational& num = left.getConst<Rational>();
       return RewriteResponse(REWRITE_DONE, nm->mkConstReal(num / den));
     }
-    if (left.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+    if (rewriter::isRAN(left))
     {
-      const RealAlgebraicNumber& num =
-          left.getOperator().getConst<RealAlgebraicNumber>();
       return RewriteResponse(
           REWRITE_DONE,
-          nm->mkRealAlgebraicNumber(num / RealAlgebraicNumber(den)));
+          nm->mkRealAlgebraicNumber(rewriter::getRAN(left) / den));
     }
 
     Node result = nm->mkConstReal(den.inverse());
@@ -557,27 +540,22 @@ RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre)
       return RewriteResponse(REWRITE_AGAIN, mult);
     }
   }
-  if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+  if (rewriter::isRAN(right))
   {
-    NodeManager* nm = NodeManager::currentNM();
-    const RealAlgebraicNumber& den =
-        right.getOperator().getConst<RealAlgebraicNumber>();
+    const RealAlgebraicNumber& den = rewriter::getRAN(right);
+
     if (left.isConst())
     {
-      const Rational& num = left.getConst<Rational>();
       return RewriteResponse(
-          REWRITE_DONE,
-          nm->mkRealAlgebraicNumber(RealAlgebraicNumber(num) / den));
+          REWRITE_DONE, rewriter::mkConst(left.getConst<Rational>() / den));
     }
-    if (left.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+    if (rewriter::isRAN(left))
     {
-      const RealAlgebraicNumber& num =
-          left.getOperator().getConst<RealAlgebraicNumber>();
       return RewriteResponse(REWRITE_DONE,
-                             nm->mkRealAlgebraicNumber(num / den));
+                             rewriter::mkConst(rewriter::getRAN(left) / den));
     }
 
-    Node result = nm->mkRealAlgebraicNumber(inverse(den));
+    Node result = rewriter::mkConst(inverse(den));
     Node mult = NodeManager::currentNM()->mkNode(kind::MULT, left, result);
     if (pre)
     {
@@ -607,10 +585,9 @@ RewriteResponse ArithRewriter::rewriteAbs(TNode t)
         REWRITE_DONE,
         NodeManager::currentNM()->mkConstRealOrInt(t[0].getType(), -rat));
   }
-  if (t[0].getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+  if (rewriter::isRAN(t[0]))
   {
-    const RealAlgebraicNumber& ran =
-        t[0].getOperator().getConst<RealAlgebraicNumber>();
+    const RealAlgebraicNumber& ran = rewriter::getRAN(t[0]);
     if (ran >= RealAlgebraicNumber())
     {
       return RewriteResponse(REWRITE_DONE, t[0]);
@@ -646,37 +623,6 @@ RewriteResponse ArithRewriter::rewriteIntsDivMod(TNode t, bool pre)
   return RewriteResponse(REWRITE_DONE, t);
 }
 
-RewriteResponse ArithRewriter::rewriteExtIntegerOp(TNode t)
-{
-  Assert(t.getKind() == kind::TO_INTEGER || t.getKind() == kind::IS_INTEGER);
-  bool isPred = t.getKind() == kind::IS_INTEGER;
-  NodeManager* nm = NodeManager::currentNM();
-  if (t[0].isConst())
-  {
-    Node ret;
-    if (isPred)
-    {
-      ret = nm->mkConst(t[0].getConst<Rational>().isIntegral());
-    }
-    else
-    {
-      ret = nm->mkConstInt(Rational(t[0].getConst<Rational>().floor()));
-    }
-    return returnRewrite(t, ret, Rewrite::INT_EXT_CONST);
-  }
-  if (t[0].getType().isInteger())
-  {
-    Node ret = isPred ? nm->mkConst(true) : Node(t[0]);
-    return returnRewrite(t, ret, Rewrite::INT_EXT_INT);
-  }
-  if (t[0].getKind() == kind::PI)
-  {
-    Node ret = isPred ? nm->mkConst(false) : nm->mkConstReal(Rational(3));
-    return returnRewrite(t, ret, Rewrite::INT_EXT_PI);
-  }
-  return RewriteResponse(REWRITE_DONE, t);
-}
-
 RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre)
 {
   if (pre)
@@ -785,6 +731,37 @@ RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre)
   return RewriteResponse(REWRITE_DONE, t);
 }
 
+RewriteResponse ArithRewriter::rewriteExtIntegerOp(TNode t)
+{
+  Assert(t.getKind() == kind::TO_INTEGER || t.getKind() == kind::IS_INTEGER);
+  bool isPred = t.getKind() == kind::IS_INTEGER;
+  NodeManager* nm = NodeManager::currentNM();
+  if (t[0].isConst())
+  {
+    Node ret;
+    if (isPred)
+    {
+      ret = nm->mkConst(t[0].getConst<Rational>().isIntegral());
+    }
+    else
+    {
+      ret = nm->mkConstInt(Rational(t[0].getConst<Rational>().floor()));
+    }
+    return returnRewrite(t, ret, Rewrite::INT_EXT_CONST);
+  }
+  if (t[0].getType().isInteger())
+  {
+    Node ret = isPred ? nm->mkConst(true) : Node(t[0]);
+    return returnRewrite(t, ret, Rewrite::INT_EXT_INT);
+  }
+  if (t[0].getKind() == kind::PI)
+  {
+    Node ret = isPred ? nm->mkConst(false) : nm->mkConstReal(Rational(3));
+    return returnRewrite(t, ret, Rewrite::INT_EXT_PI);
+  }
+  return RewriteResponse(REWRITE_DONE, t);
+}
+
 RewriteResponse ArithRewriter::postRewriteIAnd(TNode t)
 {
   Assert(t.getKind() == kind::IAND);
index 3ffa90015b27f1f975e9ee0ac3c060581db5a835..6f2e06ef14dcb7d85f0e90a1ec03bbdb71ae7fe9 100644 (file)
@@ -48,7 +48,9 @@ class ArithRewriter : public TheoryRewriter
   /** postRewrite for atoms */
   static RewriteResponse postRewriteAtom(TNode t);
 
+  /** preRewrite for terms */
   static RewriteResponse preRewriteTerm(TNode t);
+  /** postRewrite for terms */
   static RewriteResponse postRewriteTerm(TNode t);
 
   /** rewrite real algebraic numbers */
@@ -64,16 +66,22 @@ class ArithRewriter : public TheoryRewriter
   static RewriteResponse preRewritePlus(TNode t);
   /** postRewrite addition */
   static RewriteResponse postRewritePlus(TNode t);
+  /** preRewrite multiplication */
+  static RewriteResponse preRewriteMult(TNode t);
+  /** postRewrite multiplication */
+  static RewriteResponse postRewriteMult(TNode t);
+
+  /** rewrite division */
   static RewriteResponse rewriteDiv(TNode t, bool pre);
+  /** rewrite absolute */
   static RewriteResponse rewriteAbs(TNode t);
+  /** rewrite integer division and modulus */
   static RewriteResponse rewriteIntsDivMod(TNode t, bool pre);
+  /** rewrite integer total division and total modulus */
   static RewriteResponse rewriteIntsDivModTotal(TNode t, bool pre);
-  /** Entry for applications of to_int and is_int */
+  /** rewrite to_int and is_int */
   static RewriteResponse rewriteExtIntegerOp(TNode t);
 
-  static RewriteResponse preRewriteMult(TNode t);
-  static RewriteResponse postRewriteMult(TNode t);
-
   /** postRewrite IAND */
   static RewriteResponse postRewriteIAnd(TNode t);
   /** postRewrite POW2 */