Eliminate use of rewrite, CONST_RATIONAL in ArithMSum (#7808)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 14 Dec 2021 17:14:04 +0000 (11:14 -0600)
committerGitHub <noreply@github.com>
Tue, 14 Dec 2021 17:14:04 +0000 (11:14 -0600)
src/theory/arith/arith_msum.cpp
src/theory/arith/arith_msum.h
src/theory/arith/arith_rewriter.cpp
src/theory/quantifiers/cegqi/ceg_arith_instantiator.cpp
src/theory/quantifiers/cegqi/ceg_arith_instantiator.h
src/theory/quantifiers/fmf/bounded_integers.cpp
src/theory/quantifiers/relevant_domain.cpp

index a8edb0e7908bf06ece2a49016589d04b14486c69..0621c139150d50a53716179779c909a9b3bee050 100644 (file)
@@ -81,7 +81,8 @@ bool ArithMSum::getMonomialSum(Node n, std::map<Node, Node>& msum)
 
 bool ArithMSum::getMonomialSumLit(Node lit, std::map<Node, Node>& msum)
 {
-  if (lit.getKind() == GEQ || lit.getKind() == EQUAL)
+  if (lit.getKind() == GEQ
+      || (lit.getKind() == EQUAL && lit[0].getType().isRealOrInt()))
   {
     if (getMonomialSum(lit[0], msum))
     {
@@ -96,6 +97,7 @@ bool ArithMSum::getMonomialSumLit(Node lit, std::map<Node, Node>& msum)
         NodeManager* nm = NodeManager::currentNM();
         if (getMonomialSum(lit[1], msum2))
         {
+          TypeNode tn = lit[0].getType();
           for (std::map<Node, Node>::iterator it = msum2.begin();
                it != msum2.end();
                ++it)
@@ -103,20 +105,20 @@ bool ArithMSum::getMonomialSumLit(Node lit, std::map<Node, Node>& msum)
             std::map<Node, Node>::iterator it2 = msum.find(it->first);
             if (it2 != msum.end())
             {
-              Node r = nm->mkNode(MINUS,
-                                  it2->second.isNull()
-                                      ? nm->mkConst(CONST_RATIONAL, Rational(1))
-                                      : it2->second,
-                                  it->second.isNull()
-                                      ? nm->mkConst(CONST_RATIONAL, Rational(1))
-                                      : it->second);
-              msum[it->first] = Rewriter::rewrite(r);
+              Rational r1 = it2->second.isNull()
+                                ? Rational(1)
+                                : it2->second.getConst<Rational>();
+              Rational r2 = it->second.isNull()
+                                ? Rational(1)
+                                : it->second.getConst<Rational>();
+              msum[it->first] = nm->mkConstRealOrInt(tn, r1 - r2);
             }
             else
             {
               msum[it->first] = it->second.isNull()
-                                    ? nm->mkConst(CONST_RATIONAL, Rational(-1))
-                                    : negate(it->second);
+                                    ? nm->mkConstRealOrInt(tn, Rational(-1))
+                                    : nm->mkConstRealOrInt(
+                                          tn, -it->second.getConst<Rational>());
             }
           }
           return true;
@@ -127,7 +129,7 @@ bool ArithMSum::getMonomialSumLit(Node lit, std::map<Node, Node>& msum)
   return false;
 }
 
-Node ArithMSum::mkNode(const std::map<Node, Node>& msum)
+Node ArithMSum::mkNode(TypeNode tn, const std::map<Node, Node>& msum)
 {
   NodeManager* nm = NodeManager::currentNM();
   std::vector<Node> children;
@@ -146,10 +148,10 @@ Node ArithMSum::mkNode(const std::map<Node, Node>& msum)
     }
     children.push_back(m);
   }
-  return children.size() > 1 ? nm->mkNode(PLUS, children)
-                             : (children.size() == 1
-                                    ? children[0]
-                                    : nm->mkConst(CONST_RATIONAL, Rational(0)));
+  return children.size() > 1
+             ? nm->mkNode(PLUS, children)
+             : (children.size() == 1 ? children[0]
+                                     : nm->mkConstRealOrInt(tn, Rational(0)));
 }
 
 int ArithMSum::isolate(
@@ -159,11 +161,13 @@ int ArithMSum::isolate(
   std::map<Node, Node>::const_iterator itv = msum.find(v);
   if (itv != msum.end())
   {
+    NodeManager* nm = NodeManager::currentNM();
     std::vector<Node> children;
     Rational r =
         itv->second.isNull() ? Rational(1) : itv->second.getConst<Rational>();
     if (r.sgn() != 0)
     {
+      TypeNode vtn = v.getType();
       for (std::map<Node, Node>::const_iterator it = msum.begin();
            it != msum.end();
            ++it)
@@ -182,27 +186,25 @@ int ArithMSum::isolate(
           children.push_back(m);
         }
       }
-      val = children.size() > 1
-                ? NodeManager::currentNM()->mkNode(PLUS, children)
-                : (children.size() == 1 ? children[0]
-                                        : NodeManager::currentNM()->mkConst(
-                                            CONST_RATIONAL, Rational(0)));
+      val =
+          children.size() > 1
+              ? nm->mkNode(PLUS, children)
+              : (children.size() == 1 ? children[0]
+                                      : nm->mkConstRealOrInt(vtn, Rational(0)));
       if (!r.isOne() && !r.isNegativeOne())
       {
-        if (v.getType().isInteger())
+        if (vtn.isInteger())
         {
-          veq_c = NodeManager::currentNM()->mkConst(CONST_RATIONAL, r.abs());
+          veq_c = nm->mkConstInt(r.abs());
         }
         else
         {
-          val = NodeManager::currentNM()->mkNode(
-              MULT,
-              val,
-              NodeManager::currentNM()->mkConst(CONST_RATIONAL,
-                                                Rational(1) / r.abs()));
+          val = nm->mkNode(MULT, val, nm->mkConstReal(Rational(1) / r.abs()));
         }
       }
-      val = r.sgn() == 1 ? negate(val) : Rewriter::rewrite(val);
+      val = r.sgn() == 1
+                ? nm->mkNode(MULT, nm->mkConstRealOrInt(vtn, Rational(-1)), val)
+                : val;
       return (r.sgn() == 1 || k == EQUAL) ? 1 : -1;
     }
   }
@@ -284,29 +286,13 @@ bool ArithMSum::decompose(Node n, Node v, Node& coeff, Node& rem)
     {
       coeff = it->second;
       msum.erase(v);
-      rem = mkNode(msum);
+      rem = mkNode(n.getType(), msum);
       return true;
     }
   }
   return false;
 }
 
