From 477d211e3d0e15d434aed932cdd5e636a455f8ee Mon Sep 17 00:00:00 2001 From: Gereon Kremer Date: Thu, 13 Jan 2022 13:53:16 -0800 Subject: [PATCH] Add arithmetic rewriter for RAN (#7929) This adds a rewriter for real arithmetic constants, possibly rewriting them to a rational. --- src/theory/arith/arith_rewriter.cpp | 22 ++++++++++++++++++- src/theory/arith/arith_rewriter.h | 4 +++- .../theory/theory_arith_rewriter_black.cpp | 9 ++++++++ 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp index 3a355839d..9ccb6407f 100644 --- a/src/theory/arith/arith_rewriter.cpp +++ b/src/theory/arith/arith_rewriter.cpp @@ -231,6 +231,21 @@ RewriteResponse ArithRewriter::rewriteConstant(TNode t){ return RewriteResponse(REWRITE_DONE, t); } +RewriteResponse ArithRewriter::rewriteRAN(TNode t) +{ + Assert(t.getKind() == REAL_ALGEBRAIC_NUMBER); + + const RealAlgebraicNumber& r = + t.getOperator().getConst(); + if (r.isRational()) + { + return RewriteResponse( + REWRITE_DONE, NodeManager::currentNM()->mkConstReal(r.toRational())); + } + + return RewriteResponse(REWRITE_DONE, t); +} + RewriteResponse ArithRewriter::rewriteVariable(TNode t){ Assert(t.isVar()); @@ -279,6 +294,7 @@ RewriteResponse ArithRewriter::preRewriteTerm(TNode t){ return rewriteVariable(t); }else{ switch(Kind k = t.getKind()){ + case kind::REAL_ALGEBRAIC_NUMBER: return rewriteRAN(t); case kind::MINUS: return rewriteMinus(t); case kind::UMINUS: return rewriteUMinus(t, true); case kind::DIVISION: @@ -341,6 +357,7 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){ }else{ Trace("arith-rewriter") << "postRewriteTerm: " << t << std::endl; switch(t.getKind()){ + case kind::REAL_ALGEBRAIC_NUMBER: return rewriteRAN(t); case kind::MINUS: return rewriteMinus(t); case kind::UMINUS: return rewriteUMinus(t, false); case kind::DIVISION: @@ -564,7 +581,10 @@ RewriteResponse ArithRewriter::postRewritePlus(TNode t){ if (poly.containsConstant()) { ran += RealAlgebraicNumber(poly.getHead().getConstant().getValue()); - poly = poly.getTail(); + if (!poly.isConstant()) + { + poly = poly.getTail(); + } } auto* nm = NodeManager::currentNM(); diff --git a/src/theory/arith/arith_rewriter.h b/src/theory/arith/arith_rewriter.h index ad9e17145..2e89432f8 100644 --- a/src/theory/arith/arith_rewriter.h +++ b/src/theory/arith/arith_rewriter.h @@ -52,8 +52,10 @@ class ArithRewriter : public TheoryRewriter static RewriteResponse preRewriteTerm(TNode t); static RewriteResponse postRewriteTerm(TNode t); - static RewriteResponse rewriteVariable(TNode t); static RewriteResponse rewriteConstant(TNode t); + static RewriteResponse rewriteRAN(TNode t); + static RewriteResponse rewriteVariable(TNode t); + static RewriteResponse rewriteMinus(TNode t); static RewriteResponse rewriteUMinus(TNode t, bool pre); static RewriteResponse rewriteDiv(TNode t, bool pre); diff --git a/test/unit/theory/theory_arith_rewriter_black.cpp b/test/unit/theory/theory_arith_rewriter_black.cpp index 99766e994..ab3895cbb 100644 --- a/test/unit/theory/theory_arith_rewriter_black.cpp +++ b/test/unit/theory/theory_arith_rewriter_black.cpp @@ -38,6 +38,15 @@ TEST_F(TestTheoryArithRewriterBlack, RealAlgebraicNumber) EXPECT_EQ(n.getKind(), Kind::CONST_RATIONAL); EXPECT_EQ(n.getConst(), Rational(2)); } + { + RealAlgebraicNumber twosqrt2({-8, 0, 1}, 2, 3); + RealAlgebraicNumber sqrt2({-2, 0, 1}, 1, 3); + Node n = d_nodeManager->mkRealAlgebraicNumber(sqrt2); + n = d_nodeManager->mkNode(Kind::PLUS, n, n); + n = d_slvEngine->getRewriter()->rewrite(n); + EXPECT_EQ(n.getKind(), Kind::REAL_ALGEBRAIC_NUMBER); + EXPECT_EQ(n.getOperator().getConst(), twosqrt2); + } { RealAlgebraicNumber sqrt2({-2, 0, 1}, 1, 3); Node n = d_nodeManager->mkRealAlgebraicNumber(sqrt2); -- 2.30.2