Add arithmetic rewriter for RAN (#7929)
authorGereon Kremer <gkremer@stanford.edu>
Thu, 13 Jan 2022 21:53:16 +0000 (13:53 -0800)
committerGitHub <noreply@github.com>
Thu, 13 Jan 2022 21:53:16 +0000 (21:53 +0000)
This adds a rewriter for real arithmetic constants, possibly rewriting them to a rational.

src/theory/arith/arith_rewriter.cpp
src/theory/arith/arith_rewriter.h
test/unit/theory/theory_arith_rewriter_black.cpp

index 3a355839de29966f1a6c14553c5342d234175472..9ccb6407f3c8bba1d467f42f92d58cb29fa7a3ae 100644 (file)
@@ -231,6 +231,21 @@ RewriteResponse ArithRewriter::rewriteConstant(TNode t){
   return RewriteResponse(REWRITE_DONE, t);
 }
 
+RewriteResponse ArithRewriter::rewriteRAN(TNode t)
+{
+  Assert(t.getKind() == REAL_ALGEBRAIC_NUMBER);
+
+  const RealAlgebraicNumber& r =
+      t.getOperator().getConst<RealAlgebraicNumber>();
+  if (r.isRational())
+  {
+    return RewriteResponse(
+        REWRITE_DONE, NodeManager::currentNM()->mkConstReal(r.toRational()));
+  }
+
+  return RewriteResponse(REWRITE_DONE, t);
+}
+
 RewriteResponse ArithRewriter::rewriteVariable(TNode t){
   Assert(t.isVar());
 
@@ -279,6 +294,7 @@ RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
     return rewriteVariable(t);
   }else{
     switch(Kind k = t.getKind()){
+      case kind::REAL_ALGEBRAIC_NUMBER: return rewriteRAN(t);
       case kind::MINUS: return rewriteMinus(t);
       case kind::UMINUS: return rewriteUMinus(t, true);
       case kind::DIVISION:
@@ -341,6 +357,7 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
   }else{
     Trace("arith-rewriter") << "postRewriteTerm: " << t << std::endl;
     switch(t.getKind()){
+      case kind::REAL_ALGEBRAIC_NUMBER: return rewriteRAN(t);
       case kind::MINUS: return rewriteMinus(t);
       case kind::UMINUS: return rewriteUMinus(t, false);
       case kind::DIVISION:
@@ -564,7 +581,10 @@ RewriteResponse ArithRewriter::postRewritePlus(TNode t){
   if (poly.containsConstant())
   {
     ran += RealAlgebraicNumber(poly.getHead().getConstant().getValue());
-    poly = poly.getTail();
+    if (!poly.isConstant())
+    {
+      poly = poly.getTail();
+    }
   }
 
   auto* nm = NodeManager::currentNM();
index ad9e171451aee732b91e73a0f5cd6c7783145182..2e89432f8f40be68c7a4977f307a3d2921a4538b 100644 (file)
@@ -52,8 +52,10 @@ class ArithRewriter : public TheoryRewriter
   static RewriteResponse preRewriteTerm(TNode t);
   static RewriteResponse postRewriteTerm(TNode t);
 
-  static RewriteResponse rewriteVariable(TNode t);
   static RewriteResponse rewriteConstant(TNode t);
+  static RewriteResponse rewriteRAN(TNode t);
+  static RewriteResponse rewriteVariable(TNode t);
+
   static RewriteResponse rewriteMinus(TNode t);
   static RewriteResponse rewriteUMinus(TNode t, bool pre);
   static RewriteResponse rewriteDiv(TNode t, bool pre);
index 99766e9941c56ddc8b02d3b5c8690a6e22981d47..ab3895cbbd054bbae711aef186757948d12507e3 100644 (file)
@@ -38,6 +38,15 @@ TEST_F(TestTheoryArithRewriterBlack, RealAlgebraicNumber)
     EXPECT_EQ(n.getKind(), Kind::CONST_RATIONAL);
     EXPECT_EQ(n.getConst<Rational>(), Rational(2));
   }
+  {
+    RealAlgebraicNumber twosqrt2({-8, 0, 1}, 2, 3);
+    RealAlgebraicNumber sqrt2({-2, 0, 1}, 1, 3);
+    Node n = d_nodeManager->mkRealAlgebraicNumber(sqrt2);
+    n = d_nodeManager->mkNode(Kind::PLUS, n, n);
+    n = d_slvEngine->getRewriter()->rewrite(n);
+    EXPECT_EQ(n.getKind(), Kind::REAL_ALGEBRAIC_NUMBER);
+    EXPECT_EQ(n.getOperator().getConst<RealAlgebraicNumber>(), twosqrt2);
+  }
   {
     RealAlgebraicNumber sqrt2({-2, 0, 1}, 1, 3);
     Node n = d_nodeManager->mkRealAlgebraicNumber(sqrt2);