From c93eccb4706804ffc139548f9d2f2145fda820a9 Mon Sep 17 00:00:00 2001 From: Gereon Kremer Date: Tue, 8 Feb 2022 13:26:21 -0800 Subject: [PATCH] Add addition utilities for the arithmetic rewriter (#8013) This PR adds methods to deal with sums in the arithmetic rewriter. Sums are stored as std::map (see the code for some reasoning), and we implement adding terms to sums, collecting sums into nodes and distributing a multiplication over nested sums. --- src/CMakeLists.txt | 2 + src/theory/arith/rewriter/addition.cpp | 265 +++++++++++++++++++++++++ src/theory/arith/rewriter/addition.h | 91 +++++++++ 3 files changed, 358 insertions(+) create mode 100644 src/theory/arith/rewriter/addition.cpp create mode 100644 src/theory/arith/rewriter/addition.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 87ba2bb94..dba712aed 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -498,6 +498,8 @@ libcvc5_add_sources( theory/arith/pp_rewrite_eq.h theory/arith/proof_checker.cpp theory/arith/proof_checker.h + theory/arith/rewriter/addition.cpp + theory/arith/rewriter/addition.h theory/arith/rewriter/node_utils.cpp theory/arith/rewriter/node_utils.h theory/arith/rewriter/ordering.h diff --git a/src/theory/arith/rewriter/addition.cpp b/src/theory/arith/rewriter/addition.cpp new file mode 100644 index 000000000..c9758fa1c --- /dev/null +++ b/src/theory/arith/rewriter/addition.cpp @@ -0,0 +1,265 @@ +/****************************************************************************** + * Top contributors (to current version): + * Gereon Kremer + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * Addition utilities for the arithmetic rewriter. + */ + +#include "theory/arith/rewriter/addition.h" + +#include + +#include "base/check.h" +#include "expr/node.h" +#include "theory/arith/rewriter/node_utils.h" +#include "theory/arith/rewriter/ordering.h" +#include "util/real_algebraic_number.h" + +namespace cvc5 { +namespace theory { +namespace arith { +namespace rewriter { + +std::ostream& operator<<(std::ostream& os, const Sum& sum) +{ + for (auto it = sum.begin(); it != sum.end(); ++it) + { + if (it != sum.begin()) os << " + "; + if (it->first.isConst()) + { + Assert(it->first.getConst().isOne()); + os << it->second; + continue; + } + os << it->second << "*" << it->first; + } + return os; +} + +namespace +{ + +/** + * Adds a factor n to a product, consisting of the numerical multiplicity and + * the remaining (non-numerical) factors. If n is a product itself, its children + * are merged into the product. If n is a constant or a real algebraic number, + * it is multiplied to the multiplicity. Otherwise, n is added to product. + * + * Invariant: + * multiplicity' * multiply(product') = n * multiplicity * multiply(product) + */ +void addToProduct(std::vector& product, + RealAlgebraicNumber& multiplicity, + TNode n) +{ + switch (n.getKind()) + { + case Kind::MULT: + case Kind::NONLINEAR_MULT: + for (const auto& child : n) + { + // make sure constants are properly extracted. + // recursion is safe, as mult is already flattened + addToProduct(product, multiplicity, child); + } + break; + case Kind::REAL_ALGEBRAIC_NUMBER: multiplicity *= getRAN(n); break; + default: + if (n.isConst()) + { + multiplicity *= n.getConst(); + } + else + { + product.emplace_back(n); + } + } +} + +/** + * Add a new summand, consisting of the product and the multiplicity, to a sum. + * Either adds the summand as a new entry to the sum, or adds the multiplicity + * to an already existing summand. Removes the entry, if the multiplicity is + * zero afterwards. + * + * Invariant: + * add(s.n * s.ran for s in sum') + * = add(s.n * s.ran for s in sum) + multiplicity * product + */ +void addToSum(Sum& sum, TNode product, const RealAlgebraicNumber& multiplicity) +{ + if (isZero(multiplicity)) return; + auto it = sum.find(product); + if (it == sum.end()) + { + sum.emplace(product, multiplicity); + } + else + { + it->second += multiplicity; + if (isZero(it->second)) + { + sum.erase(it); + } + } +} + +/** + * Evaluates `basemultiplicity * baseproduct * sum` into a single node (of kind + * `ADD`, unless the sum has less than two summands). + */ +Node collectSumWithBase(const Sum& sum, + const RealAlgebraicNumber& basemultiplicity, + const std::vector& baseproduct) +{ + if (sum.empty()) return mkConst(Rational(0)); + // construct the sum as nodes. + NodeBuilder nb(Kind::ADD); + for (const auto& summand : sum) + { + Assert(!isZero(summand.second)); + RealAlgebraicNumber mult = summand.second * basemultiplicity; + std::vector product = baseproduct; + rewriter::addToProduct(product, mult, summand.first); + nb << mkMultTerm(mult, std::move(product)); + } + if (nb.getNumChildren() == 1) + { + return nb[0]; + } + return nb.constructNode(); +} +} + +void addToSum(Sum& sum, TNode n, bool negate) +{ + if (n.getKind() == Kind::ADD) + { + for (const auto& child : n) + { + addToSum(sum, child, negate); + } + return; + } + std::vector monomial; + RealAlgebraicNumber multiplicity(Integer(1)); + if (negate) + { + multiplicity = Integer(-1); + } + addToProduct(monomial, multiplicity, n); + addToSum(sum, mkNonlinearMult(monomial), multiplicity); +} + +Node collectSum(const Sum& sum) +{ + if (sum.empty()) return mkConst(Rational(0)); + // construct the sum as nodes. + NodeBuilder nb(Kind::ADD); + for (const auto& s : sum) + { + nb << mkMultTerm(s.second, s.first); + } + if (nb.getNumChildren() == 1) + { + return nb[0]; + } + return nb.constructNode(); +} + +Node distributeMultiplication(const std::vector& factors) +{ + if (Trace.isOn("arith-rewriter-distribute")) + { + Trace("arith-rewriter-distribute") << "Distributing" << std::endl; + for (const auto& f : factors) + { + Trace("arith-rewriter-distribute") << "\t" << f << std::endl; + } + } + // factors that are not sums, separated into numerical and non-numerical + RealAlgebraicNumber basemultiplicity(Integer(1)); + std::vector base; + // maps products to their (possibly real algebraic) multiplicities. + // The current (intermediate) value is the sum of these (multiplied by the + // base factors). + Sum sum; + // Add a base summand + sum.emplace(mkConst(Rational(1)), RealAlgebraicNumber(Integer(1))); + + // multiply factors one by one to basmultiplicity * base * sum + for (const auto& factor : factors) + { + // Subtractions are rewritten already, we only need to care about additions + Assert(factor.getKind() != Kind::SUB); + Assert(factor.getKind() != Kind::NEG + || (factor[0].isConst() || isRAN(factor[0]))); + if (factor.getKind() != Kind::ADD) + { + Assert(!(factor.isConst() && factor.getConst().isZero())); + addToProduct(base, basemultiplicity, factor); + continue; + } + // temporary to store factor * sum, will be moved to sum at the end + Sum newsum; + + for (const auto& summand : sum) + { + for (const auto& child : factor) + { + // add summand * child to newsum + RealAlgebraicNumber multiplicity = summand.second; + if (child.isConst()) + { + multiplicity *= child.getConst(); + addToSum(newsum, summand.first, multiplicity); + continue; + } + if (isRAN(child)) + { + multiplicity *= getRAN(child); + addToSum(newsum, summand.first, multiplicity); + continue; + } + + // construct the new product + std::vector newProduct; + addToProduct(newProduct, multiplicity, summand.first); + addToProduct(newProduct, multiplicity, child); + std::sort(newProduct.begin(), newProduct.end(), LeafNodeComparator()); + addToSum(newsum, mkNonlinearMult(newProduct), multiplicity); + } + } + if (Trace.isOn("arith-rewriter-distribute")) + { + Trace("arith-rewriter-distribute") + << "multiplied with " << factor << std::endl; + Trace("arith-rewriter-distribute") + << "base: " << basemultiplicity << " * " << base << std::endl; + Trace("arith-rewriter-distribute") << "sum:" << std::endl; + for (const auto& summand : newsum) + { + Trace("arith-rewriter-distribute") + << "\t" << summand.second << " * " << summand.first << std::endl; + } + } + + sum = std::move(newsum); + } + // now mult(factors) == base * add(sum) + + return collectSumWithBase(sum, basemultiplicity, base); +} + +} // namespace rewriter +} // namespace arith +} // namespace theory +} // namespace cvc5 \ No newline at end of file diff --git a/src/theory/arith/rewriter/addition.h b/src/theory/arith/rewriter/addition.h new file mode 100644 index 000000000..8ab5cb2a1 --- /dev/null +++ b/src/theory/arith/rewriter/addition.h @@ -0,0 +1,91 @@ +/****************************************************************************** + * Top contributors (to current version): + * Gereon Kremer + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * Addition utilities for the arithmetic rewriter. + */ + +#include "cvc5_private.h" + +#ifndef CVC5__THEORY__ARITH__REWRITER__ADDITION_H +#define CVC5__THEORY__ARITH__REWRITER__ADDITION_H + +#include +#include + +#include "expr/node.h" +#include "theory/arith/rewriter/ordering.h" +#include "util/real_algebraic_number.h" + +namespace cvc5 { +namespace theory { +namespace arith { +namespace rewriter { + +/** + * Intermediate representation for a sum of terms, mapping monomials to their + * multiplicities. A sum implicitly represents the expression + * SUM(s.second * s.first for s in sum) + * Using a map allows to easily check whether a monomial is already present and + * then merge two terms (i.e. add their multiplicities). We use a std::map with + * a proper comparator (instead of std::unordered_map) to allow easy + * identification of the leading term. As we need to sort the terms anyway when + * constructing a node, a std::unordered_map may only be faster if we experience + * a lot of nullification (and thus paying the logarithmic overhead when working + * with the map, but not having it when sorting in the end). Usually, though, + * this saves us additional memory allocations for sorting the terms as it is + * done in-place instead of copying the result out of the std::unordered_map + * into a sortable container. + */ +using Sum = std::map; + +/** + * Print a sum. Does not use a particularly useful syntax and is thus only meant + * for debugging. + */ +std::ostream& operator<<(std::ostream& os, const Sum& sum); + +/** + * Add the arithmetic term `n` to the given sum. If negate is true, actually add + * `-n`. If `n` is itself a sum, it automatically flattens it into `sum` (though + * it should not be a deeply nested sum, as it simply recurses). Otherwise, `n` + * is treated as a single summand, that is a (possibly unary) product. + * It does not consider sums within the product. + */ +void addToSum(Sum& sum, TNode n, bool negate = false); + +/** + * Evaluates the sum object (mapping monomials to their multiplicities) into a + * single node (of kind `ADD`, unless the sum has less than two summands). + */ +Node collectSum(const Sum& sum); + +/** + * Distribute a multiplication over one or more additions. The multiplication + * is given as the list of its factors. Though this method also works if none + * of these factors is an addition, there is no point of calling this method + * in this case. The result is the resulting sum after expanding the product + * and pushing the multiplication inside the addition. + * + * The method maintains a `sum` as a mapping from Node to RealAlgebraicNumber. + * The nodes can be understood as monomials, or generally non-value parts of + * the product, while the real algebraic numbers are the multiplicities of these + * monomials or products. This allows to combine summands with identical + * monomials immediately and avoid a potential blow-up. + */ +Node distributeMultiplication(const std::vector& factors); + +} // namespace rewriter +} // namespace arith +} // namespace theory +} // namespace cvc5 + +#endif -- 2.30.2