From: Andrew Reynolds Date: Tue, 14 Dec 2021 17:14:04 +0000 (-0600) Subject: Eliminate use of rewrite, CONST_RATIONAL in ArithMSum (#7808) X-Git-Tag: cvc5-1.0.0~671 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=2abdb475ad265c33f1b1658b965bc5b2387313ed;p=cvc5.git Eliminate use of rewrite, CONST_RATIONAL in ArithMSum (#7808) --- diff --git a/src/theory/arith/arith_msum.cpp b/src/theory/arith/arith_msum.cpp index a8edb0e79..0621c1391 100644 --- a/src/theory/arith/arith_msum.cpp +++ b/src/theory/arith/arith_msum.cpp @@ -81,7 +81,8 @@ bool ArithMSum::getMonomialSum(Node n, std::map& msum) bool ArithMSum::getMonomialSumLit(Node lit, std::map& msum) { - if (lit.getKind() == GEQ || lit.getKind() == EQUAL) + if (lit.getKind() == GEQ + || (lit.getKind() == EQUAL && lit[0].getType().isRealOrInt())) { if (getMonomialSum(lit[0], msum)) { @@ -96,6 +97,7 @@ bool ArithMSum::getMonomialSumLit(Node lit, std::map& msum) NodeManager* nm = NodeManager::currentNM(); if (getMonomialSum(lit[1], msum2)) { + TypeNode tn = lit[0].getType(); for (std::map::iterator it = msum2.begin(); it != msum2.end(); ++it) @@ -103,20 +105,20 @@ bool ArithMSum::getMonomialSumLit(Node lit, std::map& msum) std::map::iterator it2 = msum.find(it->first); if (it2 != msum.end()) { - Node r = nm->mkNode(MINUS, - it2->second.isNull() - ? nm->mkConst(CONST_RATIONAL, Rational(1)) - : it2->second, - it->second.isNull() - ? nm->mkConst(CONST_RATIONAL, Rational(1)) - : it->second); - msum[it->first] = Rewriter::rewrite(r); + Rational r1 = it2->second.isNull() + ? Rational(1) + : it2->second.getConst(); + Rational r2 = it->second.isNull() + ? Rational(1) + : it->second.getConst(); + msum[it->first] = nm->mkConstRealOrInt(tn, r1 - r2); } else { msum[it->first] = it->second.isNull() - ? nm->mkConst(CONST_RATIONAL, Rational(-1)) - : negate(it->second); + ? nm->mkConstRealOrInt(tn, Rational(-1)) + : nm->mkConstRealOrInt( + tn, -it->second.getConst()); } } return true; @@ -127,7 +129,7 @@ bool ArithMSum::getMonomialSumLit(Node lit, std::map& msum) return false; } -Node ArithMSum::mkNode(const std::map& msum) +Node ArithMSum::mkNode(TypeNode tn, const std::map& msum) { NodeManager* nm = NodeManager::currentNM(); std::vector children; @@ -146,10 +148,10 @@ Node ArithMSum::mkNode(const std::map& msum) } children.push_back(m); } - return children.size() > 1 ? nm->mkNode(PLUS, children) - : (children.size() == 1 - ? children[0] - : nm->mkConst(CONST_RATIONAL, Rational(0))); + return children.size() > 1 + ? nm->mkNode(PLUS, children) + : (children.size() == 1 ? children[0] + : nm->mkConstRealOrInt(tn, Rational(0))); } int ArithMSum::isolate( @@ -159,11 +161,13 @@ int ArithMSum::isolate( std::map::const_iterator itv = msum.find(v); if (itv != msum.end()) { + NodeManager* nm = NodeManager::currentNM(); std::vector children; Rational r = itv->second.isNull() ? Rational(1) : itv->second.getConst(); if (r.sgn() != 0) { + TypeNode vtn = v.getType(); for (std::map::const_iterator it = msum.begin(); it != msum.end(); ++it) @@ -182,27 +186,25 @@ int ArithMSum::isolate( children.push_back(m); } } - val = children.size() > 1 - ? NodeManager::currentNM()->mkNode(PLUS, children) - : (children.size() == 1 ? children[0] - : NodeManager::currentNM()->mkConst( - CONST_RATIONAL, Rational(0))); + val = + children.size() > 1 + ? nm->mkNode(PLUS, children) + : (children.size() == 1 ? children[0] + : nm->mkConstRealOrInt(vtn, Rational(0))); if (!r.isOne() && !r.isNegativeOne()) { - if (v.getType().isInteger()) + if (vtn.isInteger()) { - veq_c = NodeManager::currentNM()->mkConst(CONST_RATIONAL, r.abs()); + veq_c = nm->mkConstInt(r.abs()); } else { - val = NodeManager::currentNM()->mkNode( - MULT, - val, - NodeManager::currentNM()->mkConst(CONST_RATIONAL, - Rational(1) / r.abs())); + val = nm->mkNode(MULT, val, nm->mkConstReal(Rational(1) / r.abs())); } } - val = r.sgn() == 1 ? negate(val) : Rewriter::rewrite(val); + val = r.sgn() == 1 + ? nm->mkNode(MULT, nm->mkConstRealOrInt(vtn, Rational(-1)), val) + : val; return (r.sgn() == 1 || k == EQUAL) ? 1 : -1; } } @@ -284,29 +286,13 @@ bool ArithMSum::decompose(Node n, Node v, Node& coeff, Node& rem) { coeff = it->second; msum.erase(v); - rem = mkNode(msum); + rem = mkNode(n.getType(), msum); return true; } } return false; } -Node ArithMSum::negate(Node t) -{ - Node tt = NodeManager::currentNM()->mkNode( - MULT, NodeManager::currentNM()->mkConst(CONST_RATIONAL, Rational(-1)), t); - tt = Rewriter::rewrite(tt); - return tt; -} - -Node ArithMSum::offset(Node t, int i) -{ - Node tt = NodeManager::currentNM()->mkNode( - PLUS, NodeManager::currentNM()->mkConst(CONST_RATIONAL, Rational(i)), t); - tt = Rewriter::rewrite(tt); - return tt; -} - void ArithMSum::debugPrintMonomialSum(std::map& msum, const char* c) { for (std::map::iterator it = msum.begin(); it != msum.end(); ++it) diff --git a/src/theory/arith/arith_msum.h b/src/theory/arith/arith_msum.h index 87f56e64f..ae57ee1cb 100644 --- a/src/theory/arith/arith_msum.h +++ b/src/theory/arith/arith_msum.h @@ -103,8 +103,13 @@ class ArithMSum * * Make the Node corresponding to the interpretation of msum, [msum], where: * [msum] = sum_{( v, c ) \in msum } [c]*[v] + * + * @param tn The type of the node to return, which is used only if msum is + * empty + * @param msum The monomial sum + * @return The node corresponding to the monomial sum */ - static Node mkNode(const std::map& msum); + static Node mkNode(TypeNode tn, const std::map& msum); /** make coefficent term * @@ -173,12 +178,6 @@ class ArithMSum */ static bool decompose(Node n, Node v, Node& coeff, Node& rem); - /** return the rewritten form of (UMINUS t) */ - static Node negate(Node t); - - /** return the rewritten form of (PLUS t (CONST_RATIONAL i)) */ - static Node offset(Node t, int i); - /** debug print for a monmoial sum, prints to Trace(c) */ static void debugPrintMonomialSum(std::map& msum, const char* c); }; diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp index 4c01f25f3..af6f23c1f 100644 --- a/src/theory/arith/arith_rewriter.cpp +++ b/src/theory/arith/arith_rewriter.cpp @@ -516,7 +516,7 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) { msum.erase(pi); if (!msum.empty()) { - rem = ArithMSum::mkNode(msum); + rem = ArithMSum::mkNode(t[0].getType(), msum); } } } diff --git a/src/theory/quantifiers/cegqi/ceg_arith_instantiator.cpp b/src/theory/quantifiers/cegqi/ceg_arith_instantiator.cpp index 2d483d502..56debbbac 100644 --- a/src/theory/quantifiers/cegqi/ceg_arith_instantiator.cpp +++ b/src/theory/quantifiers/cegqi/ceg_arith_instantiator.cpp @@ -818,7 +818,7 @@ CegTermType ArithInstantiator::solve_arith(CegInstantiator* ci, // multiply by the coefficient we will isolate for if (itv->second.isNull()) { - vts_coeff[t] = ArithMSum::negate(vts_coeff[t]); + vts_coeff[t] = negate(vts_coeff[t]); } else { @@ -833,7 +833,7 @@ CegTermType ArithInstantiator::solve_arith(CegInstantiator* ci, } else if (itv->second.getConst().sgn() == 1) { - vts_coeff[t] = ArithMSum::negate(vts_coeff[t]); + vts_coeff[t] = negate(vts_coeff[t]); } } } @@ -1040,6 +1040,13 @@ Node ArithInstantiator::getModelBasedProjectionValue(CegInstantiator* ci, return val; } +Node ArithInstantiator::negate(const Node& t) const +{ + NodeManager* nm = NodeManager::currentNM(); + return rewrite( + nm->mkNode(MULT, nm->mkConstRealOrInt(t.getType(), Rational(-1)), t)); +} + } // namespace quantifiers } // namespace theory } // namespace cvc5 diff --git a/src/theory/quantifiers/cegqi/ceg_arith_instantiator.h b/src/theory/quantifiers/cegqi/ceg_arith_instantiator.h index e102b834e..d44ab4993 100644 --- a/src/theory/quantifiers/cegqi/ceg_arith_instantiator.h +++ b/src/theory/quantifiers/cegqi/ceg_arith_instantiator.h @@ -206,6 +206,8 @@ class ArithInstantiator : public Instantiator Node theta, Node inf_coeff, Node delta_coeff); + /** Return the rewritten form of the negation of t */ + Node negate(const Node& t) const; }; } // namespace quantifiers diff --git a/src/theory/quantifiers/fmf/bounded_integers.cpp b/src/theory/quantifiers/fmf/bounded_integers.cpp index 18a63d245..5c0283863 100644 --- a/src/theory/quantifiers/fmf/bounded_integers.cpp +++ b/src/theory/quantifiers/fmf/bounded_integers.cpp @@ -223,6 +223,7 @@ void BoundedIntegers::process( Node q, Node n, bool pol, std::map< Node, Node > msum; if (ArithMSum::getMonomialSumLit(n, msum)) { + NodeManager* nm = NodeManager::currentNM(); Trace("bound-int-debug") << "literal (polarity = " << pol << ") " << n << " is monomial sum : " << std::endl; ArithMSum::debugPrintMonomialSum(msum, "bound-int-debug"); for( std::map< Node, Node >::iterator it = msum.begin(); it != msum.end(); ++it ){ @@ -239,11 +240,11 @@ void BoundedIntegers::process( Node q, Node n, bool pol, n1 = veq[1]; n2 = veq[0]; if( n1.getKind()==BOUND_VARIABLE ){ - n2 = ArithMSum::offset(n2, 1); + n2 = nm->mkNode(PLUS, n2, nm->mkConstInt(Rational(1))); }else{ - n1 = ArithMSum::offset(n1, -1); + n1 = nm->mkNode(PLUS, n1, nm->mkConstInt(Rational(-1))); } - veq = NodeManager::currentNM()->mkNode( GEQ, n1, n2 ); + veq = nm->mkNode(GEQ, n1, n2); } Trace("bound-int-debug") << "Isolated for " << it->first << " : (" << n1 << " >= " << n2 << ")" << std::endl; Node t = n1==it->first ? n2 : n1; diff --git a/src/theory/quantifiers/relevant_domain.cpp b/src/theory/quantifiers/relevant_domain.cpp index 0f3699990..f0684f04a 100644 --- a/src/theory/quantifiers/relevant_domain.cpp +++ b/src/theory/quantifiers/relevant_domain.cpp @@ -24,6 +24,7 @@ #include "theory/quantifiers/term_database.h" #include "theory/quantifiers/term_registry.h" #include "theory/quantifiers/term_util.h" +#include "util/rational.h" using namespace cvc5::kind; @@ -301,6 +302,7 @@ void RelevantDomain::computeRelevantDomainOpCh( RDomain * rf, Node n ) { void RelevantDomain::computeRelevantDomainLit( Node q, bool hasPol, bool pol, Node n ) { if( d_rel_dom_lit[hasPol][pol].find( n )==d_rel_dom_lit[hasPol][pol].end() ){ + NodeManager* nm = NodeManager::currentNM(); RDomainLit& rdl = d_rel_dom_lit[hasPol][pol][n]; rdl.d_merge = false; int varCount = 0; @@ -405,10 +407,14 @@ void RelevantDomain::computeRelevantDomainLit( Node q, bool hasPol, bool pol, No if( ( !hasPol || pol ) && n[0].getType().isInteger() ){ if( n.getKind()==EQUAL ){ for( unsigned i=0; i<2; i++ ){ - rdl.d_val.push_back(ArithMSum::offset(r_add, i == 0 ? 1 : -1)); + Node roff = nm->mkNode( + PLUS, r_add, nm->mkConstInt(Rational(i == 0 ? 1 : -1))); + rdl.d_val.push_back(roff); } }else if( n.getKind()==GEQ ){ - rdl.d_val.push_back(ArithMSum::offset(r_add, varLhs ? 1 : -1)); + Node roff = nm->mkNode( + PLUS, r_add, nm->mkConstInt(Rational(varLhs ? 1 : -1))); + rdl.d_val.push_back(roff); } } }