Eliminate more uses of CONST_RATIONAL (#7816)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 17 Dec 2021 03:47:43 +0000 (21:47 -0600)
committerGitHub <noreply@github.com>
Fri, 17 Dec 2021 03:47:43 +0000 (03:47 +0000)
14 files changed:
src/parser/smt2/smt2.cpp
src/parser/smt2/smt2.h
src/preprocessing/passes/real_to_int.cpp
src/preprocessing/passes/unconstrained_simplifier.cpp
src/theory/arith/arith_poly_norm.h
src/theory/arith/arith_utilities.cpp
src/theory/arith/arith_utilities.h
src/theory/bags/bag_reduction.cpp
src/theory/fp/theory_fp.cpp
src/theory/fp/theory_fp_rewriter.cpp
src/theory/quantifiers/cegqi/ceg_arith_instantiator.cpp
src/theory/quantifiers/cegqi/ceg_instantiator.cpp
src/theory/quantifiers/ematching/inst_match_generator.cpp
src/theory/quantifiers/ematching/relational_match_generator.cpp

index 4e1a8aae8486d80fd19a8b0201ae42eb37231bf0..1fca42634535198f4cdd17cf2e5af50bc65bf67b 100644 (file)
@@ -1013,21 +1013,22 @@ api::Term Smt2::applyParseOp(ParseOp& p, std::vector<api::Term>& args)
     // integer constants. We must ensure numerator and denominator are
     // constant and the denominator is non-zero. A similar issue happens for
     // negative integers and reals, with unary minus.
+    // NOTE this should be applied more eagerly when UMINUS/DIVISION is
+    // constructed.
     bool isNeg = false;
     if (constVal.getKind() == api::UMINUS)
     {
       isNeg = true;
       constVal = constVal[0];
     }
-    if (constVal.getKind() == api::DIVISION
-        && constVal[0].getKind() == api::CONST_RATIONAL
-        && constVal[1].getKind() == api::CONST_RATIONAL)
+    if (constVal.getKind() == api::DIVISION && isConstInt(constVal[0])
+        && isConstInt(constVal[1]))
     {
       std::stringstream sdiv;
       sdiv << (isNeg ? "-" : "") << constVal[0] << "/" << constVal[1];
       constVal = d_solver->mkReal(sdiv.str());
     }
-    else if (constVal.getKind() == api::CONST_RATIONAL && isNeg)
+    else if (isConstInt(constVal) && isNeg)
     {
       std::stringstream sneg;
       sneg << "-" << constVal;
@@ -1229,7 +1230,7 @@ void Smt2::notifyNamedExpression(api::Term& expr, std::string name)
   setLastNamedTerm(expr, name);
 }
 
-api::Term Smt2::mkAnd(const std::vector<api::Term>& es)
+api::Term Smt2::mkAnd(const std::vector<api::Term>& es) const
 {
   if (es.size() == 0)
   {
@@ -1239,10 +1240,15 @@ api::Term Smt2::mkAnd(const std::vector<api::Term>& es)
   {
     return es[0];
   }
-  else
-  {
-    return d_solver->mkTerm(api::AND, es);
-  }
+  return d_solver->mkTerm(api::AND, es);
+}
+
+bool Smt2::isConstInt(const api::Term& t)
+{
+  api::Kind k = t.getKind();
+  // !!! Note when arithmetic subtyping is eliminated, this will update to
+  // CONST_INTEGER.
+  return k == api::CONST_RATIONAL;
 }
 
 }  // namespace parser
index 58a20cb2790a1f974c7b926e4a4afcebd5e28eaf..6df62d787d819e8b13ecf57cd199c8b25cb642cb 100644 (file)
@@ -422,7 +422,11 @@ class Smt2 : public Parser
    * @return True if `es` is empty, `e` if `es` consists of a single element
    *         `e`, the conjunction of expressions otherwise.
    */
-  api::Term mkAnd(const std::vector<api::Term>& es);
+  api::Term mkAnd(const std::vector<api::Term>& es) const;
+  /**
+   * Is term t a constant integer?
+   */
+  static bool isConstInt(const api::Term& t);
 }; /* class Smt2 */
 
 }  // namespace parser
