From: Gereon Kremer Date: Mon, 24 Jan 2022 19:23:41 +0000 (-0800) Subject: Refactor how arith rewriting checks for mult-by-zero (#7962) X-Git-Tag: cvc5-1.0.0~515 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=dfc685b86237a56b1a54801bafa343b5139aaac0;p=cvc5.git Refactor how arith rewriting checks for mult-by-zero (#7962) This adds an explicit check for a zero factor in the post rewriter for multiplication. --- diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp index 4eba5dbfc..9f1fa40bb 100644 --- a/src/theory/arith/arith_rewriter.cpp +++ b/src/theory/arith/arith_rewriter.cpp @@ -18,6 +18,7 @@ #include "theory/arith/arith_rewriter.h" +#include #include #include #include @@ -56,6 +57,23 @@ bool evaluateRelation(Kind rel, const L& l, const R& r) } } +/** + * Check whether the parent has a child that is a constant zero. + * If so, return this child. Otherwise, return std::nullopt. + */ +template +std::optional getZeroChild(const Iterable& parent) +{ + for (const auto& node : parent) + { + if (node.isConst() && node.template getConst().isZero()) + { + return node; + } + } + return std::nullopt; +} + } // namespace ArithRewriter::ArithRewriter(OperatorElim& oe) : d_opElim(oe) {} @@ -437,21 +455,6 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){ } } -RewriteResponse ArithRewriter::preRewriteMult(TNode node) -{ - Assert(node.getKind() == kind::MULT - || node.getKind() == kind::NONLINEAR_MULT); - - for (const auto& child : node) - { - if (child.isConst() && child.getConst().isZero()) - { - return RewriteResponse(REWRITE_DONE, child); - } - } - return RewriteResponse(REWRITE_DONE, node); -} - static bool canFlatten(Kind k, TNode t){ for(TNode::iterator i = t.begin(); i != t.end(); ++i) { TNode child = *i; @@ -563,10 +566,28 @@ RewriteResponse ArithRewriter::postRewritePlus(TNode t){ nm->mkNode(Kind::PLUS, nm->mkRealAlgebraicNumber(ran), poly.getNode())); } +RewriteResponse ArithRewriter::preRewriteMult(TNode node) +{ + Assert(node.getKind() == kind::MULT + || node.getKind() == kind::NONLINEAR_MULT); + + auto res = getZeroChild(node); + if (res) + { + return RewriteResponse(REWRITE_DONE, *res); + } + return RewriteResponse(REWRITE_DONE, node); +} + RewriteResponse ArithRewriter::postRewriteMult(TNode t){ Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT); Assert(t.getNumChildren() >= 2); + if (auto res = getZeroChild(t); res) + { + return RewriteResponse(REWRITE_DONE, *res); + } + Rational rational = Rational(1); RealAlgebraicNumber ran = RealAlgebraicNumber(Integer(1)); Polynomial poly = Polynomial::mkOne();