More preparation for strict type rules (#8733)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Sat, 7 May 2022 02:22:12 +0000 (21:22 -0500)
committerGitHub <noreply@github.com>
Sat, 7 May 2022 02:22:12 +0000 (02:22 +0000)
This is work towards making equalities and substitutions between terms of equal types.

20 files changed:
src/preprocessing/passes/unconstrained_simplifier.cpp
src/theory/arith/arith_msum.cpp
src/theory/arith/arith_utilities.cpp
src/theory/arith/arith_utilities.h
src/theory/arith/linear/theory_arith_private.cpp
src/theory/arith/nl/ext/monomial_bounds_check.cpp
src/theory/arith/nl/ext/monomial_check.cpp
src/theory/arith/nl/nl_model.cpp
src/theory/arith/nl/transcendental/exponential_solver.cpp
src/theory/arith/nl/transcendental/proof_checker.cpp
src/theory/arith/operator_elim.cpp
src/theory/quantifiers/cegqi/ceg_arith_instantiator.cpp
src/theory/quantifiers/sygus_sampler.cpp
src/theory/substitutions.cpp
test/api/cpp/reset_assertions.cpp
test/api/python/reset_assertions.py
test/regress/cli/CMakeLists.txt
test/regress/cli/regress0/datatypes/dd.pair-real-bool-const-conf.smt2 [new file with mode: 0644]
test/regress/cli/regress1/abduction/abd-real-const.smt2
test/regress/cli/regress1/ho/issue4758.smt2

index 4a9ba46a267506fd023bc5b9a3da35e724409f81..3f2a9d9ed9bb0665c7d83bbbf281c7afc4e2649c 100644 (file)
@@ -530,7 +530,8 @@ void UnconstrainedSimplifier::processUnconstrained()
             else
             {
               // TODO(#2377): could build ITE here
-              Node test = other.eqNode(nm->mkConstReal(Rational(0)));
+              Node test = other.eqNode(
+                  nm->mkConstRealOrInt(other.getType(), Rational(0)));
               if (rewrite(test) != nm->mkConst<bool>(false))
               {
                 break;
index 6d34996112e5817eaf85728e4ecea70fdab38f72..c3b18ac6b95095ff7db80a40ad17ab179de620ab 100644 (file)
@@ -219,12 +219,13 @@ int ArithMSum::isolate(
   int ires = isolate(v, msum, veq_c, val, k);
   if (ires != 0)
   {
+    NodeManager* nm = NodeManager::currentNM();
     Node vc = v;
     if (!veq_c.isNull())
     {
       if (doCoeff)
       {
-        vc = NodeManager::currentNM()->mkNode(MULT, veq_c, vc);
+        vc = nm->mkNode(MULT, veq_c, vc);
       }
       else
       {
@@ -232,8 +233,17 @@ int ArithMSum::isolate(
       }
     }
     bool inOrder = ires == 1;
-    veq = NodeManager::currentNM()->mkNode(
-        k, inOrder ? vc : val, inOrder ? val : vc);
+    // ensure type is correct for equality
+    if (k == EQUAL)
+    {
+      if (!vc.getType().isInteger() && val.getType().isInteger())
+      {
+        val = nm->mkNode(TO_REAL, val);
+      }
+      // note that conversely this utility will never use a real value as
+      // the solution for an integer, thus the types should match now
+    }
+    veq = nm->mkNode(k, inOrder ? vc : val, inOrder ? val : vc);
   }
   return ires;
 }
index 00cbae056cca87c3bd6a1a0c7e4ae74e10b19f42..76ca33ce1717b23d493ff8fcac0407f97a9c55d1 100644 (file)
@@ -351,7 +351,7 @@ Node multConstants(const Node& c1, const Node& c2)
       tn, Rational(c1.getConst<Rational>() * c2.getConst<Rational>()));
 }
 
-Node mkEquality(Node a, Node b)
+Node mkEquality(const Node& a, const Node& b)
 {
   NodeManager* nm = NodeManager::currentNM();
   Assert(a.getType().isRealOrInt());
@@ -366,6 +366,23 @@ Node mkEquality(Node a, Node b)
   return nm->mkNode(EQUAL, diff, mkZero(diff.getType()));
 }
 
+std::pair<Node,Node> mkSameType(const Node& a, const Node& b)
+{
+  TypeNode at = a.getType();
+  TypeNode bt = b.getType();
+  if (at == bt)
+  {
+    return {a, b};
+  }
+  NodeManager* nm = NodeManager::currentNM();
+  if (at.isInteger() && bt.isReal())
+  {
+    return {nm->mkNode(kind::TO_REAL, a), b};
+  }
+  Assert(at.isReal() && bt.isInteger());
+  return {a, nm->mkNode(kind::TO_REAL, b)};
+}
+
 }  // namespace arith
 }  // namespace theory
 }  // namespace cvc5::internal
