From: Gereon Kremer Date: Wed, 12 Jan 2022 22:51:29 +0000 (-0800) Subject: Refactor atom rewriting to be RAN-aware (#7928) X-Git-Tag: cvc5-1.0.0~555 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=defb0de81171b6633e5ef745f927217614a4fe54;p=cvc5.git Refactor atom rewriting to be RAN-aware (#7928) This PR starts refactoring the arithmetic rewriter by making rewriting of atoms aware of real algebraic numbers. It also is slightly more aggressive now, directly rewriting relational operators where lhs = rhs in the preRewrite stage. --- diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp index 2c3fcdf48..d865eebe9 100644 --- a/src/theory/arith/arith_rewriter.cpp +++ b/src/theory/arith/arith_rewriter.cpp @@ -32,6 +32,7 @@ #include "util/bitvector.h" #include "util/divisible.h" #include "util/iand.h" +#include "util/real_algebraic_number.h" using namespace cvc5::kind; @@ -39,8 +40,184 @@ namespace cvc5 { namespace theory { namespace arith { +namespace { + +template +bool evaluateRelation(Kind rel, const L& l, const R& r) +{ + switch (rel) + { + case Kind::LT: return l < r; + case Kind::LEQ: return l <= r; + case Kind::EQUAL: return l == r; + case Kind::GEQ: return l >= r; + case Kind::GT: return l > r; + default: Unreachable(); return false; + } +} + +} // namespace + ArithRewriter::ArithRewriter(OperatorElim& oe) : d_opElim(oe) {} +RewriteResponse ArithRewriter::preRewrite(TNode t) +{ + if (isAtom(t)) + { + return preRewriteAtom(t); + } + return preRewriteTerm(t); +} + +RewriteResponse ArithRewriter::postRewrite(TNode t) +{ + if (isAtom(t)) + { + RewriteResponse response = postRewriteAtom(t); + if (Debug.isOn("arith::rewriter") && response.d_status == REWRITE_DONE) + { + Comparison::parseNormalForm(response.d_node); + } + return response; + } + RewriteResponse response = postRewriteTerm(t); + if (Debug.isOn("arith::rewriter") && response.d_status == REWRITE_DONE) + { + Polynomial::parsePolynomial(response.d_node); + } + return response; +} + +RewriteResponse ArithRewriter::preRewriteAtom(TNode atom) +{ + Assert(isAtom(atom)); + + NodeManager* nm = NodeManager::currentNM(); + + if (isRelationOperator(atom.getKind()) && atom[0] == atom[1]) + { + switch (atom.getKind()) + { + case Kind::LT: return RewriteResponse(REWRITE_DONE, nm->mkConst(false)); + case Kind::LEQ: return RewriteResponse(REWRITE_DONE, nm->mkConst(true)); + case Kind::EQUAL: return RewriteResponse(REWRITE_DONE, nm->mkConst(true)); + case Kind::GEQ: return RewriteResponse(REWRITE_DONE, nm->mkConst(true)); + case Kind::GT: return RewriteResponse(REWRITE_DONE, nm->mkConst(false)); + default:; + } + } + + switch (atom.getKind()) + { + case Kind::GT: + return RewriteResponse(REWRITE_DONE, + nm->mkNode(kind::LEQ, atom[0], atom[1]).notNode()); + case Kind::LT: + return RewriteResponse(REWRITE_DONE, + nm->mkNode(kind::GEQ, atom[0], atom[1]).notNode()); + case Kind::IS_INTEGER: + if (atom[0].getType().isInteger()) + { + return RewriteResponse(REWRITE_DONE, nm->mkConst(true)); + } + break; + case Kind::DIVISIBLE: + if (atom.getOperator().getConst().k.isOne()) + { + return RewriteResponse(REWRITE_DONE, nm->mkConst(true)); + } + break; + default:; + } + + return RewriteResponse(REWRITE_DONE, atom); +} + +RewriteResponse ArithRewriter::postRewriteAtom(TNode atom) +{ + Assert(isAtom(atom)); + if (atom.getKind() == kind::IS_INTEGER) + { + return rewriteExtIntegerOp(atom); + } + else if (atom.getKind() == kind::DIVISIBLE) + { + if (atom[0].isConst()) + { + return RewriteResponse(REWRITE_DONE, + NodeManager::currentNM()->mkConst(bool( + (atom[0].getConst() + / atom.getOperator().getConst().k) + .isIntegral()))); + } + if (atom.getOperator().getConst().k.isOne()) + { + return RewriteResponse(REWRITE_DONE, + NodeManager::currentNM()->mkConst(true)); + } + NodeManager* nm = NodeManager::currentNM(); + return RewriteResponse( + REWRITE_AGAIN, + nm->mkNode(kind::EQUAL, + nm->mkNode(kind::INTS_MODULUS_TOTAL, + atom[0], + nm->mkConstInt(Rational( + atom.getOperator().getConst().k))), + nm->mkConstInt(Rational(0)))); + } + + // left |><| right + TNode left = atom[0]; + TNode right = atom[1]; + + auto* nm = NodeManager::currentNM(); + if (left.isConst()) + { + const Rational& l = left.getConst(); + if (right.isConst()) + { + const Rational& r = right.getConst(); + return RewriteResponse( + REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r))); + } + else if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER) + { + const RealAlgebraicNumber& r = + right.getOperator().getConst(); + return RewriteResponse( + REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r))); + } + } + else if (left.getKind() == Kind::REAL_ALGEBRAIC_NUMBER) + { + const RealAlgebraicNumber& l = + left.getOperator().getConst(); + if (right.isConst()) + { + const Rational& r = right.getConst(); + return RewriteResponse( + REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r))); + } + else if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER) + { + const RealAlgebraicNumber& r = + right.getOperator().getConst(); + return RewriteResponse( + REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r))); + } + } + + Polynomial pleft = Polynomial::parsePolynomial(left); + Polynomial pright = Polynomial::parsePolynomial(right); + + Debug("arith::rewriter") << "pleft " << pleft.getNode() << std::endl; + Debug("arith::rewriter") << "pright " << pright.getNode() << std::endl; + + Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright); + Assert(cmp.isNormalForm()); + return RewriteResponse(REWRITE_DONE, cmp.getNode()); +} + bool ArithRewriter::isAtom(TNode n) { Kind k = n.getKind(); return arith::isRelationOperator(k) || k == kind::IS_INTEGER @@ -648,100 +825,6 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) { return RewriteResponse(REWRITE_DONE, t); } -RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){ - if(atom.getKind() == kind::IS_INTEGER) { - return rewriteExtIntegerOp(atom); - } else if(atom.getKind() == kind::DIVISIBLE) { - if(atom[0].isConst()) { - return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(bool((atom[0].getConst() / atom.getOperator().getConst().k).isIntegral()))); - } - if(atom.getOperator().getConst().k.isOne()) { - return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true)); - } - NodeManager* nm = NodeManager::currentNM(); - return RewriteResponse( - REWRITE_AGAIN, - nm->mkNode(kind::EQUAL, - nm->mkNode(kind::INTS_MODULUS_TOTAL, - atom[0], - nm->mkConstInt(Rational( - atom.getOperator().getConst().k))), - nm->mkConstInt(Rational(0)))); - } - - // left |><| right - TNode left = atom[0]; - TNode right = atom[1]; - - Polynomial pleft = Polynomial::parsePolynomial(left); - Polynomial pright = Polynomial::parsePolynomial(right); - - Debug("arith::rewriter") << "pleft " << pleft.getNode() << std::endl; - Debug("arith::rewriter") << "pright " << pright.getNode() << std::endl; - - Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright); - Assert(cmp.isNormalForm()); - return RewriteResponse(REWRITE_DONE, cmp.getNode()); -} - -RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){ - Assert(isAtom(atom)); - - NodeManager* currNM = NodeManager::currentNM(); - - if(atom.getKind() == kind::EQUAL) { - if(atom[0] == atom[1]) { - return RewriteResponse(REWRITE_DONE, currNM->mkConst(true)); - } - }else if(atom.getKind() == kind::GT){ - Node leq = currNM->mkNode(kind::LEQ, atom[0], atom[1]); - return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, leq)); - }else if(atom.getKind() == kind::LT){ - Node geq = currNM->mkNode(kind::GEQ, atom[0], atom[1]); - return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, geq)); - }else if(atom.getKind() == kind::IS_INTEGER){ - if(atom[0].getType().isInteger()){ - return RewriteResponse(REWRITE_DONE, currNM->mkConst(true)); - } - }else if(atom.getKind() == kind::DIVISIBLE){ - if(atom.getOperator().getConst().k.isOne()){ - return RewriteResponse(REWRITE_DONE, currNM->mkConst(true)); - } - } - - return RewriteResponse(REWRITE_DONE, atom); -} - -RewriteResponse ArithRewriter::postRewrite(TNode t){ - if(isTerm(t)){ - RewriteResponse response = postRewriteTerm(t); - if (Debug.isOn("arith::rewriter") && response.d_status == REWRITE_DONE) - { - Polynomial::parsePolynomial(response.d_node); - } - return response; - }else if(isAtom(t)){ - RewriteResponse response = postRewriteAtom(t); - if (Debug.isOn("arith::rewriter") && response.d_status == REWRITE_DONE) - { - Comparison::parseNormalForm(response.d_node); - } - return response; - }else{ - Unreachable(); - } -} - -RewriteResponse ArithRewriter::preRewrite(TNode t){ - if(isTerm(t)){ - return preRewriteTerm(t); - }else if(isAtom(t)){ - return preRewriteAtom(t); - }else{ - Unreachable(); - } -} - Node ArithRewriter::makeUnaryMinusNode(TNode n){ NodeManager* nm = NodeManager::currentNM(); Rational qNegOne(-1); diff --git a/src/theory/arith/arith_rewriter.h b/src/theory/arith/arith_rewriter.h index 5fcb628b9..ad9e17145 100644 --- a/src/theory/arith/arith_rewriter.h +++ b/src/theory/arith/arith_rewriter.h @@ -43,6 +43,9 @@ class ArithRewriter : public TheoryRewriter TrustNode expandDefinition(Node node) override; private: + static RewriteResponse preRewriteAtom(TNode t); + static RewriteResponse postRewriteAtom(TNode t); + static Node makeSubtractionNode(TNode l, TNode r); static Node makeUnaryMinusNode(TNode n); @@ -71,9 +74,6 @@ class ArithRewriter : public TheoryRewriter static RewriteResponse preRewriteTranscendental(TNode t); static RewriteResponse postRewriteTranscendental(TNode t); - static RewriteResponse preRewriteAtom(TNode t); - static RewriteResponse postRewriteAtom(TNode t); - static bool isAtom(TNode n); static inline bool isTerm(TNode n) {