#include "theory/arith/arith_rewriter.h"
+#include <optional>
#include <set>
#include <sstream>
#include <stack>
}
}
+/**
+ * Check whether the parent has a child that is a constant zero.
+ * If so, return this child. Otherwise, return std::nullopt.
+ */
+template <typename Iterable>
+std::optional<TNode> getZeroChild(const Iterable& parent)
+{
+ for (const auto& node : parent)
+ {
+ if (node.isConst() && node.template getConst<Rational>().isZero())
+ {
+ return node;
+ }
+ }
+ return std::nullopt;
+}
+
} // namespace
ArithRewriter::ArithRewriter(OperatorElim& oe) : d_opElim(oe) {}
}
}
-RewriteResponse ArithRewriter::preRewriteMult(TNode node)
-{
- Assert(node.getKind() == kind::MULT
- || node.getKind() == kind::NONLINEAR_MULT);
-
- for (const auto& child : node)
- {
- if (child.isConst() && child.getConst<Rational>().isZero())
- {
- return RewriteResponse(REWRITE_DONE, child);
- }
- }
- return RewriteResponse(REWRITE_DONE, node);
-}
-
static bool canFlatten(Kind k, TNode t){
for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
TNode child = *i;
nm->mkNode(Kind::PLUS, nm->mkRealAlgebraicNumber(ran), poly.getNode()));
}
+RewriteResponse ArithRewriter::preRewriteMult(TNode node)
+{
+ Assert(node.getKind() == kind::MULT
+ || node.getKind() == kind::NONLINEAR_MULT);
+
+ auto res = getZeroChild(node);
+ if (res)
+ {
+ return RewriteResponse(REWRITE_DONE, *res);
+ }
+ return RewriteResponse(REWRITE_DONE, node);
+}
+
RewriteResponse ArithRewriter::postRewriteMult(TNode t){
Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
Assert(t.getNumChildren() >= 2);
+ if (auto res = getZeroChild(t); res)
+ {
+ return RewriteResponse(REWRITE_DONE, *res);
+ }
+
Rational rational = Rational(1);
RealAlgebraicNumber ran = RealAlgebraicNumber(Integer(1));
Polynomial poly = Polynomial::mkOne();