index 584e1125ad5303e3ddc78a9b8bb2043fa522e603..92151d688370bb55ab30383ded4f481c3dffcbac 100644 (file)
@@ -343,7 +343,14 @@ Node multConstants(const Node& c1, const Node& c2);
  * types, where zero has the same type as (- a b).
  * Use this utility to ensure an equality is properly typed.
  */
-Node mkEquality(Node a, Node b);
+Node mkEquality(const Node& a, const Node& b);
+
+/**
+ * Ensures that the returned pair has equal type, where a and b have
+ * real or integer type. We add TO_REAL if not.
+ */
+std::pair<Node,Node> mkSameType(const Node& a, const Node& b);
+
 
 }  // namespace arith
 }  // namespace theory
index aafeb43cbc719ff6482568cf17ec03fe0933fe8e..29bec9c2820ddc6caf945929ab4d06538500855b 100644 (file)
@@ -989,6 +989,11 @@ Theory::PPAssertStatus TheoryArithPrivate::ppAssert(
         // substitution is integral
         Trace("simplify") << "TheoryArithPrivate::solve(): substitution "
                           << minVar << " |-> " << elim << endl;
+        if (elim.getType().isInteger() && !minVar.getType().isInteger())
+        {
+          elim = NodeManager::currentNM()->mkNode(kind::TO_REAL, elim);
+        }
+        Assert(elim.getType() == minVar.getType());
         outSubstitutions.addSubstitutionSolved(minVar, elim, tin);
         return Theory::PP_ASSERT_STATUS_SOLVED;
       }
index 2d9ef5b8d35f315314dbb3d7c6b637f5cef2af2f..5b79c59d95ee6283ab6dfe5725ebe4afda3c0003 100644 (file)
@@ -294,19 +294,28 @@ void MonomialBoundsCheck::checkBounds(const std::vector<Node>& asserts,
                 << "     ...coefficient " << mult << " is zero." << std::endl;
             continue;
           }
+          Node lhsTgt = t;
+          Node rhsTgt = rhs;
+          // if we are making an equality below, we require making it
+          // well-typed so that lhs/rhs have the same type. We use the
+          // mkSameType utility to do this
+          if (type == kind::EQUAL)
+          {
+            std::tie(lhsTgt, rhsTgt) = mkSameType(lhsTgt, rhsTgt);
+          }
           Trace("nl-ext-bound-debug")
               << "  from " << x << " * " << mult << " = " << y << " and " << t
               << " " << type << " " << rhs << ", infer : " << std::endl;
           Kind infer_type = mmv_sign == -1 ? reverseRelationKind(type) : type;
