From 42e503ba7d13c054f0b755a7fbda76abd3506f4b Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Fri, 6 May 2022 21:22:12 -0500 Subject: [PATCH] More preparation for strict type rules (#8733) This is work towards making equalities and substitutions between terms of equal types. --- .../passes/unconstrained_simplifier.cpp | 3 ++- src/theory/arith/arith_msum.cpp | 16 +++++++++--- src/theory/arith/arith_utilities.cpp | 19 +++++++++++++- src/theory/arith/arith_utilities.h | 9 ++++++- .../arith/linear/theory_arith_private.cpp | 5 ++++ .../arith/nl/ext/monomial_bounds_check.cpp | 17 ++++++++++--- src/theory/arith/nl/ext/monomial_check.cpp | 5 ++-- src/theory/arith/nl/nl_model.cpp | 25 +++++++++++++------ .../nl/transcendental/exponential_solver.cpp | 2 +- .../arith/nl/transcendental/proof_checker.cpp | 2 +- src/theory/arith/operator_elim.cpp | 12 +++++++-- .../cegqi/ceg_arith_instantiator.cpp | 11 +++++--- src/theory/quantifiers/sygus_sampler.cpp | 8 +++--- src/theory/substitutions.cpp | 2 ++ test/api/cpp/reset_assertions.cpp | 5 ++-- test/api/python/reset_assertions.py | 5 ++-- test/regress/cli/CMakeLists.txt | 1 + .../dd.pair-real-bool-const-conf.smt2 | 6 +++++ .../regress1/abduction/abd-real-const.smt2 | 4 +-- test/regress/cli/regress1/ho/issue4758.smt2 | 2 +- 20 files changed, 120 insertions(+), 39 deletions(-) create mode 100644 test/regress/cli/regress0/datatypes/dd.pair-real-bool-const-conf.smt2 diff --git a/src/preprocessing/passes/unconstrained_simplifier.cpp b/src/preprocessing/passes/unconstrained_simplifier.cpp index 4a9ba46a2..3f2a9d9ed 100644 --- a/src/preprocessing/passes/unconstrained_simplifier.cpp +++ b/src/preprocessing/passes/unconstrained_simplifier.cpp @@ -530,7 +530,8 @@ void UnconstrainedSimplifier::processUnconstrained() else { // TODO(#2377): could build ITE here - Node test = other.eqNode(nm->mkConstReal(Rational(0))); + Node test = other.eqNode( + nm->mkConstRealOrInt(other.getType(), Rational(0))); if (rewrite(test) != nm->mkConst(false)) { break; diff --git a/src/theory/arith/arith_msum.cpp b/src/theory/arith/arith_msum.cpp index 6d3499611..c3b18ac6b 100644 --- a/src/theory/arith/arith_msum.cpp +++ b/src/theory/arith/arith_msum.cpp @@ -219,12 +219,13 @@ int ArithMSum::isolate( int ires = isolate(v, msum, veq_c, val, k); if (ires != 0) { + NodeManager* nm = NodeManager::currentNM(); Node vc = v; if (!veq_c.isNull()) { if (doCoeff) { - vc = NodeManager::currentNM()->mkNode(MULT, veq_c, vc); + vc = nm->mkNode(MULT, veq_c, vc); } else { @@ -232,8 +233,17 @@ int ArithMSum::isolate( } } bool inOrder = ires == 1; - veq = NodeManager::currentNM()->mkNode( - k, inOrder ? vc : val, inOrder ? val : vc); + // ensure type is correct for equality + if (k == EQUAL) + { + if (!vc.getType().isInteger() && val.getType().isInteger()) + { + val = nm->mkNode(TO_REAL, val); + } + // note that conversely this utility will never use a real value as + // the solution for an integer, thus the types should match now + } + veq = nm->mkNode(k, inOrder ? vc : val, inOrder ? val : vc); } return ires; } diff --git a/src/theory/arith/arith_utilities.cpp b/src/theory/arith/arith_utilities.cpp index 00cbae056..76ca33ce1 100644 --- a/src/theory/arith/arith_utilities.cpp +++ b/src/theory/arith/arith_utilities.cpp @@ -351,7 +351,7 @@ Node multConstants(const Node& c1, const Node& c2) tn, Rational(c1.getConst() * c2.getConst())); } -Node mkEquality(Node a, Node b) +Node mkEquality(const Node& a, const Node& b) { NodeManager* nm = NodeManager::currentNM(); Assert(a.getType().isRealOrInt()); @@ -366,6 +366,23 @@ Node mkEquality(Node a, Node b) return nm->mkNode(EQUAL, diff, mkZero(diff.getType())); } +std::pair mkSameType(const Node& a, const Node& b) +{ + TypeNode at = a.getType(); + TypeNode bt = b.getType(); + if (at == bt) + { + return {a, b}; + } + NodeManager* nm = NodeManager::currentNM(); + if (at.isInteger() && bt.isReal()) + { + return {nm->mkNode(kind::TO_REAL, a), b}; + } + Assert(at.isReal() && bt.isInteger()); + return {a, nm->mkNode(kind::TO_REAL, b)}; +} + } // namespace arith } // namespace theory } // namespace cvc5::internal diff --git a/src/theory/arith/arith_utilities.h b/src/theory/arith/arith_utilities.h index 584e1125a..92151d688 100644 --- a/src/theory/arith/arith_utilities.h +++ b/src/theory/arith/arith_utilities.h @@ -343,7 +343,14 @@ Node multConstants(const Node& c1, const Node& c2); * types, where zero has the same type as (- a b). * Use this utility to ensure an equality is properly typed. */ -Node mkEquality(Node a, Node b); +Node mkEquality(const Node& a, const Node& b); + +/** + * Ensures that the returned pair has equal type, where a and b have + * real or integer type. We add TO_REAL if not. + */ +std::pair mkSameType(const Node& a, const Node& b); + } // namespace arith } // namespace theory diff --git a/src/theory/arith/linear/theory_arith_private.cpp b/src/theory/arith/linear/theory_arith_private.cpp index aafeb43cb..29bec9c28 100644 --- a/src/theory/arith/linear/theory_arith_private.cpp +++ b/src/theory/arith/linear/theory_arith_private.cpp @@ -989,6 +989,11 @@ Theory::PPAssertStatus TheoryArithPrivate::ppAssert( // substitution is integral Trace("simplify") << "TheoryArithPrivate::solve(): substitution " << minVar << " |-> " << elim << endl; + if (elim.getType().isInteger() && !minVar.getType().isInteger()) + { + elim = NodeManager::currentNM()->mkNode(kind::TO_REAL, elim); + } + Assert(elim.getType() == minVar.getType()); outSubstitutions.addSubstitutionSolved(minVar, elim, tin); return Theory::PP_ASSERT_STATUS_SOLVED; } diff --git a/src/theory/arith/nl/ext/monomial_bounds_check.cpp b/src/theory/arith/nl/ext/monomial_bounds_check.cpp index 2d9ef5b8d..5b79c59d9 100644 --- a/src/theory/arith/nl/ext/monomial_bounds_check.cpp +++ b/src/theory/arith/nl/ext/monomial_bounds_check.cpp @@ -294,19 +294,28 @@ void MonomialBoundsCheck::checkBounds(const std::vector& asserts, << " ...coefficient " << mult << " is zero." << std::endl; continue; } + Node lhsTgt = t; + Node rhsTgt = rhs; + // if we are making an equality below, we require making it + // well-typed so that lhs/rhs have the same type. We use the + // mkSameType utility to do this + if (type == kind::EQUAL) + { + std::tie(lhsTgt, rhsTgt) = mkSameType(lhsTgt, rhsTgt); + } Trace("nl-ext-bound-debug") << " from " << x << " * " << mult << " = " << y << " and " << t << " " << type << " " << rhs << ", infer : " << std::endl; Kind infer_type = mmv_sign == -1 ? reverseRelationKind(type) : type; - Node infer_lhs = nm->mkNode(Kind::MULT, mult, t); - Node infer_rhs = nm->mkNode(Kind::MULT, mult, rhs); + Node infer_lhs = nm->mkNode(Kind::MULT, mult, lhsTgt); + Node infer_rhs = nm->mkNode(Kind::MULT, mult, rhsTgt); Node infer = nm->mkNode(infer_type, infer_lhs, infer_rhs); Trace("nl-ext-bound-debug") << " " << infer << std::endl; Node infer_mv = d_data->d_model.computeAbstractModelValue(rewrite(infer)); Trace("nl-ext-bound-debug") << " ...infer model value is " << infer_mv << std::endl; - if (infer_mv == d_data->d_false) + if (infer_mv.isConst() && !infer_mv.getConst()) { Node exp = nm->mkNode( Kind::AND, @@ -324,7 +333,7 @@ void MonomialBoundsCheck::checkBounds(const std::vector& asserts, if (d_data->isProofEnabled()) { proof = d_data->getProof(); - Node simpleeq = nm->mkNode(type, t, rhs); + Node simpleeq = nm->mkNode(type, lhsTgt, rhsTgt); // this is iblem, but uses (type t rhs) instead of the original // variant (which is identical under rewriting) // we first infer the "clean" version of the lemma and then diff --git a/src/theory/arith/nl/ext/monomial_check.cpp b/src/theory/arith/nl/ext/monomial_check.cpp index 8f2b23f44..479b77130 100644 --- a/src/theory/arith/nl/ext/monomial_check.cpp +++ b/src/theory/arith/nl/ext/monomial_check.cpp @@ -331,7 +331,7 @@ int MonomialCheck::compareSign( if (mvaoa.getConst().sgn() != 0) { Node prem = av.eqNode(zero); - Node conc = oa.eqNode(zero); + Node conc = oa.eqNode(mkZero(oa.getType())); Node lemma = prem.impNode(conc); CDProof* proof = nullptr; if (d_data->isProofEnabled()) @@ -420,10 +420,9 @@ bool MonomialCheck::compareMonomial( if (status == 2) { // must state that all variables are non-zero - Node zero = mkZero(oa.getType()); for (const Node& v : vla) { - exp.push_back(v.eqNode(zero).negate()); + exp.push_back(v.eqNode(mkZero(v.getType())).negate()); } } Node clem = nm->mkNode( diff --git a/src/theory/arith/nl/nl_model.cpp b/src/theory/arith/nl/nl_model.cpp index b7c09244c..4add4e798 100644 --- a/src/theory/arith/nl/nl_model.cpp +++ b/src/theory/arith/nl/nl_model.cpp @@ -1013,6 +1013,7 @@ void NlModel::printModelValue(const char* c, Node n, unsigned prec) const void NlModel::getModelValueRepair(std::map& arithModel) { + NodeManager* nm = NodeManager::currentNM(); Trace("nl-model") << "NlModel::getModelValueRepair:" << std::endl; // If we extended the model with entries x -> 0 for unconstrained values, // we first update the map to the extended one. @@ -1037,9 +1038,11 @@ void NlModel::getModelValueRepair(std::map& arithModel) } else { - // overwrite - arithModel[v] = l; - Trace("nl-model") << v << " exact approximation is " << l << std::endl; + // overwrite, ensure the type is correct + Assert(l.isConst()); + Node ll = nm->mkConstRealOrInt(v.getType(), l.getConst()); + arithModel[v] = ll; + Trace("nl-model") << v << " exact approximation is " << ll << std::endl; } } // Also record the exact values we used. An exact value can be seen as a @@ -1048,10 +1051,18 @@ void NlModel::getModelValueRepair(std::map& arithModel) // is eliminated. for (size_t i = 0; i < d_substitutions.size(); ++i) { - // overwrite - arithModel[d_substitutions.d_vars[i]] = d_substitutions.d_subs[i]; - Trace("nl-model") << d_substitutions.d_vars[i] << " solved is " - << d_substitutions.d_subs[i] << std::endl; + // overwrite, ensure the type is correct + Node v = d_substitutions.d_vars[i]; + Node s = d_substitutions.d_subs[i]; + Node ss = s; + // If its a rational constant, ensure it has the proper type now. It + // also may be a RAN, in which case v should be a real. + if (s.isConst()) + { + ss = nm->mkConstRealOrInt(v.getType(), s.getConst()); + } + arithModel[v] = ss; + Trace("nl-model") << v << " solved is " << ss << std::endl; } // multiplication terms should not be given values; their values are diff --git a/src/theory/arith/nl/transcendental/exponential_solver.cpp b/src/theory/arith/nl/transcendental/exponential_solver.cpp index e80b12641..0a8d71d46 100644 --- a/src/theory/arith/nl/transcendental/exponential_solver.cpp +++ b/src/theory/arith/nl/transcendental/exponential_solver.cpp @@ -98,7 +98,7 @@ void ExponentialSolver::checkInitialRefine() } { // must use real one/zero in equalities - Node rzero = nm->mkConstReal(Rational(0)); + Node rzero = mkZero(t[0].getType()); Node rone = nm->mkConstReal(Rational(1)); // exp at zero: (t = 0.0) <=> (exp(t) = 1.0) Node lem = diff --git a/src/theory/arith/nl/transcendental/proof_checker.cpp b/src/theory/arith/nl/transcendental/proof_checker.cpp index 5f6c166b8..a9c080796 100644 --- a/src/theory/arith/nl/transcendental/proof_checker.cpp +++ b/src/theory/arith/nl/transcendental/proof_checker.cpp @@ -136,7 +136,7 @@ Node TranscendentalProofRuleChecker::checkInternal( Assert(children.empty()); Assert(args.size() == 1); Node e = nm->mkNode(Kind::EXPONENTIAL, args[0]); - Node rzero = nm->mkConstReal(Rational(0)); + Node rzero = nm->mkConstRealOrInt(args[0].getType(), Rational(0)); Node rone = nm->mkConstReal(Rational(1)); return nm->mkNode(EQUAL, args[0].eqNode(rzero), e.eqNode(rone)); } diff --git a/src/theory/arith/operator_elim.cpp b/src/theory/arith/operator_elim.cpp index 78a6a3899..7981c4fba 100644 --- a/src/theory/arith/operator_elim.cpp +++ b/src/theory/arith/operator_elim.cpp @@ -227,7 +227,7 @@ Node OperatorElim::eliminateOperators(Node node, rw, "nonlinearDiv", "the result of a non-linear div term"); Node lem = nm->mkNode(IMPLIES, den.eqNode(mkZero(den.getType())).negate(), - nm->mkNode(MULT, den, v).eqNode(num)); + mkEquality(nm->mkNode(MULT, den, v), num)); lems.push_back(mkSkolemLemma(lem, v)); return v; break; @@ -440,7 +440,15 @@ Node OperatorElim::getArithSkolemApp(Node n, SkolemFunId id) Node skolem = getArithSkolem(id); if (usePartialFunction(id)) { - skolem = NodeManager::currentNM()->mkNode(APPLY_UF, skolem, n); + NodeManager* nm = NodeManager::currentNM(); + Assert(skolem.getType().isFunction() + && skolem.getType().getNumChildren() == 2); + TypeNode argType = skolem.getType()[0]; + if (!argType.isInteger() && n.getType().isInteger()) + { + n = nm->mkNode(TO_REAL, n); + } + skolem = nm->mkNode(APPLY_UF, skolem, n); } return skolem; } diff --git a/src/theory/quantifiers/cegqi/ceg_arith_instantiator.cpp b/src/theory/quantifiers/cegqi/ceg_arith_instantiator.cpp index 36e4f5a7c..47c6d3548 100644 --- a/src/theory/quantifiers/cegqi/ceg_arith_instantiator.cpp +++ b/src/theory/quantifiers/cegqi/ceg_arith_instantiator.cpp @@ -46,6 +46,7 @@ void ArithInstantiator::reset(CegInstantiator* ci, Node pv, CegInstEffort effort) { + Assert(pv.getType() == d_type); d_vts_sym[0] = d_vtc->getVtsInfinity(d_type, false, false); d_vts_sym[1] = d_vtc->getVtsDelta(false, false); for (unsigned i = 0; i < 2; i++) @@ -905,13 +906,12 @@ CegTermType ArithInstantiator::solve_arith(CegInstantiator* ci, int ires_use = (msum[pv].isNull() || msum[pv].getConst().sgn() == 1) ? 1 : -1; - val = nm->mkNode(ires_use == -1 ? ADD : SUB, + val = nm->mkNode(TO_INTEGER, nm->mkNode(ires_use == -1 ? ADD : SUB, nm->mkNode(ires_use == -1 ? SUB : ADD, val, realPart), - nm->mkNode(TO_INTEGER, realPart)); + nm->mkNode(TO_INTEGER, realPart))); Trace("cegqi-arith-debug") << "result (pre-rewrite) : " << val << std::endl; val = rewrite(val); - val = val.getKind() == TO_REAL ? val[0] : val; // could round up for upper bounds here Trace("cegqi-arith-debug") << "result : " << val << std::endl; Assert(val.getType().isInteger()); @@ -923,6 +923,11 @@ CegTermType ArithInstantiator::solve_arith(CegInstantiator* ci, } vts_coeff_inf = vts_coeff[0]; vts_coeff_delta = vts_coeff[1]; + if (!pv.getType().isInteger() && val.getType().isInteger()) + { + val = nm->mkNode(TO_REAL, val); + } + Assert(pv.getType() == val.getType()); Trace("cegqi-arith-debug") << "Return " << veq_c << " * " << pv << " " << atom.getKind() << " " << val << ", vts = (" << vts_coeff_inf << ", " << vts_coeff_delta << ")" diff --git a/src/theory/quantifiers/sygus_sampler.cpp b/src/theory/quantifiers/sygus_sampler.cpp index f767ca352..6c48f8a14 100644 --- a/src/theory/quantifiers/sygus_sampler.cpp +++ b/src/theory/quantifiers/sygus_sampler.cpp @@ -614,6 +614,7 @@ Node SygusSampler::getRandomValue(TypeNode tn) } ret = d_env.getRewriter()->rewrite(ret); Assert(ret.isConst()); + Assert(ret.getType()==tn); return ret; } } @@ -627,12 +628,9 @@ Node SygusSampler::getRandomValue(TypeNode tn) Rational rr = r.getConst(); if (rr.sgn() == 0) { - return s; - } - else - { - return nm->mkConstReal(sr / rr); + return nm->mkConstReal(s.getConst()); } + return nm->mkConstReal(sr / rr); } } // default: use type enumerator diff --git a/src/theory/substitutions.cpp b/src/theory/substitutions.cpp index 743815957..93b187cd2 100644 --- a/src/theory/substitutions.cpp +++ b/src/theory/substitutions.cpp @@ -178,6 +178,8 @@ Node SubstitutionMap::internalSubstitute(TNode t, void SubstitutionMap::addSubstitution(TNode x, TNode t, bool invalidateCache) { + // don't check type equal here, since this utility may be used in conversions + // that change the types of terms Trace("substitution") << "SubstitutionMap::addSubstitution(" << x << ", " << t << ")" << endl; Assert(d_substitutions.find(x) == d_substitutions.end()); diff --git a/test/api/cpp/reset_assertions.cpp b/test/api/cpp/reset_assertions.cpp index 735249d5d..5c86268a3 100644 --- a/test/api/cpp/reset_assertions.cpp +++ b/test/api/cpp/reset_assertions.cpp @@ -32,7 +32,7 @@ int main() Sort real = slv.getRealSort(); Term x = slv.mkConst(real, "x"); - Term four = slv.mkInteger(4); + Term four = slv.mkReal(4); Term xEqFour = slv.mkTerm(Kind::EQUAL, {x, four}); slv.assertFormula(xEqFour); std::cout << slv.checkSat() << std::endl; @@ -43,7 +43,8 @@ int main() Sort indexType = slv.getIntegerSort(); Sort arrayType = slv.mkArraySort(indexType, elementType); Term array = slv.mkConst(arrayType, "array"); - Term arrayAtFour = slv.mkTerm(Kind::SELECT, {array, four}); + Term fourInt = slv.mkInteger(4); + Term arrayAtFour = slv.mkTerm(Kind::SELECT, {array, fourInt}); Term ten = slv.mkInteger(10); Term arrayAtFour_eq_ten = slv.mkTerm(Kind::EQUAL, {arrayAtFour, ten}); slv.assertFormula(arrayAtFour_eq_ten); diff --git a/test/api/python/reset_assertions.py b/test/api/python/reset_assertions.py index 7946f49b7..dca6efb3a 100644 --- a/test/api/python/reset_assertions.py +++ b/test/api/python/reset_assertions.py @@ -26,7 +26,7 @@ slv.setOption("incremental", "true") real = slv.getRealSort() x = slv.mkConst(real, "x") -four = slv.mkInteger(4) +four = slv.mkReal(4) xEqFour = slv.mkTerm(Kind.EQUAL, x, four) slv.assertFormula(xEqFour) print(slv.checkSat()) @@ -37,7 +37,8 @@ elementType = slv.getIntegerSort() indexType = slv.getIntegerSort() arrayType = slv.mkArraySort(indexType, elementType) array = slv.mkConst(arrayType, "array") -arrayAtFour = slv.mkTerm(Kind.SELECT, array, four) +fourInt = slv.mkInteger(4) +arrayAtFour = slv.mkTerm(Kind.SELECT, array, fourInt) ten = slv.mkInteger(10) arrayAtFour_eq_ten = slv.mkTerm(Kind.EQUAL, arrayAtFour, ten) slv.assertFormula(arrayAtFour_eq_ten) diff --git a/test/regress/cli/CMakeLists.txt b/test/regress/cli/CMakeLists.txt index 4a9489cf6..7901da1c0 100644 --- a/test/regress/cli/CMakeLists.txt +++ b/test/regress/cli/CMakeLists.txt @@ -523,6 +523,7 @@ set(regress_0_tests regress0/datatypes/datatype2.cvc.smt2 regress0/datatypes/datatype3.cvc.smt2 regress0/datatypes/datatype4.cvc.smt2 + regress0/datatypes/dd.pair-real-bool-const-conf.smt2 regress0/datatypes/dt-2.6.smt2 regress0/datatypes/dt-different-params.smt2 regress0/datatypes/dt-match-pat-param-2.6.smt2 diff --git a/test/regress/cli/regress0/datatypes/dd.pair-real-bool-const-conf.smt2 b/test/regress/cli/regress0/datatypes/dd.pair-real-bool-const-conf.smt2 new file mode 100644 index 000000000..b181cbffb --- /dev/null +++ b/test/regress/cli/regress0/datatypes/dd.pair-real-bool-const-conf.smt2 @@ -0,0 +1,6 @@ +(set-logic ALL) +(set-info :status sat) +(declare-datatypes ((P 0)) (((k (f Real))))) +(declare-const r P) +(assert (= 0.0 (f r))) +(check-sat) diff --git a/test/regress/cli/regress1/abduction/abd-real-const.smt2 b/test/regress/cli/regress1/abduction/abd-real-const.smt2 index 258d80a88..32549a831 100644 --- a/test/regress/cli/regress1/abduction/abd-real-const.smt2 +++ b/test/regress/cli/regress1/abduction/abd-real-const.smt2 @@ -5,5 +5,5 @@ (declare-const x Real) (declare-const y Real) (declare-const z Real) -(assert (and (>= x 0) (< y 7))) -(get-abduct A (>= y 5)) +(assert (and (>= x 0.0) (< y 7.0))) +(get-abduct A (>= y 5.0)) diff --git a/test/regress/cli/regress1/ho/issue4758.smt2 b/test/regress/cli/regress1/ho/issue4758.smt2 index dab284c11..c0c5cdd01 100644 --- a/test/regress/cli/regress1/ho/issue4758.smt2 +++ b/test/regress/cli/regress1/ho/issue4758.smt2 @@ -2,5 +2,5 @@ (set-info :status sat) (declare-fun a () Real) (declare-fun b (Real Real) Real) -(assert (> (b a 0) (b (- a) 1))) +(assert (> (b a 0.0) (b (- a) 1.0))) (check-sat) -- 2.30.2