-Node ArithMSum::negate(Node t)
-{
-  Node tt = NodeManager::currentNM()->mkNode(
-      MULT, NodeManager::currentNM()->mkConst(CONST_RATIONAL, Rational(-1)), t);
-  tt = Rewriter::rewrite(tt);
-  return tt;
-}
-
-Node ArithMSum::offset(Node t, int i)
-{
-  Node tt = NodeManager::currentNM()->mkNode(
-      PLUS, NodeManager::currentNM()->mkConst(CONST_RATIONAL, Rational(i)), t);
-  tt = Rewriter::rewrite(tt);
-  return tt;
-}
-
 void ArithMSum::debugPrintMonomialSum(std::map<Node, Node>& msum, const char* c)
 {
   for (std::map<Node, Node>::iterator it = msum.begin(); it != msum.end(); ++it)
index 87f56e64f8f106f625a3252fc2c4685be0ff6fe1..ae57ee1cb9953c67fe2142d849fb9c51da48ecc9 100644 (file)
@@ -103,8 +103,13 @@ class ArithMSum
    *
    * Make the Node corresponding to the interpretation of msum, [msum], where:
    *   [msum] = sum_{( v, c ) \in msum } [c]*[v]
+   *
+   * @param tn The type of the node to return, which is used only if msum is
+   * empty
+   * @param msum The monomial sum
+   * @return The node corresponding to the monomial sum
    */
-  static Node mkNode(const std::map<Node, Node>& msum);
+  static Node mkNode(TypeNode tn, const std::map<Node, Node>& msum);
 
   /** make coefficent term
    *
@@ -173,12 +178,6 @@ class ArithMSum
   */
   static bool decompose(Node n, Node v, Node& coeff, Node& rem);
 
-  /** return the rewritten form of (UMINUS t) */
-  static Node negate(Node t);
-
-  /** return the rewritten form of (PLUS t (CONST_RATIONAL i)) */
-  static Node offset(Node t, int i);
-
   /** debug print for a monmoial sum, prints to Trace(c) */
   static void debugPrintMonomialSum(std::map<Node, Node>& msum, const char* c);
 };
index 4c01f25f3e3c759d1ffef7c09a8282d18833c66d..af6f23c1fd2ab3bea34ea2c40ebc36d9bbc99d13 100644 (file)
@@ -516,7 +516,7 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) {
           msum.erase(pi);
           if (!msum.empty())
           {
-            rem = ArithMSum::mkNode(msum);
+            rem = ArithMSum::mkNode(t[0].getType(), msum);
           }
         }
       }
index 2d483d50225eb8c805b2ec2bf57143478aba3d44..56debbbac029beea2e559fc65a3f4be084c79295 100644 (file)
@@ -818,7 +818,7 @@ CegTermType ArithInstantiator::solve_arith(CegInstantiator* ci,
           // multiply by the coefficient we will isolate for
           if (itv->second.isNull())
           {
-            vts_coeff[t] = ArithMSum::negate(vts_coeff[t]);
+            vts_coeff[t] = negate(vts_coeff[t]);
           }
           else
           {
@@ -833,7 +833,7 @@ CegTermType ArithInstantiator::solve_arith(CegInstantiator* ci,
             }
             else if (itv->second.getConst<Rational>().sgn() == 1)
             {
-              vts_coeff[t] = ArithMSum::negate(vts_coeff[t]);
+              vts_coeff[t] = negate(vts_coeff[t]);
             }
           }
         }