index d2cde7b46d3fe659c92c3792cdf5f83e4d67c875..ef9077770d559a51f029f6e8ccbb9daaac26ed00 100644 (file)
@@ -98,9 +98,9 @@ Node RealToInt::realToIntInternal(TNode n, NodeMap& cache, std::vector<Node>& va
               Node s;
               if (c.isNull())
               {
-                c = cc.isNull() ? NodeManager::currentNM()->mkConst(
-                        CONST_RATIONAL, Rational(1))
-                                : cc;
+                c = cc.isNull()
+                        ? NodeManager::currentNM()->mkConstInt(Rational(1))
+                        : cc;
               }
               else
               {
index 027be232ba36d65f1b913dfc5201fa2a8f32a307..7a58fc231cf25e1640487ae84170b5591b9ba505 100644 (file)
@@ -530,8 +530,7 @@ void UnconstrainedSimplifier::processUnconstrained()
             else
             {
               // TODO(#2377): could build ITE here
-              Node test =
-                  other.eqNode(nm->mkConst(CONST_RATIONAL, Rational(0)));
+              Node test = other.eqNode(nm->mkConstReal(Rational(0)));
               if (rewrite(test) != nm->mkConst<bool>(false))
               {
                 break;
index 9c3cbcf958540d9264aabccf94a05ca19ebc70d2..fafa94ee345631fa4d15c75485a2f5438f669f15 100644 (file)
@@ -40,7 +40,7 @@ class PolyNorm
    */
   void addMonomial(TNode x, const Rational& c, bool isNeg = false);
   /**
-   * Multiply this polynomial by the monomial x*c, where c is a CONST_RATIONAL.
+   * Multiply this polynomial by the monomial x*c, where c is a constant.
    * If x is null, then x*c is treated as c.
    */
   void multiplyMonomial(TNode x, const Rational& c);
index a9fd970796588a516a112f177058aeae16ed511b..6f43cfc1bf5453dc7b3d53ef8c30745e64a83837 100644 (file)
@@ -319,6 +319,22 @@ Node negateProofLiteral(TNode n)
   }
 }
 
+Node multConstants(const Node& c1, const Node& c2)
+{
+  Assert(!c1.isNull() && c1.isConst());
+  Assert(!c2.isNull() && c2.isConst());
+  NodeManager* nm = NodeManager::currentNM();
+  // real type if either has type real
+  TypeNode tn = c1.getType();
+  if (tn.isInteger())
+  {
+    tn = c2.getType();
+  }
+  Assert(tn.isRealOrInt());
+  return nm->mkConstRealOrInt(
+      tn, Rational(c1.getConst<Rational>() * c2.getConst<Rational>()));
+}
+
 }  // namespace arith
 }  // namespace theory
 }  // namespace cvc5
index b926af2e08756ed5954a22ac590fc7a09f05f897..027f7a65aea836129009ab9420ffaa6e7e8fc20c 100644 (file)
@@ -332,6 +332,12 @@ Rational greatestIntLessThan(const Rational&);
 /** Negates a node in arithmetic proof normal form. */
 Node negateProofLiteral(TNode n);
 
+/**
+ * Return the result of multiplying constant integer or real nodes c1 and c2.
+ * The returned type is real if either have type real.
+ */
+Node multConstants(const Node& c1, const Node& c2);
+
 }  // namespace arith
 }  // namespace theory
 }  // namespace cvc5
index 3e6544882212e45e551c55f2527067b8a64c3a42..a6895e4b12b742a849e8a8b2082b0da24c95df98 100644 (file)
@@ -134,8 +134,8 @@ Node BagReduction::reduceCardOperator(Node node, std::vector<Node>& asserts)
   NodeManager* nm = NodeManager::currentNM();
   SkolemManager* sm = nm->getSkolemManager();
   Node A = node[0];
-  Node zero = nm->mkConst(CONST_RATIONAL, Rational(0));
-  Node one = nm->mkConst(CONST_RATIONAL, Rational(1));
+  Node zero = nm->mkConstInt(Rational(0));
+  Node one = nm->mkConstInt(Rational(1));
   // types
   TypeNode bagType = A.getType();
   TypeNode elementType = A.getType().getBagElementType();
