Refactor rewriting of arithmetic addition (#8180)
authorGereon Kremer <gkremer@cs.stanford.edu>
Mon, 28 Feb 2022 15:22:48 +0000 (16:22 +0100)
committerGitHub <noreply@github.com>
Mon, 28 Feb 2022 15:22:48 +0000 (09:22 -0600)
This PR uses the new addition utilities to refactor rewriting of arithmetic addition. This properly handles real algebraic numbers now, eliminating a few more edge cases where the previous solution might allow for non-idempotent rewrites.

src/theory/arith/arith_rewriter.cpp
src/theory/arith/arith_rewriter.h
src/theory/arith/rewriter/addition.cpp
src/theory/arith/rewriter/node_utils.cpp

index aae34abb62fcf18f95e6862a9fe697999814e638..a251d9e08cf92b76df31446d17c9d58670dcde09 100644 (file)
@@ -459,86 +459,26 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
 }
 
 
-RewriteResponse ArithRewriter::preRewritePlus(TNode t){
+RewriteResponse ArithRewriter::preRewritePlus(TNode t)
+{
   Assert(t.getKind() == kind::ADD);
   return RewriteResponse(REWRITE_DONE, expr::algorithm::flatten(t));
 }
 
-RewriteResponse ArithRewriter::postRewritePlus(TNode t){
+RewriteResponse ArithRewriter::postRewritePlus(TNode t)
+{
   Assert(t.getKind() == kind::ADD);
   Assert(t.getNumChildren() > 1);
 
-  {
-    Node flat = expr::algorithm::flatten(t);
-    if (flat != t)
-    {
-      return RewriteResponse(REWRITE_AGAIN, flat);
-    }
-  }
-
-  Rational rational;
-  RealAlgebraicNumber ran;
-  std::vector<Monomial> monomials;
-  std::vector<Polynomial> polynomials;
-
-  for (const auto& child : t)
-  {
-    if (child.isConst())
-    {
-      if (child.getConst<Rational>().isZero())
-      {
-        continue;
-      }
-      rational += child.getConst<Rational>();
-    }
-    else if (child.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
-    {
-      ran += child.getOperator().getConst<RealAlgebraicNumber>();
-    }
-    else if (Monomial::isMember(child))
-    {
-      monomials.emplace_back(Monomial::parseMonomial(child));
-    }
-    else
-    {
-      polynomials.emplace_back(Polynomial::parsePolynomial(child));
-    }
-  }
-
-  if(!monomials.empty()){
-    Monomial::sort(monomials);
-    Monomial::combineAdjacentMonomials(monomials);
-    polynomials.emplace_back(Polynomial::mkPolynomial(monomials));
-  }
-  if (!rational.isZero())
-  {
-    polynomials.emplace_back(
-        Polynomial::mkPolynomial(Constant::mkConstant(rational)));
-  }
-
-  Polynomial poly = Polynomial::sumPolynomials(polynomials);
-
-  if (isZero(ran))
-  {
-    return RewriteResponse(REWRITE_DONE, poly.getNode());
-  }
-  if (poly.containsConstant())
-  {
-    ran += RealAlgebraicNumber(poly.getHead().getConstant().getValue());
-    if (!poly.isConstant())
-    {
-      poly = poly.getTail();
-    }
-  }
+  std::vector<TNode> children;
+  expr::algorithm::flatten(t, children);
 
-  auto* nm = NodeManager::currentNM();
-  if (poly.isConstant())
+  rewriter::Sum sum;
+  for (const auto& child : children)
   {
-    return RewriteResponse(REWRITE_DONE, nm->mkRealAlgebraicNumber(ran));
+    rewriter::addToSum(sum, child);
   }
-  return RewriteResponse(
-      REWRITE_DONE,
-      nm->mkNode(Kind::ADD, nm->mkRealAlgebraicNumber(ran), poly.getNode()));
+  return RewriteResponse(REWRITE_DONE, rewriter::collectSum(sum));
 }
 
 RewriteResponse ArithRewriter::preRewriteMult(TNode node)
index a1079e60f5c3a911b82c92f815071697fea07067..7cf9a4f5b29362cc43a7a0a847ade537be7dcea6 100644 (file)
@@ -57,6 +57,10 @@ class ArithRewriter : public TheoryRewriter
   static RewriteResponse rewriteNeg(TNode t, bool pre);
   /** rewrite binary minus */
   static RewriteResponse rewriteSub(TNode t);
+  /** preRewrite addition */
+  static RewriteResponse preRewritePlus(TNode t);
+  /** postRewrite addition */
+  static RewriteResponse postRewritePlus(TNode t);
   static RewriteResponse rewriteDiv(TNode t, bool pre);
   static RewriteResponse rewriteAbs(TNode t);
   static RewriteResponse rewriteIntsDivMod(TNode t, bool pre);
@@ -64,9 +68,6 @@ class ArithRewriter : public TheoryRewriter
   /** Entry for applications of to_int and is_int */
   static RewriteResponse rewriteExtIntegerOp(TNode t);
 
-  static RewriteResponse preRewritePlus(TNode t);
-  static RewriteResponse postRewritePlus(TNode t);
-
   static RewriteResponse preRewriteMult(TNode t);
   static RewriteResponse postRewriteMult(TNode t);
 
index 16e2889208b0097f9dc2c2c0fb3dd0f1d194530b..e8743a43aca55c61da5ea1d72a325d1087529f56 100644 (file)
@@ -162,6 +162,7 @@ void addToSum(Sum& sum, TNode n, bool negate)
 Node collectSum(const Sum& sum)
 {
   if (sum.empty()) return mkConst(Rational(0));
+  Trace("arith-rewriter") << "Collecting sum " << sum << std::endl;
   // construct the sum as nodes.
   NodeBuilder nb(Kind::ADD);
   for (const auto& s : sum)
index 34c3849a09e0e3aa655f50864eeeb0dda323b4d0..3d1856e0b3cebe4cfbc77ba688a147d29a154243 100644 (file)
@@ -82,7 +82,14 @@ Node mkMultTerm(const RealAlgebraicNumber& multiplicity, TNode monomial)
   }
   std::vector<Node> prod;
   prod.emplace_back(mkConst(multiplicity));
-  prod.insert(prod.end(), monomial.begin(), monomial.end());
+  if (monomial.getKind() == Kind::MULT || monomial.getKind() == Kind::NONLINEAR_MULT)
+  {
+    prod.insert(prod.end(), monomial.begin(), monomial.end());
+  }
+  else
+  {
+    prod.emplace_back(monomial);
+  }
   Assert(prod.size() >= 2);
   return NodeManager::currentNM()->mkNode(Kind::NONLINEAR_MULT, prod);
 }