Refactor multiplication in arithmetic rewriter (#7965)
authorGereon Kremer <gkremer@cs.stanford.edu>
Wed, 23 Feb 2022 14:57:53 +0000 (15:57 +0100)
committerGitHub <noreply@github.com>
Wed, 23 Feb 2022 14:57:53 +0000 (14:57 +0000)
This PR refactors the rewriting of multiplication. Most importantly, we explicitly deal with distributivity of addition and multiplication explicitly using the new utilities.

src/theory/arith/arith_rewriter.cpp
src/theory/arith/rewriter/addition.h
src/theory/arith/rewriter/ordering.h

index ffd7427af668f7b647ace8518737909aea9d16b8..86835f447184d1d95662821c50590acd4fee2489 100644 (file)
@@ -30,6 +30,9 @@
 #include "theory/arith/arith_utilities.h"
 #include "theory/arith/normal_form.h"
 #include "theory/arith/operator_elim.h"
+#include "theory/arith/rewriter/addition.h"
+#include "theory/arith/rewriter/node_utils.h"
+#include "theory/arith/rewriter/ordering.h"
 #include "theory/theory.h"
 #include "util/bitvector.h"
 #include "util/divisible.h"
@@ -44,88 +47,6 @@ namespace arith {
 
 namespace {
 
-/**
- * Implements an ordering on arithmetic leaf nodes, excluding rationals. As this
- * comparator is meant to be used on children of Kind::NONLINEAR_MULT, we expect
- * rationals to be handled separately. Furthermore, we expect there to be only a
- * single real algebraic number.
- * It broadly categorizes leaf nodes into real algebraic numbers, integers,
- * variables, and the rest. The ordering is built as follows:
- * - real algebraic numbers come first
- * - real terms come before integer terms
- * - variables come before non-variable terms
- * - finally, fall back to node ordering
- */
-struct LeafNodeComparator
-{
-  /** Implements operator<(a, b) as described above */
-  bool operator()(TNode a, TNode b)
-  {
-    if (a == b) return false;
-
-    bool aIsRAN = a.getKind() == Kind::REAL_ALGEBRAIC_NUMBER;
-    bool bIsRAN = b.getKind() == Kind::REAL_ALGEBRAIC_NUMBER;
-    if (aIsRAN != bIsRAN) return aIsRAN;
-    Assert(!aIsRAN && !bIsRAN) << "real algebraic numbers should be combined";
-
-    bool aIsInt = a.getType().isInteger();
-    bool bIsInt = b.getType().isInteger();
-    if (aIsInt != bIsInt) return !aIsInt;
-
-    bool aIsVar = a.isVar();
-    bool bIsVar = b.isVar();
-    if (aIsVar != bIsVar) return aIsVar;
-
-    return a < b;
-  }
-};
-
-/**
- * Implements an ordering on arithmetic nonlinear multiplications. As we assume
- * rationals to be handled separately, we only consider Kind::NONLINEAR_MULT as
- * multiplication terms. For individual factors of the product, we rely on the
- * ordering from LeafNodeComparator. Furthermore, we expect products to be
- * sorted according to LeafNodeComparator. The ordering is built as follows:
- * - single factors come first (everything that is not NONLINEAR_MULT)
- * - multiplications with less factors come first
- * - multiplications are compared lexicographically
- */
-struct ProductNodeComparator
-{
-  /** Implements operator<(a, b) as described above */
-  bool operator()(TNode a, TNode b)
-  {
-    if (a == b) return false;
-
-    Assert(a.getKind() != Kind::MULT);
-    Assert(b.getKind() != Kind::MULT);
-
-    bool aIsMult = a.getKind() == Kind::NONLINEAR_MULT;
-    bool bIsMult = b.getKind() == Kind::NONLINEAR_MULT;
-    if (aIsMult != bIsMult) return !aIsMult;
-
-    if (!aIsMult)
-    {
-      return LeafNodeComparator()(a, b);
-    }
-
-    size_t aLen = a.getNumChildren();
-    size_t bLen = b.getNumChildren();
-    if (aLen != bLen) return aLen < bLen;
-
-    for (size_t i = 0; i < aLen; ++i)
-    {
-      if (a[i] != b[i])
-      {
-        return LeafNodeComparator()(a[i], b[i]);
-      }
-    }
-    Unreachable() << "Nodes are different, but have the same content";
-    return false;
-  }
-};
-
-
 template <typename L, typename R>
 bool evaluateRelation(Kind rel, const L& l, const R& r)
 {
@@ -626,8 +547,7 @@ RewriteResponse ArithRewriter::preRewriteMult(TNode node)
   Assert(node.getKind() == kind::MULT
          || node.getKind() == kind::NONLINEAR_MULT);
 
-  auto res = getZeroChild(node);
-  if (res)
+  if (auto res = rewriter::getZeroChild(node); res)
   {
     return RewriteResponse(REWRITE_DONE, *res);
   }
@@ -638,16 +558,27 @@ RewriteResponse ArithRewriter::postRewriteMult(TNode t){
   Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
   Assert(t.getNumChildren() >= 2);
 
-  if (auto res = getZeroChild(t); res)
+  std::vector<TNode> children;
+  expr::algorithm::flatten(t, children, Kind::MULT, Kind::NONLINEAR_MULT);
+
+  if (auto res = rewriter::getZeroChild(children); res)
   {
     return RewriteResponse(REWRITE_DONE, *res);
   }
 
-  Rational rational = Rational(1);
+  // Distribute over addition
+  if (std::any_of(children.begin(), children.end(), [](TNode child) {
+        return child.getKind() == Kind::ADD;
+      }))
+  {
+    return RewriteResponse(REWRITE_DONE,
+                           rewriter::distributeMultiplication(children));
+  }
+
   RealAlgebraicNumber ran = RealAlgebraicNumber(Integer(1));
-  Polynomial poly = Polynomial::mkOne();
+  std::vector<Node> leafs;
 
-  for (const auto& child : t)
+  for (const auto& child : children)
   {
     if (child.isConst())
     {
@@ -655,36 +586,20 @@ RewriteResponse ArithRewriter::postRewriteMult(TNode t){
       {
         return RewriteResponse(REWRITE_DONE, child);
       }
-      rational *= child.getConst<Rational>();
+      ran *= child.getConst<Rational>();
     }
-    else if (child.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+    else if (rewriter::isRAN(child))
     {
-      ran *= child.getOperator().getConst<RealAlgebraicNumber>();
+      ran *= rewriter::getRAN(child);
     }
     else
     {
-      poly = poly * Polynomial::parsePolynomial(child);
+      leafs.emplace_back(child);
     }
   }
 
-  if (!rational.isOne())
-  {
-    poly = poly * rational;
-  }
-  if (isOne(ran))
-  {
-    return RewriteResponse(REWRITE_DONE, poly.getNode());
-  }
-  auto* nm = NodeManager::currentNM();
-  if (poly.isConstant())
-  {
-    ran *= RealAlgebraicNumber(poly.getHead().getConstant().getValue());
-    return RewriteResponse(REWRITE_DONE, nm->mkRealAlgebraicNumber(ran));
-  }
-  return RewriteResponse(
-      REWRITE_DONE,
-      nm->mkNode(
-          Kind::MULT, nm->mkRealAlgebraicNumber(ran), poly.getNode()));
+  return RewriteResponse(REWRITE_DONE,
+                         rewriter::mkMultTerm(ran, std::move(leafs)));
 }
 
 RewriteResponse ArithRewriter::postRewritePow2(TNode t)
index 8ab5cb2a13c031f9b8483404b39e17f050342170..fd25a79d3fb87ce67d630cc31705784a88d62c45 100644 (file)
@@ -18,8 +18,8 @@
 #ifndef CVC5__THEORY__ARITH__REWRITER__ADDITION_H
 #define CVC5__THEORY__ARITH__REWRITER__ADDITION_H
 
-#include <map>
 #include <iosfwd>
+#include <map>
 
 #include "expr/node.h"
 #include "theory/arith/rewriter/ordering.h"
index 529bd14cabc20cde844e0146b8166579ae8a6f7f..4b7d6c6fa67a60fb963d3a02979eea5da390d75f 100644 (file)
@@ -82,8 +82,10 @@ struct TermComparator
   {
     if (a == b) return false;
 
-    bool aIsMult = a.getKind() == Kind::MULT || a.getKind() == Kind::NONLINEAR_MULT;
-    bool bIsMult = b.getKind() == Kind::MULT || b.getKind() == Kind::NONLINEAR_MULT;
+    bool aIsMult =
+        a.getKind() == Kind::MULT || a.getKind() == Kind::NONLINEAR_MULT;
+    bool bIsMult =
+        b.getKind() == Kind::MULT || b.getKind() == Kind::NONLINEAR_MULT;
     if (aIsMult != bIsMult) return !aIsMult;
 
     if (!aIsMult)