index 972fac5a34d8f6506e6694c8d5f7e8d75ddba7ff..bcbd9e297324ef370d5d763dd44ff396271570a6 100644 (file)
@@ -364,7 +364,7 @@ bool TheoryFp::refineAbstraction(TheoryModel *m, TNode abstract, TNode concrete)
         Node realValueOfAbstract =
             rewrite(nm->mkNode(kind::FLOATINGPOINT_TO_REAL_TOTAL,
                                abstractValue,
-                               nm->mkConst(CONST_RATIONAL, Rational(0U))));
+                               nm->mkConstReal(Rational(0U))));
 
         Node bg = nm->mkNode(
             kind::IMPLIES,
@@ -570,8 +570,7 @@ void TheoryFp::registerTerm(TNode node)
     Node z = nm->mkNode(
         kind::IMPLIES,
         nm->mkNode(kind::FLOATINGPOINT_ISZ, node[0]),
-        nm->mkNode(
-            kind::EQUAL, node, nm->mkConst(CONST_RATIONAL, Rational(0U))));
+        nm->mkNode(kind::EQUAL, node, nm->mkConstReal(Rational(0U))));
     handleLemma(z, InferenceId::FP_REGISTER_TERM);
     return;
 
@@ -592,8 +591,7 @@ void TheoryFp::registerTerm(TNode node)
 
     Node z = nm->mkNode(
         kind::IMPLIES,
-        nm->mkNode(
-            kind::EQUAL, node[1], nm->mkConst(CONST_RATIONAL, Rational(0U))),
+        nm->mkNode(kind::EQUAL, node[1], nm->mkConstReal(Rational(0U))),
         nm->mkNode(kind::EQUAL,
                    node,
                    nm->mkConst(FloatingPoint::makeZero(
index 32c3cff41219c97fc5a86cfa7c3645f13982b2d3..779d02ab379ef1fa8f33f250ca9736adaaaf50f7 100644 (file)
@@ -915,7 +915,7 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite)
     FloatingPoint::PartialRational res(arg.convertToRational());
 
     if (res.second) {
-      Node lit = NodeManager::currentNM()->mkConst(CONST_RATIONAL, res.first);
+      Node lit = NodeManager::currentNM()->mkConstReal(res.first);
       return RewriteResponse(REWRITE_DONE, lit);
     } else {
       // Can't constant fold the underspecified case
@@ -998,14 +998,14 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite)
       Rational partialValue(node[1].getConst<Rational>());
 
       Rational folded(arg.convertToRationalTotal(partialValue));
-      Node lit = NodeManager::currentNM()->mkConst(CONST_RATIONAL, folded);
+      Node lit = NodeManager::currentNM()->mkConstReal(folded);
       return RewriteResponse(REWRITE_DONE, lit);
 
     } else {
       FloatingPoint::PartialRational res(arg.convertToRational());
 
       if (res.second) {
-        Node lit = NodeManager::currentNM()->mkConst(CONST_RATIONAL, res.first);
+        Node lit = NodeManager::currentNM()->mkConstReal(res.first);
         return RewriteResponse(REWRITE_DONE, lit);
       } else {
        // Can't constant fold the underspecified case
index 56debbbac029beea2e559fc65a3f4be084c79295..ecf2d9a484426ace53530d8d6e9f7768a745f431 100644 (file)
@@ -36,8 +36,8 @@ namespace quantifiers {
 ArithInstantiator::ArithInstantiator(Env& env, TypeNode tn, VtsTermCache* vtc)
     : Instantiator(env, tn), d_vtc(vtc)
 {
-  d_zero = NodeManager::currentNM()->mkConst(CONST_RATIONAL, Rational(0));
-  d_one = NodeManager::currentNM()->mkConst(CONST_RATIONAL, Rational(1));
+  d_zero = NodeManager::currentNM()->mkConstRealOrInt(tn, Rational(0));
+  d_one = NodeManager::currentNM()->mkConstRealOrInt(tn, Rational(1));
 }
 
 void ArithInstantiator::reset(CegInstantiator* ci,
@@ -185,8 +185,7 @@ bool ArithInstantiator::processAssertion(CegInstantiator* ci,
           uval = nm->mkNode(
               PLUS,
               val,
-              nm->mkConst(CONST_RATIONAL,
-                          Rational(isUpperBoundCTT(uires) ? 1 : -1)));
+              nm->mkConstInt(Rational(isUpperBoundCTT(uires) ? 1 : -1)));
           uval = rewrite(uval);
         }
         else
@@ -253,11 +252,10 @@ bool ArithInstantiator::processAssertion(CegInstantiator* ci,
       if (d_type.isInteger())
       {
         uires = is_upper ? CEG_TT_LOWER : CEG_TT_UPPER;
-        uval =
-            nm->mkNode(PLUS,
-                       val,
-                       nm->mkConst(CONST_RATIONAL,
-                                   Rational(isUpperBoundCTT(uires) ? 1 : -1)));
+        uval = nm->mkNode(
+            PLUS,
+            val,
+            nm->mkConstInt(Rational(isUpperBoundCTT(uires) ? 1 : -1)));
         uval = rewrite(uval);
       }
       else
@@ -278,8 +276,8 @@ bool ArithInstantiator::processAssertion(CegInstantiator* ci,
     {
       if (options().quantifiers.cegqiModel)
       {
-        Node delta_coeff = nm->mkConst(
-            CONST_RATIONAL, Rational(isUpperBoundCTT(uires) ? 1 : -1));
+        Node delta_coeff = nm->mkConstRealOrInt(
+            d_type, Rational(isUpperBoundCTT(uires) ? 1 : -1));
         if (vts_coeff_delta.isNull())
         {
           vts_coeff_delta = delta_coeff;
@@ -455,9 +453,8 @@ bool ArithInstantiator::processAssertions(CegInstantiator* ci,
             Assert(d_mbp_coeff[rr][j].isConst());
             value[t] = nm->mkNode(
                 MULT,
-                nm->mkConst(
-                    CONST_RATIONAL,
-                    Rational(1) / d_mbp_coeff[rr][j].getConst<Rational>()),
+                nm->mkConstReal(Rational(1)
+                                / d_mbp_coeff[rr][j].getConst<Rational>()),
                 value[t]);
             value[t] = rewrite(value[t]);
           }
@@ -611,10 +608,9 @@ bool ArithInstantiator::processAssertions(CegInstantiator* ci,
       }
       else
       {
-        val =
-            nm->mkNode(MULT,
-                       nm->mkNode(PLUS, vals[0], vals[1]),
-                       nm->mkConst(CONST_RATIONAL, Rational(1) / Rational(2)));
+        val = nm->mkNode(MULT,
+                         nm->mkNode(PLUS, vals[0], vals[1]),
+                         nm->mkConstReal(Rational(1) / Rational(2)));
         val = rewrite(val);
       }
     }
@@ -809,7 +805,7 @@ CegTermType ArithInstantiator::solve_arith(CegInstantiator* ci,
         vts_coeff[t] = itminf->second;
         if (vts_coeff[t].isNull())
         {
-          vts_coeff[t] = nm->mkConst(CONST_RATIONAL, Rational(1));
+          vts_coeff[t] = nm->mkConstRealOrInt(d_type, Rational(1));
         }
         // negate if coefficient on variable is positive
         std::map<Node, Node>::iterator itv = msum.find(pv);
@@ -826,8 +822,8 @@ CegTermType ArithInstantiator::solve_arith(CegInstantiator* ci,
             {
               vts_coeff[t] = nm->mkNode(
                   MULT,
-                  nm->mkConst(CONST_RATIONAL,
-                              Rational(-1) / itv->second.getConst<Rational>()),
+                  nm->mkConstReal(Rational(-1)
+                                  / itv->second.getConst<Rational>()),
                   vts_coeff[t]);
               vts_coeff[t] = rewrite(vts_coeff[t]);
             }
@@ -887,7 +883,7 @@ CegTermType ArithInstantiator::solve_arith(CegInstantiator* ci,
       }
     }
     // multiply everything by this coefficient
-    Node rcoeff = nm->mkConst(CONST_RATIONAL, Rational(coeff));
+    Node rcoeff = nm->mkConstInt(Rational(coeff));
     std::vector<Node> real_part;
     for (std::map<Node, Node>::iterator it = msum.begin(); it != msum.end();
          ++it)
index 81ae18f4fa257350c898d82f9706a8d75c6dc8ee..45ac899e1ebad4ba605c37f7ec2977b37d8b5b17 100644 (file)
@@ -20,6 +20,7 @@
 #include "expr/node_algorithm.h"
 #include "options/quantifiers_options.h"
 #include "theory/arith/arith_msum.h"
+#include "theory/arith/arith_utilities.h"
 #include "theory/quantifiers/cegqi/ceg_arith_instantiator.h"
 #include "theory/quantifiers/cegqi/ceg_bv_instantiator.h"
 #include "theory/quantifiers/cegqi/ceg_dt_instantiator.h"
@@ -137,10 +138,7 @@ void TermProperties::composeProperty(TermProperties& p)
   }
   else
   {
-    NodeManager* nm = NodeManager::currentNM();
-    d_coeff = nm->mkConst(CONST_RATIONAL,
-                          Rational(d_coeff.getConst<Rational>()
-                                   * p.d_coeff.getConst<Rational>()));
+    d_coeff = arith::multConstants(d_coeff, p.d_coeff);
   }
 }
 
@@ -163,12 +161,7 @@ void SolvedForm::push_back(Node pv, Node n, TermProperties& pv_prop)
   }
   else
   {
-    Assert(new_theta.isConst());
-    Assert(pv_prop.d_coeff.isConst());
-    NodeManager* nm = NodeManager::currentNM();
-    new_theta = nm->mkConst(CONST_RATIONAL,
-                            Rational(new_theta.getConst<Rational>()
-                                     * pv_prop.d_coeff.getConst<Rational>()));
+    new_theta = arith::multConstants(new_theta, pv_prop.d_coeff);
   }
   d_theta.push_back(new_theta);
 }
index e3dd246a7cc4d8d8b9fefe518c8d1d88b3a6b12a..ab4bbc91b60dda91d19ba4d4dcab9aa346c5aa4c 100644 (file)
@@ -364,8 +364,8 @@ int InstMatchGenerator::getMatch(Node f, Node t, InstMatch& m)
     {
       if (pat.getKind() == GT)
       {
-        t_match =
-            nm->mkNode(MINUS, t, nm->mkConst(CONST_RATIONAL, Rational(1)));
+        t_match = nm->mkNode(
+            MINUS, t, nm->mkConstRealOrInt(t.getType(), Rational(1)));
       }else{
         t_match = t;
       }
@@ -374,20 +374,21 @@ int InstMatchGenerator::getMatch(Node f, Node t, InstMatch& m)
     {
       if (pat.getKind() == EQUAL)
       {
-        if (t.getType().isBoolean())
+        TypeNode tn = t.getType();
+        if (tn.isBoolean())
         {
           t_match = nm->mkConst(!d_qstate.areEqual(nm->mkConst(true), t));
         }
         else
         {
-          Assert(t.getType().isRealOrInt());
-          t_match =
-              nm->mkNode(PLUS, t, nm->mkConst(CONST_RATIONAL, Rational(1)));
+          Assert(tn.isRealOrInt());
+          t_match = nm->mkNode(PLUS, t, nm->mkConstRealOrInt(tn, Rational(1)));
         }
       }
       else if (pat.getKind() == GEQ)
       {
-        t_match = nm->mkNode(PLUS, t, nm->mkConst(CONST_RATIONAL, Rational(1)));
+        t_match =
+            nm->mkNode(PLUS, t, nm->mkConstRealOrInt(t.getType(), Rational(1)));
       }
       else if (pat.getKind() == GT)
       {
index 5cf9079e8e7c719684f5d30a11a3d70d3af38a27..6ec3334cdf3b7e7c405d440245a58291c498e566 100644 (file)
@@ -97,7 +97,7 @@ int RelationalMatchGenerator::getNextMatch(Node q, InstMatch& m)
       s = nm->mkNode(
           PLUS,
           s,
-          nm->mkConst(CONST_RATIONAL, Rational(d_rel == GEQ ? -1 : 1)));
+          nm->mkConstRealOrInt(s.getType(), Rational(d_rel == GEQ ? -1 : 1)));
     }
     d_counter++;
     Trace("relational-match-gen")