From 95a9b6e26bac6a822248563681c1943d35367528 Mon Sep 17 00:00:00 2001 From: Gereon Kremer Date: Tue, 1 Mar 2022 00:15:27 +0100 Subject: [PATCH] Refactor rewriting of arithmetic atoms (#8175) This PR uses the new utilities for atom rewriting in the arithmetic rewriter. --- src/theory/arith/arith_rewriter.cpp | 163 ++++++++++++---------------- src/theory/arith/arith_rewriter.h | 7 +- 2 files changed, 72 insertions(+), 98 deletions(-) diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp index 38fb0d162..863eaee1d 100644 --- a/src/theory/arith/arith_rewriter.cpp +++ b/src/theory/arith/arith_rewriter.cpp @@ -33,6 +33,7 @@ #include "theory/arith/rewriter/addition.h" #include "theory/arith/rewriter/node_utils.h" #include "theory/arith/rewriter/ordering.h" +#include "theory/arith/rewriter/rewrite_atom.h" #include "theory/theory.h" #include "util/bitvector.h" #include "util/divisible.h" @@ -47,20 +48,6 @@ namespace arith { namespace { -template -bool evaluateRelation(Kind rel, const L& l, const R& r) -{ - switch (rel) - { - case Kind::LT: return l < r; - case Kind::LEQ: return l <= r; - case Kind::EQUAL: return l == r; - case Kind::GEQ: return l >= r; - case Kind::GT: return l > r; - default: Unreachable(); return false; - } -} - /** * Check whether the parent has a child that is a constant zero. * If so, return this child. Otherwise, return std::nullopt. @@ -85,7 +72,7 @@ ArithRewriter::ArithRewriter(OperatorElim& oe) : d_opElim(oe) {} RewriteResponse ArithRewriter::preRewrite(TNode t) { Trace("arith-rewriter") << "preRewrite(" << t << ")" << std::endl; - if (isAtom(t)) + if (rewriter::isAtom(t)) { auto res = preRewriteAtom(t); Trace("arith-rewriter") @@ -100,7 +87,7 @@ RewriteResponse ArithRewriter::preRewrite(TNode t) RewriteResponse ArithRewriter::postRewrite(TNode t) { Trace("arith-rewriter") << "postRewrite(" << t << ")" << std::endl; - if (isAtom(t)) + if (rewriter::isAtom(t)) { auto res = postRewriteAtom(t); Trace("arith-rewriter") @@ -114,41 +101,33 @@ RewriteResponse ArithRewriter::postRewrite(TNode t) RewriteResponse ArithRewriter::preRewriteAtom(TNode atom) { - Assert(isAtom(atom)); - - NodeManager* nm = NodeManager::currentNM(); + Assert(rewriter::isAtom(atom)); - if (isRelationOperator(atom.getKind()) && atom[0] == atom[1]) + if (auto response = rewriter::tryEvaluateRelationReflexive(atom); response) { - switch (atom.getKind()) - { - case Kind::LT: return RewriteResponse(REWRITE_DONE, nm->mkConst(false)); - case Kind::LEQ: return RewriteResponse(REWRITE_DONE, nm->mkConst(true)); - case Kind::EQUAL: return RewriteResponse(REWRITE_DONE, nm->mkConst(true)); - case Kind::GEQ: return RewriteResponse(REWRITE_DONE, nm->mkConst(true)); - case Kind::GT: return RewriteResponse(REWRITE_DONE, nm->mkConst(false)); - default:; - } + return RewriteResponse(REWRITE_DONE, rewriter::mkConst(*response)); } switch (atom.getKind()) { case Kind::GT: - return RewriteResponse(REWRITE_DONE, - nm->mkNode(kind::LEQ, atom[0], atom[1]).notNode()); + return RewriteResponse( + REWRITE_DONE, + rewriter::buildRelation(kind::LEQ, atom[0], atom[1], true)); case Kind::LT: - return RewriteResponse(REWRITE_DONE, - nm->mkNode(kind::GEQ, atom[0], atom[1]).notNode()); + return RewriteResponse( + REWRITE_DONE, + rewriter::buildRelation(kind::GEQ, atom[0], atom[1], true)); case Kind::IS_INTEGER: if (atom[0].getType().isInteger()) { - return RewriteResponse(REWRITE_DONE, nm->mkConst(true)); + return RewriteResponse(REWRITE_DONE, rewriter::mkConst(true)); } break; case Kind::DIVISIBLE: if (atom.getOperator().getConst().k.isOne()) { - return RewriteResponse(REWRITE_DONE, nm->mkConst(true)); + return RewriteResponse(REWRITE_DONE, rewriter::mkConst(true)); } break; default:; @@ -159,93 +138,91 @@ RewriteResponse ArithRewriter::preRewriteAtom(TNode atom) RewriteResponse ArithRewriter::postRewriteAtom(TNode atom) { - Assert(isAtom(atom)); + Assert(rewriter::isAtom(atom)); + if (atom.getKind() == kind::IS_INTEGER) { return rewriteExtIntegerOp(atom); } else if (atom.getKind() == kind::DIVISIBLE) { + const Integer& k = atom.getOperator().getConst().k; if (atom[0].isConst()) { + const Rational& num = atom[0].getConst(); return RewriteResponse(REWRITE_DONE, - NodeManager::currentNM()->mkConst(bool( - (atom[0].getConst() - / atom.getOperator().getConst().k) - .isIntegral()))); + rewriter::mkConst((num / k).isIntegral())); } - if (atom.getOperator().getConst().k.isOne()) + if (k.isOne()) { - return RewriteResponse(REWRITE_DONE, - NodeManager::currentNM()->mkConst(true)); + return RewriteResponse(REWRITE_DONE, rewriter::mkConst(true)); } NodeManager* nm = NodeManager::currentNM(); return RewriteResponse( REWRITE_AGAIN, - nm->mkNode(kind::EQUAL, - nm->mkNode(kind::INTS_MODULUS_TOTAL, - atom[0], - nm->mkConstInt(Rational( - atom.getOperator().getConst().k))), - nm->mkConstInt(Rational(0)))); + nm->mkNode( + kind::EQUAL, + nm->mkNode(kind::INTS_MODULUS_TOTAL, atom[0], rewriter::mkConst(k)), + rewriter::mkConst(Integer(0)))); + } + + if (auto response = rewriter::tryEvaluateRelationReflexive(atom); response) + { + return RewriteResponse(REWRITE_DONE, rewriter::mkConst(*response)); } // left |><| right + Kind kind = atom.getKind(); TNode left = atom[0]; TNode right = atom[1]; + Assert(isRelationOperator(kind)); - auto* nm = NodeManager::currentNM(); - if (left.isConst()) + if (auto response = rewriter::tryEvaluateRelation(kind, left, right); + response) { - const Rational& l = left.getConst(); - if (right.isConst()) - { - const Rational& r = right.getConst(); - return RewriteResponse( - REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r))); - } - else if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER) - { - const RealAlgebraicNumber& r = - right.getOperator().getConst(); - return RewriteResponse( - REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r))); - } + return RewriteResponse(REWRITE_DONE, rewriter::mkConst(*response)); } - else if (left.getKind() == Kind::REAL_ALGEBRAIC_NUMBER) + + bool negate = false; + + switch (atom.getKind()) { - const RealAlgebraicNumber& l = - left.getOperator().getConst(); - if (right.isConst()) + case Kind::LEQ: + kind = Kind::GEQ; + negate = true; + break; + case Kind::LT: + kind = Kind::GT; + negate = true; + break; + default: break; + } + + rewriter::Sum sum; + rewriter::addToSum(sum, left, negate); + rewriter::addToSum(sum, right, !negate); + + // Now we have (rsum 0) + if (rewriter::isIntegral(atom)) + { + if (kind == Kind::EQUAL) { - const Rational& r = right.getConst(); - return RewriteResponse( - REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r))); + return RewriteResponse(REWRITE_DONE, + rewriter::buildIntegerEquality(std::move(sum))); } - else if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER) + return RewriteResponse( + REWRITE_DONE, rewriter::buildIntegerInequality(std::move(sum), kind)); + } + else + { + if (kind == Kind::EQUAL) { - const RealAlgebraicNumber& r = - right.getOperator().getConst(); - return RewriteResponse( - REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r))); + return RewriteResponse(REWRITE_DONE, + rewriter::buildRealEquality(std::move(sum))); } + return RewriteResponse(REWRITE_DONE, + rewriter::buildRealInequality(std::move(sum), kind)); } - - Polynomial pleft = Polynomial::parsePolynomial(left); - Polynomial pright = Polynomial::parsePolynomial(right); - - Debug("arith::rewriter") << "pleft " << pleft.getNode() << std::endl; - Debug("arith::rewriter") << "pright " << pright.getNode() << std::endl; - - Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright); - Assert(cmp.isNormalForm()); - return RewriteResponse(REWRITE_DONE, cmp.getNode()); -} - -bool ArithRewriter::isAtom(TNode n) { - Kind k = n.getKind(); - return arith::isRelationOperator(k) || k == kind::IS_INTEGER - || k == kind::DIVISIBLE; } RewriteResponse ArithRewriter::rewriteNeg(TNode t, bool pre) diff --git a/src/theory/arith/arith_rewriter.h b/src/theory/arith/arith_rewriter.h index c4590ffdc..3ffa90015 100644 --- a/src/theory/arith/arith_rewriter.h +++ b/src/theory/arith/arith_rewriter.h @@ -43,7 +43,9 @@ class ArithRewriter : public TheoryRewriter TrustNode expandDefinition(Node node) override; private: + /** preRewrite for atoms */ static RewriteResponse preRewriteAtom(TNode t); + /** postRewrite for atoms */ static RewriteResponse postRewriteAtom(TNode t); static RewriteResponse preRewriteTerm(TNode t); @@ -82,11 +84,6 @@ class ArithRewriter : public TheoryRewriter /** postRewrite transcendental functions */ static RewriteResponse postRewriteTranscendental(TNode t); - static bool isAtom(TNode n); - - static inline bool isTerm(TNode n) { - return !isAtom(n); - } /** return rewrite */ static RewriteResponse returnRewrite(TNode t, Node ret, Rewrite r); /** The operator elimination utility */ -- 2.30.2