From f68764d01a070be71dda42e5e26483b2cfb1281a Mon Sep 17 00:00:00 2001 From: Gereon Kremer Date: Thu, 20 Jan 2022 10:29:25 -0800 Subject: [PATCH] Refactor abs rewriting (#7935) This PR refactors rewriting for ABS to also support real algebraic number. We also generalize the ABS operator to real arguments in general, instead of integer arguments. --- src/theory/arith/arith_rewriter.cpp | 64 +++++++++---------- src/theory/arith/arith_rewriter.h | 1 + src/theory/arith/kinds | 2 +- .../theory/theory_arith_rewriter_black.cpp | 36 +++++++++++ 4 files changed, 70 insertions(+), 33 deletions(-) diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp index c698490b0..4eba5dbfc 100644 --- a/src/theory/arith/arith_rewriter.cpp +++ b/src/theory/arith/arith_rewriter.cpp @@ -331,22 +331,7 @@ RewriteResponse ArithRewriter::preRewriteTerm(TNode t){ case kind::INTS_MODULUS: return rewriteIntsDivMod(t, true); case kind::INTS_DIVISION_TOTAL: case kind::INTS_MODULUS_TOTAL: return rewriteIntsDivModTotal(t, true); - case kind::ABS: - if (t[0].isConst()) - { - const Rational& rat = t[0].getConst(); - if (rat >= 0) - { - return RewriteResponse(REWRITE_DONE, t[0]); - } - else - { - return RewriteResponse(REWRITE_DONE, - NodeManager::currentNM()->mkConstRealOrInt( - t[0].getType(), -rat)); - } - } - return RewriteResponse(REWRITE_DONE, t); + case kind::ABS: return rewriteAbs(t); case kind::IS_INTEGER: case kind::TO_INTEGER: return RewriteResponse(REWRITE_DONE, t); case kind::TO_REAL: @@ -394,22 +379,7 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){ case kind::INTS_MODULUS: return rewriteIntsDivMod(t, false); case kind::INTS_DIVISION_TOTAL: case kind::INTS_MODULUS_TOTAL: return rewriteIntsDivModTotal(t, false); - case kind::ABS: - if (t[0].isConst()) - { - const Rational& rat = t[0].getConst(); - if (rat >= 0) - { - return RewriteResponse(REWRITE_DONE, t[0]); - } - else - { - return RewriteResponse(REWRITE_DONE, - NodeManager::currentNM()->mkConstRealOrInt( - t[0].getType(), -rat)); - } - } - return RewriteResponse(REWRITE_DONE, t); + case kind::ABS: return rewriteAbs(t); case kind::TO_REAL: case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]); case kind::TO_INTEGER: return rewriteExtIntegerOp(t); @@ -997,6 +967,36 @@ RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){ return RewriteResponse(REWRITE_DONE, t); } +RewriteResponse ArithRewriter::rewriteAbs(TNode t) +{ + Assert(t.getKind() == Kind::ABS); + Assert(t.getNumChildren() == 1); + + if (t[0].isConst()) + { + const Rational& rat = t[0].getConst(); + if (rat >= 0) + { + return RewriteResponse(REWRITE_DONE, t[0]); + } + return RewriteResponse( + REWRITE_DONE, + NodeManager::currentNM()->mkConstRealOrInt(t[0].getType(), -rat)); + } + if (t[0].getKind() == Kind::REAL_ALGEBRAIC_NUMBER) + { + const RealAlgebraicNumber& ran = + t[0].getOperator().getConst(); + if (ran >= RealAlgebraicNumber()) + { + return RewriteResponse(REWRITE_DONE, t[0]); + } + return RewriteResponse( + REWRITE_DONE, NodeManager::currentNM()->mkRealAlgebraicNumber(-ran)); + } + return RewriteResponse(REWRITE_DONE, t); +} + RewriteResponse ArithRewriter::rewriteIntsDivMod(TNode t, bool pre) { NodeManager* nm = NodeManager::currentNM(); diff --git a/src/theory/arith/arith_rewriter.h b/src/theory/arith/arith_rewriter.h index 2e89432f8..90140cc18 100644 --- a/src/theory/arith/arith_rewriter.h +++ b/src/theory/arith/arith_rewriter.h @@ -59,6 +59,7 @@ class ArithRewriter : public TheoryRewriter static RewriteResponse rewriteMinus(TNode t); static RewriteResponse rewriteUMinus(TNode t, bool pre); static RewriteResponse rewriteDiv(TNode t, bool pre); + static RewriteResponse rewriteAbs(TNode t); static RewriteResponse rewriteIntsDivMod(TNode t, bool pre); static RewriteResponse rewriteIntsDivModTotal(TNode t, bool pre); /** Entry for applications of to_int and is_int */ diff --git a/src/theory/arith/kinds b/src/theory/arith/kinds index 40557bc65..ba326ba56 100644 --- a/src/theory/arith/kinds +++ b/src/theory/arith/kinds @@ -145,7 +145,7 @@ typerule CAST_TO_REAL ::cvc5::theory::arith::ArithOperatorTypeRule typerule TO_INTEGER ::cvc5::theory::arith::ArithOperatorTypeRule typerule IS_INTEGER "SimpleTypeRule" -typerule ABS "SimpleTypeRule" +typerule ABS ::cvc5::theory::arith::ArithOperatorTypeRule typerule INTS_DIVISION "SimpleTypeRule" typerule INTS_MODULUS "SimpleTypeRule" typerule DIVISIBLE "SimpleTypeRule" diff --git a/test/unit/theory/theory_arith_rewriter_black.cpp b/test/unit/theory/theory_arith_rewriter_black.cpp index 0147b591f..3fcd74356 100644 --- a/test/unit/theory/theory_arith_rewriter_black.cpp +++ b/test/unit/theory/theory_arith_rewriter_black.cpp @@ -79,5 +79,41 @@ TEST_F(TestTheoryArithRewriterBlack, RealAlgebraicNumber) } } +TEST_F(TestTheoryArithRewriterBlack, Abs) +{ + { + Node a = d_nodeManager->mkConstReal(10); + Node b = d_nodeManager->mkConstReal(-10); + Node m = d_nodeManager->mkNode(Kind::ABS, a); + Node n = d_nodeManager->mkNode(Kind::ABS, b); + m = d_slvEngine->getRewriter()->rewrite(m); + n = d_slvEngine->getRewriter()->rewrite(n); + EXPECT_EQ(m, a); + EXPECT_EQ(n, a); + } + { + Node a = d_nodeManager->mkConstReal(Rational(3,2)); + Node b = d_nodeManager->mkConstReal(Rational(-3,2)); + Node m = d_nodeManager->mkNode(Kind::ABS, a); + Node n = d_nodeManager->mkNode(Kind::ABS, b); + m = d_slvEngine->getRewriter()->rewrite(m); + n = d_slvEngine->getRewriter()->rewrite(n); + EXPECT_EQ(m, a); + EXPECT_EQ(n, a); + } + { + RealAlgebraicNumber msqrt2({-2, 0, 1}, -2, -1); + RealAlgebraicNumber sqrt2({-2, 0, 1}, 1, 2); + Node a = d_nodeManager->mkRealAlgebraicNumber(msqrt2); + Node b = d_nodeManager->mkRealAlgebraicNumber(sqrt2); + Node m = d_nodeManager->mkNode(Kind::ABS, a); + Node n = d_nodeManager->mkNode(Kind::ABS, b); + m = d_slvEngine->getRewriter()->rewrite(m); + n = d_slvEngine->getRewriter()->rewrite(n); + EXPECT_EQ(m, b); + EXPECT_EQ(n, b); + } +} + } // namespace test } // namespace cvc5 -- 2.30.2