From 549329cd3803b1ebe6e59036e1d69fb21474ca2d Mon Sep 17 00:00:00 2001 From: Gereon Kremer Date: Fri, 14 Jan 2022 07:57:24 -0800 Subject: [PATCH] Add support for RANs in rewriter for `MULT` (#7940) This PR refactors the post rewriter for multiplication. It now supports real algebraic numbers and tries a bit harder to avoid the overhead of the normal form abstraction. --- src/theory/arith/arith_rewriter.cpp | 47 ++++++++++++++++--- .../theory/theory_arith_rewriter_black.cpp | 8 ++++ 2 files changed, 48 insertions(+), 7 deletions(-) diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp index 1c9555820..bef9f9a3e 100644 --- a/src/theory/arith/arith_rewriter.cpp +++ b/src/theory/arith/arith_rewriter.cpp @@ -606,17 +606,50 @@ RewriteResponse ArithRewriter::postRewritePlus(TNode t){ RewriteResponse ArithRewriter::postRewriteMult(TNode t){ Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT); + Assert(t.getNumChildren() >= 2); - Polynomial res = Polynomial::mkOne(); + Rational rational = Rational(1); + RealAlgebraicNumber ran = RealAlgebraicNumber(Integer(1)); + Polynomial poly = Polynomial::mkOne(); - for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){ - Node curr = *i; - Polynomial currPoly = Polynomial::parsePolynomial(curr); - - res = res * currPoly; + for (const auto& child : t) + { + if (child.isConst()) + { + if (child.getConst().isZero()) + { + return RewriteResponse(REWRITE_DONE, child); + } + rational *= child.getConst(); + } + else if (child.getKind() == Kind::REAL_ALGEBRAIC_NUMBER) + { + ran *= child.getOperator().getConst(); + } + else + { + poly = poly * Polynomial::parsePolynomial(child); + } } - return RewriteResponse(REWRITE_DONE, res.getNode()); + if (!rational.isOne()) + { + poly = poly * rational; + } + if (isOne(ran)) + { + return RewriteResponse(REWRITE_DONE, poly.getNode()); + } + auto* nm = NodeManager::currentNM(); + if (poly.isConstant()) + { + ran *= RealAlgebraicNumber(poly.getHead().getConstant().getValue()); + return RewriteResponse(REWRITE_DONE, nm->mkRealAlgebraicNumber(ran)); + } + return RewriteResponse( + REWRITE_DONE, + nm->mkNode( + Kind::MULT, nm->mkRealAlgebraicNumber(ran), poly.getNode())); } RewriteResponse ArithRewriter::postRewritePow2(TNode t) diff --git a/test/unit/theory/theory_arith_rewriter_black.cpp b/test/unit/theory/theory_arith_rewriter_black.cpp index a836cf8bd..0147b591f 100644 --- a/test/unit/theory/theory_arith_rewriter_black.cpp +++ b/test/unit/theory/theory_arith_rewriter_black.cpp @@ -47,6 +47,14 @@ TEST_F(TestTheoryArithRewriterBlack, RealAlgebraicNumber) EXPECT_EQ(n.getKind(), Kind::REAL_ALGEBRAIC_NUMBER); EXPECT_EQ(n.getOperator().getConst(), twosqrt2); } + { + RealAlgebraicNumber sqrt2({-2, 0, 1}, 1, 3); + Node n = d_nodeManager->mkRealAlgebraicNumber(sqrt2); + n = d_nodeManager->mkNode(Kind::MULT, n, n); + n = d_slvEngine->getRewriter()->rewrite(n); + EXPECT_EQ(n.getKind(), Kind::CONST_RATIONAL); + EXPECT_EQ(n.getConst(), Rational(2)); + } { RealAlgebraicNumber sqrt2({-2, 0, 1}, 1, 3); Node n = d_nodeManager->mkRealAlgebraicNumber(sqrt2); -- 2.30.2