From 281bda09e88c568d151464559e00e4ea73c7192f Mon Sep 17 00:00:00 2001 From: Gereon Kremer Date: Wed, 23 Feb 2022 15:57:53 +0100 Subject: [PATCH] Refactor multiplication in arithmetic rewriter (#7965) This PR refactors the rewriting of multiplication. Most importantly, we explicitly deal with distributivity of addition and multiplication explicitly using the new utilities. --- src/theory/arith/arith_rewriter.cpp | 135 +++++---------------------- src/theory/arith/rewriter/addition.h | 2 +- src/theory/arith/rewriter/ordering.h | 6 +- 3 files changed, 30 insertions(+), 113 deletions(-) diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp index ffd7427af..86835f447 100644 --- a/src/theory/arith/arith_rewriter.cpp +++ b/src/theory/arith/arith_rewriter.cpp @@ -30,6 +30,9 @@ #include "theory/arith/arith_utilities.h" #include "theory/arith/normal_form.h" #include "theory/arith/operator_elim.h" +#include "theory/arith/rewriter/addition.h" +#include "theory/arith/rewriter/node_utils.h" +#include "theory/arith/rewriter/ordering.h" #include "theory/theory.h" #include "util/bitvector.h" #include "util/divisible.h" @@ -44,88 +47,6 @@ namespace arith { namespace { -/** - * Implements an ordering on arithmetic leaf nodes, excluding rationals. As this - * comparator is meant to be used on children of Kind::NONLINEAR_MULT, we expect - * rationals to be handled separately. Furthermore, we expect there to be only a - * single real algebraic number. - * It broadly categorizes leaf nodes into real algebraic numbers, integers, - * variables, and the rest. The ordering is built as follows: - * - real algebraic numbers come first - * - real terms come before integer terms - * - variables come before non-variable terms - * - finally, fall back to node ordering - */ -struct LeafNodeComparator -{ - /** Implements operator<(a, b) as described above */ - bool operator()(TNode a, TNode b) - { - if (a == b) return false; - - bool aIsRAN = a.getKind() == Kind::REAL_ALGEBRAIC_NUMBER; - bool bIsRAN = b.getKind() == Kind::REAL_ALGEBRAIC_NUMBER; - if (aIsRAN != bIsRAN) return aIsRAN; - Assert(!aIsRAN && !bIsRAN) << "real algebraic numbers should be combined"; - - bool aIsInt = a.getType().isInteger(); - bool bIsInt = b.getType().isInteger(); - if (aIsInt != bIsInt) return !aIsInt; - - bool aIsVar = a.isVar(); - bool bIsVar = b.isVar(); - if (aIsVar != bIsVar) return aIsVar; - - return a < b; - } -}; - -/** - * Implements an ordering on arithmetic nonlinear multiplications. As we assume - * rationals to be handled separately, we only consider Kind::NONLINEAR_MULT as - * multiplication terms. For individual factors of the product, we rely on the - * ordering from LeafNodeComparator. Furthermore, we expect products to be - * sorted according to LeafNodeComparator. The ordering is built as follows: - * - single factors come first (everything that is not NONLINEAR_MULT) - * - multiplications with less factors come first - * - multiplications are compared lexicographically - */ -struct ProductNodeComparator -{ - /** Implements operator<(a, b) as described above */ - bool operator()(TNode a, TNode b) - { - if (a == b) return false; - - Assert(a.getKind() != Kind::MULT); - Assert(b.getKind() != Kind::MULT); - - bool aIsMult = a.getKind() == Kind::NONLINEAR_MULT; - bool bIsMult = b.getKind() == Kind::NONLINEAR_MULT; - if (aIsMult != bIsMult) return !aIsMult; - - if (!aIsMult) - { - return LeafNodeComparator()(a, b); - } - - size_t aLen = a.getNumChildren(); - size_t bLen = b.getNumChildren(); - if (aLen != bLen) return aLen < bLen; - - for (size_t i = 0; i < aLen; ++i) - { - if (a[i] != b[i]) - { - return LeafNodeComparator()(a[i], b[i]); - } - } - Unreachable() << "Nodes are different, but have the same content"; - return false; - } -}; - - template bool evaluateRelation(Kind rel, const L& l, const R& r) { @@ -626,8 +547,7 @@ RewriteResponse ArithRewriter::preRewriteMult(TNode node) Assert(node.getKind() == kind::MULT || node.getKind() == kind::NONLINEAR_MULT); - auto res = getZeroChild(node); - if (res) + if (auto res = rewriter::getZeroChild(node); res) { return RewriteResponse(REWRITE_DONE, *res); } @@ -638,16 +558,27 @@ RewriteResponse ArithRewriter::postRewriteMult(TNode t){ Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT); Assert(t.getNumChildren() >= 2); - if (auto res = getZeroChild(t); res) + std::vector children; + expr::algorithm::flatten(t, children, Kind::MULT, Kind::NONLINEAR_MULT); + + if (auto res = rewriter::getZeroChild(children); res) { return RewriteResponse(REWRITE_DONE, *res); } - Rational rational = Rational(1); + // Distribute over addition + if (std::any_of(children.begin(), children.end(), [](TNode child) { + return child.getKind() == Kind::ADD; + })) + { + return RewriteResponse(REWRITE_DONE, + rewriter::distributeMultiplication(children)); + } + RealAlgebraicNumber ran = RealAlgebraicNumber(Integer(1)); - Polynomial poly = Polynomial::mkOne(); + std::vector leafs; - for (const auto& child : t) + for (const auto& child : children) { if (child.isConst()) { @@ -655,36 +586,20 @@ RewriteResponse ArithRewriter::postRewriteMult(TNode t){ { return RewriteResponse(REWRITE_DONE, child); } - rational *= child.getConst(); + ran *= child.getConst(); } - else if (child.getKind() == Kind::REAL_ALGEBRAIC_NUMBER) + else if (rewriter::isRAN(child)) { - ran *= child.getOperator().getConst(); + ran *= rewriter::getRAN(child); } else { - poly = poly * Polynomial::parsePolynomial(child); + leafs.emplace_back(child); } } - if (!rational.isOne()) - { - poly = poly * rational; - } - if (isOne(ran)) - { - return RewriteResponse(REWRITE_DONE, poly.getNode()); - } - auto* nm = NodeManager::currentNM(); - if (poly.isConstant()) - { - ran *= RealAlgebraicNumber(poly.getHead().getConstant().getValue()); - return RewriteResponse(REWRITE_DONE, nm->mkRealAlgebraicNumber(ran)); - } - return RewriteResponse( - REWRITE_DONE, - nm->mkNode( - Kind::MULT, nm->mkRealAlgebraicNumber(ran), poly.getNode())); + return RewriteResponse(REWRITE_DONE, + rewriter::mkMultTerm(ran, std::move(leafs))); } RewriteResponse ArithRewriter::postRewritePow2(TNode t) diff --git a/src/theory/arith/rewriter/addition.h b/src/theory/arith/rewriter/addition.h index 8ab5cb2a1..fd25a79d3 100644 --- a/src/theory/arith/rewriter/addition.h +++ b/src/theory/arith/rewriter/addition.h @@ -18,8 +18,8 @@ #ifndef CVC5__THEORY__ARITH__REWRITER__ADDITION_H #define CVC5__THEORY__ARITH__REWRITER__ADDITION_H -#include #include +#include #include "expr/node.h" #include "theory/arith/rewriter/ordering.h" diff --git a/src/theory/arith/rewriter/ordering.h b/src/theory/arith/rewriter/ordering.h index 529bd14ca..4b7d6c6fa 100644 --- a/src/theory/arith/rewriter/ordering.h +++ b/src/theory/arith/rewriter/ordering.h @@ -82,8 +82,10 @@ struct TermComparator { if (a == b) return false; - bool aIsMult = a.getKind() == Kind::MULT || a.getKind() == Kind::NONLINEAR_MULT; - bool bIsMult = b.getKind() == Kind::MULT || b.getKind() == Kind::NONLINEAR_MULT; + bool aIsMult = + a.getKind() == Kind::MULT || a.getKind() == Kind::NONLINEAR_MULT; + bool bIsMult = + b.getKind() == Kind::MULT || b.getKind() == Kind::NONLINEAR_MULT; if (aIsMult != bIsMult) return !aIsMult; if (!aIsMult) -- 2.30.2