Refactor how arith rewriting checks for mult-by-zero (#7962)
authorGereon Kremer <gkremer@stanford.edu>
Mon, 24 Jan 2022 19:23:41 +0000 (11:23 -0800)
committerGitHub <noreply@github.com>
Mon, 24 Jan 2022 19:23:41 +0000 (19:23 +0000)
This adds an explicit check for a zero factor in the post rewriter for multiplication.

src/theory/arith/arith_rewriter.cpp

index 4eba5dbfc006b92dc532a7a26ec960ae66323539..9f1fa40bb8280ee841c9b35cd7d659a52b25d565 100644 (file)
@@ -18,6 +18,7 @@
 
 #include "theory/arith/arith_rewriter.h"
 
+#include <optional>
 #include <set>
 #include <sstream>
 #include <stack>
@@ -56,6 +57,23 @@ bool evaluateRelation(Kind rel, const L& l, const R& r)
   }
 }
 
+/**
+ * Check whether the parent has a child that is a constant zero.
+ * If so, return this child. Otherwise, return std::nullopt.
+ */
+template <typename Iterable>
+std::optional<TNode> getZeroChild(const Iterable& parent)
+{
+  for (const auto& node : parent)
+  {
+    if (node.isConst() && node.template getConst<Rational>().isZero())
+    {
+      return node;
+    }
+  }
+  return std::nullopt;
+}
+
 }  // namespace
 
 ArithRewriter::ArithRewriter(OperatorElim& oe) : d_opElim(oe) {}
@@ -437,21 +455,6 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
   }
 }
 
-RewriteResponse ArithRewriter::preRewriteMult(TNode node)
-{
-  Assert(node.getKind() == kind::MULT
-         || node.getKind() == kind::NONLINEAR_MULT);
-
-  for (const auto& child : node)
-  {
-    if (child.isConst() && child.getConst<Rational>().isZero())
-    {
-      return RewriteResponse(REWRITE_DONE, child);
-    }
-  }
-  return RewriteResponse(REWRITE_DONE, node);
-}
-
 static bool canFlatten(Kind k, TNode t){
   for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
     TNode child = *i;
@@ -563,10 +566,28 @@ RewriteResponse ArithRewriter::postRewritePlus(TNode t){
       nm->mkNode(Kind::PLUS, nm->mkRealAlgebraicNumber(ran), poly.getNode()));
 }
 
+RewriteResponse ArithRewriter::preRewriteMult(TNode node)
+{
+  Assert(node.getKind() == kind::MULT
+         || node.getKind() == kind::NONLINEAR_MULT);
+
+  auto res = getZeroChild(node);
+  if (res)
+  {
+    return RewriteResponse(REWRITE_DONE, *res);
+  }
+  return RewriteResponse(REWRITE_DONE, node);
+}
+
 RewriteResponse ArithRewriter::postRewriteMult(TNode t){
   Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
   Assert(t.getNumChildren() >= 2);
 
+  if (auto res = getZeroChild(t); res)
+  {
+    return RewriteResponse(REWRITE_DONE, *res);
+  }
+
   Rational rational = Rational(1);
   RealAlgebraicNumber ran = RealAlgebraicNumber(Integer(1));
   Polynomial poly = Polynomial::mkOne();