More cleaning uses of arithmetic subtyping (#8595)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Mon, 11 Apr 2022 22:13:46 +0000 (17:13 -0500)
committerGitHub <noreply@github.com>
Mon, 11 Apr 2022 22:13:46 +0000 (22:13 +0000)
Towards eliminating arithmetic subtyping.

src/theory/arith/arith_utilities.cpp
src/theory/arith/arith_utilities.h
src/theory/arith/bound_inference.cpp
src/theory/arith/congruence_manager.cpp
src/theory/arith/constraint.cpp
src/theory/arith/nl/ext/monomial_check.cpp
src/theory/arith/theory_arith_private.cpp
src/theory/quantifiers/sygus/sygus_grammar_cons.cpp
src/theory/rewriter.cpp

index b0709c1b35a931b18966ba652289ba9f814767f7..d6bec5b9aee61ca2b6c6d5a8516913874561f083 100644 (file)
@@ -99,6 +99,11 @@ Kind transKinds(Kind k1, Kind k2)
   return UNDEFINED_KIND;
 }
 
+Node mkZero(const TypeNode& tn)
+{
+  return NodeManager::currentNM()->mkConstRealOrInt(tn, 0);
+}
+
 bool isZero(const Node& n)
 {
   Assert(n.getType().isRealOrInt());
index 98afea98f9446a0be05eda4dc25421f39b828684..c1412257c2351827e5a4402d39964bb91672c876 100644 (file)
@@ -222,12 +222,18 @@ inline Node flattenAnd(Node n){
   return NodeManager::currentNM()->mkNode(kind::AND, out);
 }
 
+/** Make zero of the given type */
+Node mkZero(const TypeNode& tn);
+
+/** Is n (integer or real) zero? */
+bool isZero(const Node& n);
+
 // Returns an node that is the identity of a select few kinds.
 inline Node getIdentityType(const TypeNode& tn, Kind k)
 {
   switch (k)
   {
-    case kind::ADD: return NodeManager::currentNM()->mkConstRealOrInt(tn, 0);
+    case kind::ADD: return mkZero(tn);
     case kind::MULT:
     case kind::NONLINEAR_MULT:
       return NodeManager::currentNM()->mkConstRealOrInt(tn, 1);
@@ -277,14 +283,14 @@ inline Node mkInRange(Node term, Node start, Node end) {
 // when n is 0 or not. Useful for division by 0 logic.
 //   (ite (= n 0) (= q if_zero) (= q not_zero))
 inline Node mkOnZeroIte(Node n, Node q, Node if_zero, Node not_zero) {
-  Node zero = NodeManager::currentNM()->mkConstRealOrInt(n.getType(), 0);
+  Node zero = mkZero(n.getType());
   return n.eqNode(zero).iteNode(q.eqNode(if_zero), q.eqNode(not_zero));
 }
 
 inline Node mkPi()
 {
-  return NodeManager::currentNM()->mkNullaryOperator(
-      NodeManager::currentNM()->realType(), kind::PI);
+  NodeManager* nm = NodeManager::currentNM();
+  return nm->mkNullaryOperator(nm->realType(), kind::PI);
 }
 /** Join kinds, where k1 and k2 are arithmetic relations returns an
  * arithmetic relation ret such that
@@ -298,9 +304,6 @@ Kind joinKinds(Kind k1, Kind k2);
  */
 Kind transKinds(Kind k1, Kind k2);
 
-/** Is n (integer or real) zero? */
-bool isZero(const Node& n);
-
 /** Is k a transcendental function kind? */
 bool isTranscendentalKind(Kind k);
 /**
index acb87b4068de6bb3c7d89f269cab4b58ac7dffba..68b3a7b2a62c7e7ed2c0c3ae410893448b66362e 100644 (file)
@@ -81,21 +81,20 @@ bool BoundInference::add(const Node& n, bool onlyVariables)
     auto* nm = NodeManager::currentNM();
     switch (relation)
     {
-      case Kind::LEQ:
-        bound = nm->mkConst<Rational>(CONST_RATIONAL, br.floor());
-        break;
+      case Kind::LEQ: bound = nm->mkConstInt(br.floor()); break;
       case Kind::LT:
-        bound = nm->mkConst<Rational>(CONST_RATIONAL, (br - 1).ceiling());
+        bound = nm->mkConstInt((br - 1).ceiling());
         relation = Kind::LEQ;
         break;
       case Kind::GT:
-        bound = nm->mkConst<Rational>(CONST_RATIONAL, (br + 1).floor());
+        bound = nm->mkConstInt((br + 1).floor());
         relation = Kind::GEQ;
         break;
-      case Kind::GEQ:
-        bound = nm->mkConst<Rational>(CONST_RATIONAL, br.ceiling());
+      case Kind::GEQ: bound = nm->mkConstInt(br.ceiling()); break;
+      default:
+        // always ensure integer
+        bound = nm->mkConstInt(br);
         break;
-      default:;
     }
     Trace("bound-inf") << "Strengthened " << n << " to " << lhs << " "
                        << relation << " " << bound << std::endl;
index 45cfe65a6b52c2fc4f2e6484783cb96b6d0d87c5..a9a09620046a5fab373e3b420835a9c85d6f457b 100644 (file)
@@ -329,14 +329,15 @@ void ArithCongruenceManager::watchedVariableCannotBeZero(ConstraintCP c){
                                         && c->getValue().sgn() > 0);
       const int cSign = scaleCNegatively ? -1 : 1;
       TNode isZero = d_watchedEqualities[s];
+      TypeNode type = isZero[0].getType();
       const auto isZeroPf = d_pnm->mkAssume(isZero);
       const auto nm = NodeManager::currentNM();
       const auto sumPf =
           d_pnm->mkNode(PfRule::MACRO_ARITH_SCALE_SUM_UB,
                         {isZeroPf, pf},
                         // Trick for getting correct, opposing signs.
-                        {nm->mkConst(CONST_RATIONAL, Rational(-1 * cSign)),
-                         nm->mkConst(CONST_RATIONAL, Rational(cSign))});
+                        {nm->mkConstRealOrInt(type, Rational(-1 * cSign)),
+                         nm->mkConstRealOrInt(type, Rational(cSign))});
       const auto botPf = d_pnm->mkNode(
           PfRule::MACRO_SR_PRED_TRANSFORM, {sumPf}, {nm->mkConst(false)});
       std::vector<Node> assumption = {isZero};
index 7344cf3efe44b1349285cd3a95b1249497ce8ebd..dc1f430c85d9d6d6da1fe27b3780fcaaef28e81a 100644 (file)
@@ -1114,6 +1114,7 @@ TrustNode Constraint::split()
   TrustNode trustedLemma;
   if (d_database->isProofEnabled())
   {
+    TypeNode type = lhs.getType();
     // Farkas proof that this works.
     auto nm = NodeManager::currentNM();
     auto nLeqPf = d_database->d_pnm->mkAssume(leqNode.negate());
@@ -1125,8 +1126,8 @@ TrustNode Constraint::split()
     auto sumPf =
         d_database->d_pnm->mkNode(PfRule::MACRO_ARITH_SCALE_SUM_UB,
                                   {gtPf, ltPf},
-                                  {nm->mkConst(CONST_RATIONAL, Rational(-1)),
-                                   nm->mkConst(CONST_RATIONAL, Rational(1))});
+                                  {nm->mkConstRealOrInt(type, Rational(-1)),
+                                   nm->mkConstRealOrInt(type, Rational(1))});
     auto botPf = d_database->d_pnm->mkNode(
         PfRule::MACRO_SR_PRED_TRANSFORM, {sumPf}, {nm->mkConst(false)});
     std::vector<Node> a = {leqNode.negate(), geqNode.negate()};
@@ -1794,9 +1795,9 @@ std::shared_ptr<ProofNode> Constraint::externalExplain(
 
           // Enumerate child proofs (negation included) in d_farkasCoefficients
           // order
+          Node plit = getNegation()->getProofLiteral();
           std::vector<std::shared_ptr<ProofNode>> farkasChildren;
-          farkasChildren.push_back(
-              pnm->mkAssume(getNegation()->getProofLiteral()));
+          farkasChildren.push_back(pnm->mkAssume(plit));
           farkasChildren.insert(
               farkasChildren.end(), children.rbegin(), children.rend());
 
@@ -1804,9 +1805,10 @@ std::shared_ptr<ProofNode> Constraint::externalExplain(
 
           // Enumerate d_farkasCoefficients as nodes.
           std::vector<Node> farkasCoeffs;
+          TypeNode type = plit[0].getType();
           for (Rational r : *getFarkasCoefficients())
           {
-            farkasCoeffs.push_back(nm->mkConst(CONST_RATIONAL, Rational(r)));
+            farkasCoeffs.push_back(nm->mkConstRealOrInt(type, Rational(r)));
           }
 
           // Apply the scaled-sum rule.
@@ -1819,7 +1821,7 @@ std::shared_ptr<ProofNode> Constraint::externalExplain(
 
           // Scope out the negated constraint, yielding a proof of the
           // constraint.
-          std::vector<Node> assump{getNegation()->getProofLiteral()};
+          std::vector<Node> assump{plit};
           auto maybeDoubleNotPf = pnm->mkScope(botPf, assump, false);
 
           // No need to ensure that the expected node aggrees with `assump`
@@ -2086,8 +2088,8 @@ Node Constraint::getProofLiteral() const
     default: Unreachable() << d_type;
   }
   NodeManager* nm = NodeManager::currentNM();
-  Node constPart =
-      nm->mkConst(CONST_RATIONAL, Rational(d_value.getNoninfinitesimalPart()));
+  Node constPart = nm->mkConstRealOrInt(
+      varPart.getType(), Rational(d_value.getNoninfinitesimalPart()));
   Node posLit = nm->mkNode(cmp, varPart, constPart);
   return neg ? posLit.negate() : posLit;
 }
@@ -2104,19 +2106,22 @@ void ConstraintDatabase::proveOr(std::vector<TrustNode>& out,
   {
     Assert(b->getNegation()->getType() != ConstraintType::Disequality);
     auto nm = NodeManager::currentNM();
+    Node alit = a->getNegation()->getProofLiteral();
+    TypeNode type = alit[0].getType();
     auto pf_neg_la = d_pnm->mkNode(PfRule::MACRO_SR_PRED_TRANSFORM,
                                    {d_pnm->mkAssume(la.negate())},
-                                   {a->getNegation()->getProofLiteral()});
+                                   {alit});
+    Node blit = b->getNegation()->getProofLiteral();
     auto pf_neg_lb = d_pnm->mkNode(PfRule::MACRO_SR_PRED_TRANSFORM,
                                    {d_pnm->mkAssume(lb.negate())},
-                                   {b->getNegation()->getProofLiteral()});
+                                   {blit});
     int sndSign = negateSecond ? -1 : 1;
     auto bot_pf = d_pnm->mkNode(
         PfRule::MACRO_SR_PRED_TRANSFORM,
         {d_pnm->mkNode(PfRule::MACRO_ARITH_SCALE_SUM_UB,
                        {pf_neg_la, pf_neg_lb},
-                       {nm->mkConst(CONST_RATIONAL, Rational(-1 * sndSign)),
-                        nm->mkConst(CONST_RATIONAL, Rational(sndSign))})},
+                       {nm->mkConstRealOrInt(type, Rational(-1 * sndSign)),
+                        nm->mkConstRealOrInt(type, Rational(sndSign))})},
         {nm->mkConst(false)});
     std::vector<Node> as;
     std::transform(orN.begin(), orN.end(), std::back_inserter(as), [](Node n) {
index 2bda60f25c10f64d78e239b10937938c9a46005c..f2cdd5ca7db0e73786e9f471fd0a52325a11d99f 100644 (file)
@@ -18,6 +18,7 @@
 #include "expr/node.h"
 #include "proof/proof.h"
 #include "theory/arith/arith_msum.h"
+#include "theory/arith/arith_utilities.h"
 #include "theory/arith/inference_manager.h"
 #include "theory/arith/nl/ext/ext_state.h"
 #include "theory/arith/nl/nl_lemma_utils.h"
@@ -302,8 +303,8 @@ int MonomialCheck::compareSign(
   {
     if (mvaoa.getConst<Rational>().sgn() != status)
     {
-      Node lemma =
-          nm->mkAnd(exp).impNode(mkLit(oa, d_data->d_zero, status * 2));
+      Node zero = mkZero(oa.getType());
+      Node lemma = nm->mkAnd(exp).impNode(mkLit(oa, zero, status * 2));
       CDProof* proof = nullptr;
       if (d_data->isProofEnabled())
       {
@@ -318,6 +319,7 @@ int MonomialCheck::compareSign(
   }
   Assert(a_index < vla.size());
   Node av = vla[a_index];
+  Node zero = mkZero(av.getType());
   unsigned aexp = d_data->d_mdb.getExponent(a, av);
   // take current sign in model
   Node mvaav = d_data->d_model.computeAbstractModelValue(av);
@@ -328,8 +330,8 @@ int MonomialCheck::compareSign(
   {
     if (mvaoa.getConst<Rational>().sgn() != 0)
     {
-      Node prem = av.eqNode(d_data->d_zero);
-      Node conc = oa.eqNode(d_data->d_zero);
+      Node prem = av.eqNode(zero);
+      Node conc = oa.eqNode(zero);
       Node lemma = prem.impNode(conc);
       CDProof* proof = nullptr;
       if (d_data->isProofEnabled())
@@ -344,10 +346,10 @@ int MonomialCheck::compareSign(
   }
   if (aexp % 2 == 0)
   {
-    exp.push_back(av.eqNode(d_data->d_zero).negate());
+    exp.push_back(av.eqNode(zero).negate());
     return compareSign(oa, a, a_index + 1, status, exp);
   }
-  exp.push_back(nm->mkNode(sgn == 1 ? Kind::GT : Kind::LT, av, d_data->d_zero));
+  exp.push_back(nm->mkNode(sgn == 1 ? Kind::GT : Kind::LT, av, zero));
   return compareSign(oa, a, a_index + 1, status * sgn, exp);
 }
 
@@ -417,9 +419,10 @@ bool MonomialCheck::compareMonomial(
       if (status == 2)
       {
         // must state that all variables are non-zero
-        for (unsigned j = 0; j < vla.size(); j++)
+        Node zero = mkZero(oa.getType());
+        for (const Node& v : vla)
         {
-          exp.push_back(vla[j].eqNode(d_data->d_zero).negate());
+          exp.push_back(v.eqNode(zero).negate());
         }
       }
       NodeManager* nm = NodeManager::currentNM();
@@ -714,6 +717,7 @@ void MonomialCheck::assignOrderIds(std::vector<Node>& vars,
 }
 Node MonomialCheck::mkLit(Node a, Node b, int status, bool isAbsolute) const
 {
+  Assert(a.getType().isComparableTo(b.getType()));
   if (status == 0)
   {
     Node a_eq_b = a.eqNode(b);
@@ -736,8 +740,9 @@ Node MonomialCheck::mkLit(Node a, Node b, int status, bool isAbsolute) const
     return nm->mkNode(greater_op, a, b);
   }
   // return nm->mkNode( greater_op, mkAbs( a ), mkAbs( b ) );
-  Node a_is_nonnegative = nm->mkNode(Kind::GEQ, a, d_data->d_zero);
-  Node b_is_nonnegative = nm->mkNode(Kind::GEQ, b, d_data->d_zero);
+  Node zero = mkZero(a.getType());
+  Node a_is_nonnegative = nm->mkNode(Kind::GEQ, a, zero);
+  Node b_is_nonnegative = nm->mkNode(Kind::GEQ, b, zero);
   Node negate_a = nm->mkNode(Kind::NEG, a);
   Node negate_b = nm->mkNode(Kind::NEG, b);
   return a_is_nonnegative.iteNode(
index 4aa3a8330ebfe60b52a5e6fdeab9d31428d7dbf5..58ec5ee19234ff436c0aa0ed848dc46296bb51c2 100644 (file)
@@ -1399,6 +1399,7 @@ TrustNode TheoryArithPrivate::dioCutting()
       NodeManager* nm = NodeManager::currentNM();
       Node gt = nm->mkNode(kind::GT, p.getNode(), c.getNode());
       Node lt = nm->mkNode(kind::LT, p.getNode(), c.getNode());
+      TypeNode type = gt[0].getType();
 
       Pf pfNotLeq = d_pnm->mkAssume(leq.getNode().negate());
       Pf pfGt =
@@ -1406,10 +1407,10 @@ TrustNode TheoryArithPrivate::dioCutting()
       Pf pfNotGeq = d_pnm->mkAssume(geq.getNode().negate());
       Pf pfLt =
           d_pnm->mkNode(PfRule::MACRO_SR_PRED_TRANSFORM, {pfNotGeq}, {lt});
-      Pf pfSum = d_pnm->mkNode(PfRule::MACRO_ARITH_SCALE_SUM_UB,
-                               {pfGt, pfLt},
-                               {nm->mkConst<Rational>(CONST_RATIONAL, -1),
-                                nm->mkConst<Rational>(CONST_RATIONAL, 1)});
+      Pf pfSum = d_pnm->mkNode(
+          PfRule::MACRO_ARITH_SCALE_SUM_UB,
+          {pfGt, pfLt},
+          {nm->mkConstRealOrInt(type, -1), nm->mkConstRealOrInt(type, 1)});
       Pf pfBot = d_pnm->mkNode(
           PfRule::MACRO_SR_PRED_TRANSFORM, {pfSum}, {nm->mkConst<bool>(false)});
       std::vector<Node> assumptions = {leq.getNode().negate(),
@@ -3748,7 +3749,7 @@ DeltaRational TheoryArithPrivate::getDeltaValue(TNode term) const
 
   switch (Kind kind = term.getKind()) {
     case kind::CONST_RATIONAL:
-      return term.getConst<Rational>();
+    case kind::CONST_INTEGER: return term.getConst<Rational>();
 
     case kind::ADD:
     {  // 2+ args
@@ -4440,12 +4441,14 @@ bool TheoryArithPrivate::rowImplicationCanBeApplied(RowIndex ridx, bool rowUp, C
       {
         // We can prove this lemma from Farkas...
         std::vector<std::shared_ptr<ProofNode>> conflictPfs;
+        Node pfLit = implied->getNegation()->getProofLiteral();
+        TypeNode type = pfLit[0].getType();
         // Assume the negated getLiteral version of the implied constaint
         // then rewrite it into proof normal form.
         conflictPfs.push_back(
             d_pnm->mkNode(PfRule::MACRO_SR_PRED_TRANSFORM,
                           {d_pnm->mkAssume(implied->getLiteral().negate())},
-                          {implied->getNegation()->getProofLiteral()}));
+                          {pfLit}));
         // Add the explaination proofs.
         for (const auto constraint : explain)
         {
@@ -4459,8 +4462,8 @@ bool TheoryArithPrivate::rowImplicationCanBeApplied(RowIndex ridx, bool rowUp, C
         std::transform(coeffs->begin(),
                        coeffs->end(),
                        std::back_inserter(farkasCoefficients),
-                       [nm](const Rational& r) {
-                         return nm->mkConst<Rational>(CONST_RATIONAL, r);
+                       [nm, type](const Rational& r) {
+                         return nm->mkConstRealOrInt(type, r);
                        });
 
         // Prove bottom.
@@ -4580,7 +4583,8 @@ std::pair<bool, Node> TheoryArithPrivate::entailmentCheck(TNode lit, const Arith
   bool successful = decomposeLiteral(lit, k, primDir, lm, lp, rm, rp, dm, dp, sep);
   if(!successful) { return make_pair(false, Node::null()); }
 
-  if(dp.getKind() == CONST_RATIONAL){
+  if (dp.isConst())
+  {
     Node eval = rewrite(lit);
     Assert(eval.getKind() == kind::CONST_BOOLEAN);
     // if true, true is an acceptable explaination
@@ -4922,11 +4926,14 @@ void TheoryArithPrivate::entailmentCheckBoundLookup(std::pair<Node, DeltaRationa
   if(sgn == 0){ return; }
 
   Assert(Polynomial::isMember(tp));
-  if(tp.getKind() == CONST_RATIONAL){
+  if (tp.isConst())
+  {
     tmp.first = mkBoolNode(true);
     tmp.second = DeltaRational(tp.getConst<Rational>());
-  }else if(d_partialModel.hasArithVar(tp)){
-    Assert(tp.getKind() != CONST_RATIONAL);
+  }
+  else if (d_partialModel.hasArithVar(tp))
+  {
+    Assert(!tp.isConst());
     ArithVar v = d_partialModel.asArithVar(tp);
     Assert(v != ARITHVAR_SENTINEL);
     ConstraintP c = (sgn > 0)
index 996a93deaa253e19435c5a3bfd7f735ecc83ab64..81c75a1763e7b7837900224e9f810f7947c39f4d 100644 (file)
@@ -783,27 +783,27 @@ void CegGrammarConstructor::mkSygusDefaultGrammar(
       {
         Trace("sygus-grammar-def")
             << "  ...create auxiliary Positive Integers grammar\n";
-        // Creating type for positive integers. Notice we can't use the any
-        // constant constructor here, since it admits zero.
+        // Creating type for positive integral reals. Notice we can't use the
+        // any constant constructor here, since it admits zero.
         std::stringstream ss;
-        ss << fun << "_PosInt";
-        std::string pos_int_name = ss.str();
+        ss << fun << "_PosIReal";
+        std::string posIRealName = ss.str();
         // make unresolved type
-        TypeNode unresPosInt = mkUnresolvedType(pos_int_name, unres);
-        unres_types.push_back(unresPosInt);
-        // make data type for positive constant integers
-        sdts.push_back(SygusDatatypeGenerator(pos_int_name));
+        TypeNode unresPosIReal = mkUnresolvedType(posIRealName, unres);
+        unres_types.push_back(unresPosIReal);
+        // make data type for positive constant integral reals
+        sdts.push_back(SygusDatatypeGenerator(posIRealName));
         /* Add operator 1 */
-        Trace("sygus-grammar-def") << "\t...add for 1 to Pos_Int\n";
+        Trace("sygus-grammar-def") << "\t...add for 1.0 to PosIReal\n";
         std::vector<TypeNode> cargsEmpty;
         sdts.back().addConstructor(
-            nm->mkConstInt(Rational(1)), "1", cargsEmpty);
+            nm->mkConstReal(Rational(1)), "1", cargsEmpty);
         /* Add operator ADD */
         Kind kind = ADD;
-        Trace("sygus-grammar-def") << "\t...add for ADD to Pos_Int\n";
+        Trace("sygus-grammar-def") << "\t...add for ADD to PosIReal\n";
         std::vector<TypeNode> cargsPlus;
-        cargsPlus.push_back(unresPosInt);
-        cargsPlus.push_back(unresPosInt);
+        cargsPlus.push_back(unresPosIReal);
+        cargsPlus.push_back(unresPosIReal);
         sdts.back().addConstructor(kind, cargsPlus);
         sdts.back().d_sdt.initializeDatatype(types[i], bvl, true, true);
         Trace("sygus-grammar-def")
@@ -813,7 +813,7 @@ void CegGrammarConstructor::mkSygusDefaultGrammar(
         Trace("sygus-grammar-def") << "\t...add for " << kind << std::endl;
         std::vector<TypeNode> cargsDiv;
         cargsDiv.push_back(unres_t);
-        cargsDiv.push_back(unresPosInt);
+        cargsDiv.push_back(unresPosIReal);
         sdts[i].addConstructor(kind, cargsDiv);
       }
     }
index 0bf92baa299cda15aac119e3022ef2c0b81ef131..22b852b92c61aae7744680fb91565e0570167dac 100644 (file)
@@ -219,6 +219,10 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId,
           TheoryId newTheory = theoryOf(newNode);
           rewriteStackTop.d_node = newNode;
           rewriteStackTop.d_theoryId = newTheory;
+          Assert(
+              newNode.getType().isSubtypeOf(rewriteStackTop.d_node.getType()))
+              << "Pre-rewriting " << rewriteStackTop.d_node
+              << " does not preserve type";
           // In the pre-rewrite, if changing theories, we just call the other
           // theories pre-rewrite. If the kind of the node was changed, then we
           // pre-rewrite again.
@@ -306,6 +310,9 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId,
         // We continue with the response we got
         TNode newNode = response.d_node;
         TheoryId newTheoryId = theoryOf(newNode);
+        Assert(newNode.getType().isSubtypeOf(rewriteStackTop.d_node.getType()))
+            << "Post-rewriting " << rewriteStackTop.d_node
+            << " does not preserve type";
         if (newTheoryId != rewriteStackTop.getTheoryId()
             || response.d_status == REWRITE_AGAIN_FULL)
         {