-          Node infer_lhs = nm->mkNode(Kind::MULT, mult, t);
-          Node infer_rhs = nm->mkNode(Kind::MULT, mult, rhs);
+          Node infer_lhs = nm->mkNode(Kind::MULT, mult, lhsTgt);
+          Node infer_rhs = nm->mkNode(Kind::MULT, mult, rhsTgt);
           Node infer = nm->mkNode(infer_type, infer_lhs, infer_rhs);
           Trace("nl-ext-bound-debug") << "     " << infer << std::endl;
           Node infer_mv =
               d_data->d_model.computeAbstractModelValue(rewrite(infer));
           Trace("nl-ext-bound-debug")
               << "       ...infer model value is " << infer_mv << std::endl;
-          if (infer_mv == d_data->d_false)
+          if (infer_mv.isConst() && !infer_mv.getConst<bool>())
           {
             Node exp = nm->mkNode(
                 Kind::AND,
@@ -324,7 +333,7 @@ void MonomialBoundsCheck::checkBounds(const std::vector<Node>& asserts,
             if (d_data->isProofEnabled())
             {
               proof = d_data->getProof();
-              Node simpleeq = nm->mkNode(type, t, rhs);
+              Node simpleeq = nm->mkNode(type, lhsTgt, rhsTgt);
               // this is iblem, but uses (type t rhs) instead of the original
               // variant (which is identical under rewriting)
               // we first infer the "clean" version of the lemma and then
index 8f2b23f44ea5f802ed278ee730d8985230c46e2c..479b77130562471a3a5bfe0871afa1be66b08264 100644 (file)
@@ -331,7 +331,7 @@ int MonomialCheck::compareSign(
     if (mvaoa.getConst<Rational>().sgn() != 0)
     {
       Node prem = av.eqNode(zero);
-      Node conc = oa.eqNode(zero);
+      Node conc = oa.eqNode(mkZero(oa.getType()));
       Node lemma = prem.impNode(conc);
       CDProof* proof = nullptr;
       if (d_data->isProofEnabled())
@@ -420,10 +420,9 @@ bool MonomialCheck::compareMonomial(
       if (status == 2)
       {
         // must state that all variables are non-zero
-        Node zero = mkZero(oa.getType());
         for (const Node& v : vla)
         {
-          exp.push_back(v.eqNode(zero).negate());
+          exp.push_back(v.eqNode(mkZero(v.getType())).negate());
         }
       }
       Node clem = nm->mkNode(
index b7c09244cf63a2d6aa49587956fd9d1bb37e1839..4add4e798a11f695e4950f4ce14f66a858ba75ee 100644 (file)
@@ -1013,6 +1013,7 @@ void NlModel::printModelValue(const char* c, Node n, unsigned prec) const
 
 void NlModel::getModelValueRepair(std::map<Node, Node>& arithModel)
 {
+  NodeManager* nm = NodeManager::currentNM();
   Trace("nl-model") << "NlModel::getModelValueRepair:" << std::endl;
   // If we extended the model with entries x -> 0 for unconstrained values,
   // we first update the map to the extended one.
@@ -1037,9 +1038,11 @@ void NlModel::getModelValueRepair(std::map<Node, Node>& arithModel)
     }
     else
     {
-      // overwrite
-      arithModel[v] = l;
-      Trace("nl-model") << v << " exact approximation is " << l << std::endl;
+      // overwrite, ensure the type is correct
+      Assert(l.isConst());
+      Node ll = nm->mkConstRealOrInt(v.getType(), l.getConst<Rational>());
+      arithModel[v] = ll;
+      Trace("nl-model") << v << " exact approximation is " << ll << std::endl;
     }
   }
   // Also record the exact values we used. An exact value can be seen as a
@@ -1048,10 +1051,18 @@ void NlModel::getModelValueRepair(std::map<Node, Node>& arithModel)
   // is eliminated.
   for (size_t i = 0; i < d_substitutions.size(); ++i)
   {
-    // overwrite
-    arithModel[d_substitutions.d_vars[i]] = d_substitutions.d_subs[i];
-    Trace("nl-model") << d_substitutions.d_vars[i] << " solved is "
-                      << d_substitutions.d_subs[i] << std::endl;
+    // overwrite, ensure the type is correct
+    Node v = d_substitutions.d_vars[i];
+    Node s = d_substitutions.d_subs[i];
+    Node ss = s;
+    // If its a rational constant, ensure it has the proper type now. It
+    // also may be a RAN, in which case v should be a real.
+    if (s.isConst())
+    {
+      ss = nm->mkConstRealOrInt(v.getType(), s.getConst<Rational>());
+    }
+    arithModel[v] = ss;
+    Trace("nl-model") << v << " solved is " << ss << std::endl;
   }
 
   // multiplication terms should not be given values; their values are
index e80b12641ff0c9efad2faf651da172b37e6100f6..0a8d71d46638b29a5f17030514c6f79023b7ced2 100644 (file)
@@ -98,7 +98,7 @@ void ExponentialSolver::checkInitialRefine()
         }
         {
           // must use real one/zero in equalities
-          Node rzero = nm->mkConstReal(Rational(0));
+          Node rzero = mkZero(t[0].getType());
           Node rone = nm->mkConstReal(Rational(1));
           // exp at zero: (t = 0.0) <=> (exp(t) = 1.0)
           Node lem =
index 5f6c166b8474f3cbd934531d1771caafda8413b3..a9c0807960b84beb44ee16c6b082b390185b14b3 100644 (file)
@@ -136,7 +136,7 @@ Node TranscendentalProofRuleChecker::checkInternal(
     Assert(children.empty());
     Assert(args.size() == 1);
     Node e = nm->mkNode(Kind::EXPONENTIAL, args[0]);
-    Node rzero = nm->mkConstReal(Rational(0));
+    Node rzero = nm->mkConstRealOrInt(args[0].getType(), Rational(0));
     Node rone = nm->mkConstReal(Rational(1));
     return nm->mkNode(EQUAL, args[0].eqNode(rzero), e.eqNode(rone));
   }
index 78a6a3899005ff04c3e7470d1a91f1d4485a2512..7981c4fbaf1c93036440472e88b79dcf55be5364 100644 (file)
@@ -227,7 +227,7 @@ Node OperatorElim::eliminateOperators(Node node,
           rw, "nonlinearDiv", "the result of a non-linear div term");
       Node lem = nm->mkNode(IMPLIES,
                             den.eqNode(mkZero(den.getType())).negate(),
-                            nm->mkNode(MULT, den, v).eqNode(num));
+                            mkEquality(nm->mkNode(MULT, den, v), num));
       lems.push_back(mkSkolemLemma(lem, v));
       return v;
       break;
@@ -440,7 +440,15 @@ Node OperatorElim::getArithSkolemApp(Node n, SkolemFunId id)
   Node skolem = getArithSkolem(id);
   if (usePartialFunction(id))
   {
-    skolem = NodeManager::currentNM()->mkNode(APPLY_UF, skolem, n);
+    NodeManager* nm = NodeManager::currentNM();
+    Assert(skolem.getType().isFunction()
+           && skolem.getType().getNumChildren() == 2);
+    TypeNode argType = skolem.getType()[0];
+    if (!argType.isInteger() && n.getType().isInteger())
+    {
+      n = nm->mkNode(TO_REAL, n);
+    }
+    skolem = nm->mkNode(APPLY_UF, skolem, n);
   }
   return skolem;
 }
index 36e4f5a7cd32af824a78a19434ebd6b6a864a817..47c6d3548c5e3ed4e85956551f281ef42559b302 100644 (file)
@@ -46,6 +46,7 @@ void ArithInstantiator::reset(CegInstantiator* ci,
                               Node pv,
                               CegInstEffort effort)
 {
+  Assert(pv.getType() == d_type);
   d_vts_sym[0] = d_vtc->getVtsInfinity(d_type, false, false);
   d_vts_sym[1] = d_vtc->getVtsDelta(false, false);
   for (unsigned i = 0; i < 2; i++)
@@ -905,13 +906,12 @@ CegTermType ArithInstantiator::solve_arith(CegInstantiator* ci,
       int ires_use =
           (msum[pv].isNull() || msum[pv].getConst<Rational>().sgn() == 1) ? 1
                                                                           : -1;
-      val = nm->mkNode(ires_use == -1 ? ADD : SUB,
+      val = nm->mkNode(TO_INTEGER, nm->mkNode(ires_use == -1 ? ADD : SUB,
                        nm->mkNode(ires_use == -1 ? SUB : ADD, val, realPart),
-                       nm->mkNode(TO_INTEGER, realPart));
+                       nm->mkNode(TO_INTEGER, realPart)));
       Trace("cegqi-arith-debug")
           << "result (pre-rewrite) : " << val << std::endl;
       val = rewrite(val);
-      val = val.getKind() == TO_REAL ? val[0] : val;
       // could round up for upper bounds here
       Trace("cegqi-arith-debug") << "result : " << val << std::endl;
       Assert(val.getType().isInteger());
@@ -923,6 +923,11 @@ CegTermType ArithInstantiator::solve_arith(CegInstantiator* ci,
   }
   vts_coeff_inf = vts_coeff[0];
   vts_coeff_delta = vts_coeff[1];
+  if (!pv.getType().isInteger() && val.getType().isInteger())
+  {
+    val = nm->mkNode(TO_REAL, val);
+  }
+  Assert(pv.getType() == val.getType());
   Trace("cegqi-arith-debug")
       << "Return " << veq_c << " * " << pv << " " << atom.getKind() << " "
       << val << ", vts = (" << vts_coeff_inf << ", " << vts_coeff_delta << ")"
index f767ca352c13ff58c6dd48211f950cd405a6648d..6c48f8a14f732947fa97684bf68f068d0dca6cae 100644 (file)
@@ -614,6 +614,7 @@ Node SygusSampler::getRandomValue(TypeNode tn)
       }
       ret = d_env.getRewriter()->rewrite(ret);
       Assert(ret.isConst());
+      Assert(ret.getType()==tn);
       return ret;
     }
   }
@@ -627,12 +628,9 @@ Node SygusSampler::getRandomValue(TypeNode tn)
       Rational rr = r.getConst<Rational>();
       if (rr.sgn() == 0)
       {
-        return s;
-      }
-      else
-      {
-        return nm->mkConstReal(sr / rr);
+        return nm->mkConstReal(s.getConst<Rational>());
       }
+      return nm->mkConstReal(sr / rr);
     }
   }
   // default: use type enumerator
index 743815957c6c36268dac9efbcaf8dbd14e72deb9..93b187cd25cbd31feaa7292926117906ae428f1c 100644 (file)
@@ -178,6 +178,8 @@ Node SubstitutionMap::internalSubstitute(TNode t,
 
 void SubstitutionMap::addSubstitution(TNode x, TNode t, bool invalidateCache)
 {
+  // don't check type equal here, since this utility may be used in conversions
+  // that change the types of terms
   Trace("substitution") << "SubstitutionMap::addSubstitution(" << x << ", " << t << ")" << endl;
   Assert(d_substitutions.find(x) == d_substitutions.end());
 
index 735249d5de4393ea41a18f2188b7f7e650797e74..5c86268a349b249c7dc69aa22ff7be8ca5acec4f 100644 (file)
@@ -32,7 +32,7 @@ int main()
 
   Sort real = slv.getRealSort();
   Term x = slv.mkConst(real, "x");
-  Term four = slv.mkInteger(4);
+  Term four = slv.mkReal(4);
   Term xEqFour = slv.mkTerm(Kind::EQUAL, {x, four});
   slv.assertFormula(xEqFour);
   std::cout << slv.checkSat() << std::endl;
@@ -43,7 +43,8 @@ int main()
   Sort indexType = slv.getIntegerSort();
   Sort arrayType = slv.mkArraySort(indexType, elementType);
   Term array = slv.mkConst(arrayType, "array");
-  Term arrayAtFour = slv.mkTerm(Kind::SELECT, {array, four});
+  Term fourInt = slv.mkInteger(4);
+  Term arrayAtFour = slv.mkTerm(Kind::SELECT, {array, fourInt});
   Term ten = slv.mkInteger(10);
   Term arrayAtFour_eq_ten = slv.mkTerm(Kind::EQUAL, {arrayAtFour, ten});
   slv.assertFormula(arrayAtFour_eq_ten);
index 7946f49b7765f82ee816f5373c88d335bf92c294..dca6efb3ab91d72d4b4f63e1cbf65b29d5b85419 100644 (file)
@@ -26,7 +26,7 @@ slv.setOption("incremental", "true")
 
 real = slv.getRealSort()
 x = slv.mkConst(real, "x")
-four = slv.mkInteger(4)
+four = slv.mkReal(4)
 xEqFour = slv.mkTerm(Kind.EQUAL, x, four)
 slv.assertFormula(xEqFour)
 print(slv.checkSat())
@@ -37,7 +37,8 @@ elementType = slv.getIntegerSort()
 indexType = slv.getIntegerSort()
 arrayType = slv.mkArraySort(indexType, elementType)
 array = slv.mkConst(arrayType, "array")
-arrayAtFour = slv.mkTerm(Kind.SELECT, array, four)
+fourInt = slv.mkInteger(4)
+arrayAtFour = slv.mkTerm(Kind.SELECT, array, fourInt)
 ten = slv.mkInteger(10)
 arrayAtFour_eq_ten = slv.mkTerm(Kind.EQUAL, arrayAtFour, ten)
 slv.assertFormula(arrayAtFour_eq_ten)
index 4a9489cf6854cb05da89988c5a0831a30923fdcd..7901da1c03fd9cb52a5469bbabc73f68e59d6281 100644 (file)
@@ -523,6 +523,7 @@ set(regress_0_tests
   regress0/datatypes/datatype2.cvc.smt2
   regress0/datatypes/datatype3.cvc.smt2
   regress0/datatypes/datatype4.cvc.smt2
+  regress0/datatypes/dd.pair-real-bool-const-conf.smt2
   regress0/datatypes/dt-2.6.smt2
   regress0/datatypes/dt-different-params.smt2
   regress0/datatypes/dt-match-pat-param-2.6.smt2
diff --git a/test/regress/cli/regress0/datatypes/dd.pair-real-bool-const-conf.smt2 b/test/regress/cli/regress0/datatypes/dd.pair-real-bool-const-conf.smt2
new file mode 100644 (file)
index 0000000..b181cbf
--- /dev/null
@@ -0,0 +1,6 @@
+(set-logic ALL)
+(set-info :status sat)
+(declare-datatypes ((P 0)) (((k (f Real)))))
+(declare-const r P)
+(assert (= 0.0 (f r)))
+(check-sat)
index 258d80a88bd4d15fcbd28344ac0e11ec6185d863..32549a831a346edcb1e5e741752864af0cc7e421 100644 (file)
@@ -5,5 +5,5 @@
 (declare-const x Real)
 (declare-const y Real)
 (declare-const z Real)
-(assert (and (>= x 0) (< y 7)))
-(get-abduct A (>= y 5))
+(assert (and (>= x 0.0) (< y 7.0)))
+(get-abduct A (>= y 5.0))
index dab284c1121e34016b866a2f6ecfe392218affd9..c0c5cdd01523d8f64c5638b0ff55f0b342f718e8 100644 (file)
@@ -2,5 +2,5 @@
 (set-info :status sat)
 (declare-fun a () Real)
 (declare-fun b (Real Real) Real)
-(assert (> (b a 0) (b (- a) 1)))
+(assert (> (b a 0.0) (b (- a) 1.0)))
 (check-sat)