@@ -1040,6 +1040,13 @@ Node ArithInstantiator::getModelBasedProjectionValue(CegInstantiator* ci,
   return val;
 }
 
+Node ArithInstantiator::negate(const Node& t) const
+{
+  NodeManager* nm = NodeManager::currentNM();
+  return rewrite(
+      nm->mkNode(MULT, nm->mkConstRealOrInt(t.getType(), Rational(-1)), t));
+}
+
 }  // namespace quantifiers
 }  // namespace theory
 }  // namespace cvc5
index e102b834e2904ddbcee964e204818cc9b39e1763..d44ab4993d0a3cebc20e5368665a23a922cf29e0 100644 (file)
@@ -206,6 +206,8 @@ class ArithInstantiator : public Instantiator
                                     Node theta,
                                     Node inf_coeff,
                                     Node delta_coeff);
+  /** Return the rewritten form of the negation of t */
+  Node negate(const Node& t) const;
 };
 
 }  // namespace quantifiers
index 18a63d245c7f3ad1bfd12a7b20d0b14063ff8f36..5c0283863ed2ae10d858b1b3d5738a8dd4cda0a1 100644 (file)
@@ -223,6 +223,7 @@ void BoundedIntegers::process( Node q, Node n, bool pol,
       std::map< Node, Node > msum;
       if (ArithMSum::getMonomialSumLit(n, msum))
       {
+        NodeManager* nm = NodeManager::currentNM();
         Trace("bound-int-debug") << "literal (polarity = " << pol << ") " << n << " is monomial sum : " << std::endl;
         ArithMSum::debugPrintMonomialSum(msum, "bound-int-debug");
         for( std::map< Node, Node >::iterator it = msum.begin(); it != msum.end(); ++it ){
@@ -239,11 +240,11 @@ void BoundedIntegers::process( Node q, Node n, bool pol,
                   n1 = veq[1];
                   n2 = veq[0];
                   if( n1.getKind()==BOUND_VARIABLE ){
-                    n2 = ArithMSum::offset(n2, 1);
+                    n2 = nm->mkNode(PLUS, n2, nm->mkConstInt(Rational(1)));
                   }else{
-                    n1 = ArithMSum::offset(n1, -1);
+                    n1 = nm->mkNode(PLUS, n1, nm->mkConstInt(Rational(-1)));
                   }
-                  veq = NodeManager::currentNM()->mkNode( GEQ, n1, n2 );
+                  veq = nm->mkNode(GEQ, n1, n2);
                 }
                 Trace("bound-int-debug") << "Isolated for " << it->first << " : (" << n1 << " >= " << n2 << ")" << std::endl;
                 Node t = n1==it->first ? n2 : n1;
index 0f3699990992b19e87270e36c7cb60c42a27f248..f0684f04a1e8c5a2ac7364d8367d432e400c1dc7 100644 (file)
@@ -24,6 +24,7 @@
 #include "theory/quantifiers/term_database.h"
 #include "theory/quantifiers/term_registry.h"
 #include "theory/quantifiers/term_util.h"
+#include "util/rational.h"
 
 using namespace cvc5::kind;
 
@@ -301,6 +302,7 @@ void RelevantDomain::computeRelevantDomainOpCh( RDomain * rf, Node n ) {
 
 void RelevantDomain::computeRelevantDomainLit( Node q, bool hasPol, bool pol, Node n ) {
   if( d_rel_dom_lit[hasPol][pol].find( n )==d_rel_dom_lit[hasPol][pol].end() ){
+    NodeManager* nm = NodeManager::currentNM();
     RDomainLit& rdl = d_rel_dom_lit[hasPol][pol][n];
     rdl.d_merge = false;
     int varCount = 0;
@@ -405,10 +407,14 @@ void RelevantDomain::computeRelevantDomainLit( Node q, bool hasPol, bool pol, No
       if( ( !hasPol || pol ) && n[0].getType().isInteger() ){
         if( n.getKind()==EQUAL ){
           for( unsigned i=0; i<2; i++ ){
-            rdl.d_val.push_back(ArithMSum::offset(r_add, i == 0 ? 1 : -1));
+            Node roff = nm->mkNode(
+                PLUS, r_add, nm->mkConstInt(Rational(i == 0 ? 1 : -1)));
+            rdl.d_val.push_back(roff);
           }
         }else if( n.getKind()==GEQ ){
-          rdl.d_val.push_back(ArithMSum::offset(r_add, varLhs ? 1 : -1));
+          Node roff = nm->mkNode(
+              PLUS, r_add, nm->mkConstInt(Rational(varLhs ? 1 : -1)));
+          rdl.d_val.push_back(roff);
         }
       }
     }