Add support for RANs in rewriter for `MULT` (#7940)
authorGereon Kremer <gkremer@stanford.edu>
Fri, 14 Jan 2022 15:57:24 +0000 (07:57 -0800)
committerGitHub <noreply@github.com>
Fri, 14 Jan 2022 15:57:24 +0000 (15:57 +0000)
This PR refactors the post rewriter for multiplication. It now supports real algebraic numbers and tries a bit harder to avoid the overhead of the normal form abstraction.

src/theory/arith/arith_rewriter.cpp
test/unit/theory/theory_arith_rewriter_black.cpp

index 1c955582099e356082b5aec092cf5a29e1aa6a4f..bef9f9a3ee4ce9a1c5fb71b9b11b458f23b58e40 100644 (file)
@@ -606,17 +606,50 @@ RewriteResponse ArithRewriter::postRewritePlus(TNode t){
 
 RewriteResponse ArithRewriter::postRewriteMult(TNode t){
   Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
+  Assert(t.getNumChildren() >= 2);
 
-  Polynomial res = Polynomial::mkOne();
+  Rational rational = Rational(1);
+  RealAlgebraicNumber ran = RealAlgebraicNumber(Integer(1));
+  Polynomial poly = Polynomial::mkOne();
 
-  for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
-    Node curr = *i;
-    Polynomial currPoly = Polynomial::parsePolynomial(curr);
-
-    res = res * currPoly;
+  for (const auto& child : t)
+  {
+    if (child.isConst())
+    {
+      if (child.getConst<Rational>().isZero())
+      {
+        return RewriteResponse(REWRITE_DONE, child);
+      }
+      rational *= child.getConst<Rational>();
+    }
+    else if (child.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+    {
+      ran *= child.getOperator().getConst<RealAlgebraicNumber>();
+    }
+    else
+    {
+      poly = poly * Polynomial::parsePolynomial(child);
+    }
   }
 
-  return RewriteResponse(REWRITE_DONE, res.getNode());
+  if (!rational.isOne())
+  {
+    poly = poly * rational;
+  }
+  if (isOne(ran))
+  {
+    return RewriteResponse(REWRITE_DONE, poly.getNode());
+  }
+  auto* nm = NodeManager::currentNM();
+  if (poly.isConstant())
+  {
+    ran *= RealAlgebraicNumber(poly.getHead().getConstant().getValue());
+    return RewriteResponse(REWRITE_DONE, nm->mkRealAlgebraicNumber(ran));
+  }
+  return RewriteResponse(
+      REWRITE_DONE,
+      nm->mkNode(
+          Kind::MULT, nm->mkRealAlgebraicNumber(ran), poly.getNode()));
 }
 
 RewriteResponse ArithRewriter::postRewritePow2(TNode t)
index a836cf8bda9fedab11d3edaf826bd91fee5d902a..0147b591f60b49e9753347ee1593e403e574ee73 100644 (file)
@@ -47,6 +47,14 @@ TEST_F(TestTheoryArithRewriterBlack, RealAlgebraicNumber)
     EXPECT_EQ(n.getKind(), Kind::REAL_ALGEBRAIC_NUMBER);
     EXPECT_EQ(n.getOperator().getConst<RealAlgebraicNumber>(), twosqrt2);
   }
+  {
+    RealAlgebraicNumber sqrt2({-2, 0, 1}, 1, 3);
+    Node n = d_nodeManager->mkRealAlgebraicNumber(sqrt2);
+    n = d_nodeManager->mkNode(Kind::MULT, n, n);
+    n = d_slvEngine->getRewriter()->rewrite(n);
+    EXPECT_EQ(n.getKind(), Kind::CONST_RATIONAL);
+    EXPECT_EQ(n.getConst<Rational>(), Rational(2));
+  }
   {
     RealAlgebraicNumber sqrt2({-2, 0, 1}, 1, 3);
     Node n = d_nodeManager->mkRealAlgebraicNumber(sqrt2);