From: Andrew Reynolds Date: Fri, 17 Dec 2021 03:47:43 +0000 (-0600) Subject: Eliminate more uses of CONST_RATIONAL (#7816) X-Git-Tag: cvc5-1.0.0~648 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=3973cfaa8763068a635f9091367b7642f322cbd9;p=cvc5.git Eliminate more uses of CONST_RATIONAL (#7816) --- diff --git a/src/parser/smt2/smt2.cpp b/src/parser/smt2/smt2.cpp index 4e1a8aae8..1fca42634 100644 --- a/src/parser/smt2/smt2.cpp +++ b/src/parser/smt2/smt2.cpp @@ -1013,21 +1013,22 @@ api::Term Smt2::applyParseOp(ParseOp& p, std::vector& args) // integer constants. We must ensure numerator and denominator are // constant and the denominator is non-zero. A similar issue happens for // negative integers and reals, with unary minus. + // NOTE this should be applied more eagerly when UMINUS/DIVISION is + // constructed. bool isNeg = false; if (constVal.getKind() == api::UMINUS) { isNeg = true; constVal = constVal[0]; } - if (constVal.getKind() == api::DIVISION - && constVal[0].getKind() == api::CONST_RATIONAL - && constVal[1].getKind() == api::CONST_RATIONAL) + if (constVal.getKind() == api::DIVISION && isConstInt(constVal[0]) + && isConstInt(constVal[1])) { std::stringstream sdiv; sdiv << (isNeg ? "-" : "") << constVal[0] << "/" << constVal[1]; constVal = d_solver->mkReal(sdiv.str()); } - else if (constVal.getKind() == api::CONST_RATIONAL && isNeg) + else if (isConstInt(constVal) && isNeg) { std::stringstream sneg; sneg << "-" << constVal; @@ -1229,7 +1230,7 @@ void Smt2::notifyNamedExpression(api::Term& expr, std::string name) setLastNamedTerm(expr, name); } -api::Term Smt2::mkAnd(const std::vector& es) +api::Term Smt2::mkAnd(const std::vector& es) const { if (es.size() == 0) { @@ -1239,10 +1240,15 @@ api::Term Smt2::mkAnd(const std::vector& es) { return es[0]; } - else - { - return d_solver->mkTerm(api::AND, es); - } + return d_solver->mkTerm(api::AND, es); +} + +bool Smt2::isConstInt(const api::Term& t) +{ + api::Kind k = t.getKind(); + // !!! Note when arithmetic subtyping is eliminated, this will update to + // CONST_INTEGER. + return k == api::CONST_RATIONAL; } } // namespace parser diff --git a/src/parser/smt2/smt2.h b/src/parser/smt2/smt2.h index 58a20cb27..6df62d787 100644 --- a/src/parser/smt2/smt2.h +++ b/src/parser/smt2/smt2.h @@ -422,7 +422,11 @@ class Smt2 : public Parser * @return True if `es` is empty, `e` if `es` consists of a single element * `e`, the conjunction of expressions otherwise. */ - api::Term mkAnd(const std::vector& es); + api::Term mkAnd(const std::vector& es) const; + /** + * Is term t a constant integer? + */ + static bool isConstInt(const api::Term& t); }; /* class Smt2 */ } // namespace parser diff --git a/src/preprocessing/passes/real_to_int.cpp b/src/preprocessing/passes/real_to_int.cpp index d2cde7b46..ef9077770 100644 --- a/src/preprocessing/passes/real_to_int.cpp +++ b/src/preprocessing/passes/real_to_int.cpp @@ -98,9 +98,9 @@ Node RealToInt::realToIntInternal(TNode n, NodeMap& cache, std::vector& va Node s; if (c.isNull()) { - c = cc.isNull() ? NodeManager::currentNM()->mkConst( - CONST_RATIONAL, Rational(1)) - : cc; + c = cc.isNull() + ? NodeManager::currentNM()->mkConstInt(Rational(1)) + : cc; } else { diff --git a/src/preprocessing/passes/unconstrained_simplifier.cpp b/src/preprocessing/passes/unconstrained_simplifier.cpp index 027be232b..7a58fc231 100644 --- a/src/preprocessing/passes/unconstrained_simplifier.cpp +++ b/src/preprocessing/passes/unconstrained_simplifier.cpp @@ -530,8 +530,7 @@ void UnconstrainedSimplifier::processUnconstrained() else { // TODO(#2377): could build ITE here - Node test = - other.eqNode(nm->mkConst(CONST_RATIONAL, Rational(0))); + Node test = other.eqNode(nm->mkConstReal(Rational(0))); if (rewrite(test) != nm->mkConst(false)) { break; diff --git a/src/theory/arith/arith_poly_norm.h b/src/theory/arith/arith_poly_norm.h index 9c3cbcf95..fafa94ee3 100644 --- a/src/theory/arith/arith_poly_norm.h +++ b/src/theory/arith/arith_poly_norm.h @@ -40,7 +40,7 @@ class PolyNorm */ void addMonomial(TNode x, const Rational& c, bool isNeg = false); /** - * Multiply this polynomial by the monomial x*c, where c is a CONST_RATIONAL. + * Multiply this polynomial by the monomial x*c, where c is a constant. * If x is null, then x*c is treated as c. */ void multiplyMonomial(TNode x, const Rational& c); diff --git a/src/theory/arith/arith_utilities.cpp b/src/theory/arith/arith_utilities.cpp index a9fd97079..6f43cfc1b 100644 --- a/src/theory/arith/arith_utilities.cpp +++ b/src/theory/arith/arith_utilities.cpp @@ -319,6 +319,22 @@ Node negateProofLiteral(TNode n) } } +Node multConstants(const Node& c1, const Node& c2) +{ + Assert(!c1.isNull() && c1.isConst()); + Assert(!c2.isNull() && c2.isConst()); + NodeManager* nm = NodeManager::currentNM(); + // real type if either has type real + TypeNode tn = c1.getType(); + if (tn.isInteger()) + { + tn = c2.getType(); + } + Assert(tn.isRealOrInt()); + return nm->mkConstRealOrInt( + tn, Rational(c1.getConst() * c2.getConst())); +} + } // namespace arith } // namespace theory } // namespace cvc5 diff --git a/src/theory/arith/arith_utilities.h b/src/theory/arith/arith_utilities.h index b926af2e0..027f7a65a 100644 --- a/src/theory/arith/arith_utilities.h +++ b/src/theory/arith/arith_utilities.h @@ -332,6 +332,12 @@ Rational greatestIntLessThan(const Rational&); /** Negates a node in arithmetic proof normal form. */ Node negateProofLiteral(TNode n); +/** + * Return the result of multiplying constant integer or real nodes c1 and c2. + * The returned type is real if either have type real. + */ +Node multConstants(const Node& c1, const Node& c2); + } // namespace arith } // namespace theory } // namespace cvc5 diff --git a/src/theory/bags/bag_reduction.cpp b/src/theory/bags/bag_reduction.cpp index 3e6544882..a6895e4b1 100644 --- a/src/theory/bags/bag_reduction.cpp +++ b/src/theory/bags/bag_reduction.cpp @@ -134,8 +134,8 @@ Node BagReduction::reduceCardOperator(Node node, std::vector& asserts) NodeManager* nm = NodeManager::currentNM(); SkolemManager* sm = nm->getSkolemManager(); Node A = node[0]; - Node zero = nm->mkConst(CONST_RATIONAL, Rational(0)); - Node one = nm->mkConst(CONST_RATIONAL, Rational(1)); + Node zero = nm->mkConstInt(Rational(0)); + Node one = nm->mkConstInt(Rational(1)); // types TypeNode bagType = A.getType(); TypeNode elementType = A.getType().getBagElementType(); diff --git a/src/theory/fp/theory_fp.cpp b/src/theory/fp/theory_fp.cpp index 972fac5a3..bcbd9e297 100644 --- a/src/theory/fp/theory_fp.cpp +++ b/src/theory/fp/theory_fp.cpp @@ -364,7 +364,7 @@ bool TheoryFp::refineAbstraction(TheoryModel *m, TNode abstract, TNode concrete) Node realValueOfAbstract = rewrite(nm->mkNode(kind::FLOATINGPOINT_TO_REAL_TOTAL, abstractValue, - nm->mkConst(CONST_RATIONAL, Rational(0U)))); + nm->mkConstReal(Rational(0U)))); Node bg = nm->mkNode( kind::IMPLIES, @@ -570,8 +570,7 @@ void TheoryFp::registerTerm(TNode node) Node z = nm->mkNode( kind::IMPLIES, nm->mkNode(kind::FLOATINGPOINT_ISZ, node[0]), - nm->mkNode( - kind::EQUAL, node, nm->mkConst(CONST_RATIONAL, Rational(0U)))); + nm->mkNode(kind::EQUAL, node, nm->mkConstReal(Rational(0U)))); handleLemma(z, InferenceId::FP_REGISTER_TERM); return; @@ -592,8 +591,7 @@ void TheoryFp::registerTerm(TNode node) Node z = nm->mkNode( kind::IMPLIES, - nm->mkNode( - kind::EQUAL, node[1], nm->mkConst(CONST_RATIONAL, Rational(0U))), + nm->mkNode(kind::EQUAL, node[1], nm->mkConstReal(Rational(0U))), nm->mkNode(kind::EQUAL, node, nm->mkConst(FloatingPoint::makeZero( diff --git a/src/theory/fp/theory_fp_rewriter.cpp b/src/theory/fp/theory_fp_rewriter.cpp index 32c3cff41..779d02ab3 100644 --- a/src/theory/fp/theory_fp_rewriter.cpp +++ b/src/theory/fp/theory_fp_rewriter.cpp @@ -915,7 +915,7 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite) FloatingPoint::PartialRational res(arg.convertToRational()); if (res.second) { - Node lit = NodeManager::currentNM()->mkConst(CONST_RATIONAL, res.first); + Node lit = NodeManager::currentNM()->mkConstReal(res.first); return RewriteResponse(REWRITE_DONE, lit); } else { // Can't constant fold the underspecified case @@ -998,14 +998,14 @@ RewriteResponse maxTotal(TNode node, bool isPreRewrite) Rational partialValue(node[1].getConst()); Rational folded(arg.convertToRationalTotal(partialValue)); - Node lit = NodeManager::currentNM()->mkConst(CONST_RATIONAL, folded); + Node lit = NodeManager::currentNM()->mkConstReal(folded); return RewriteResponse(REWRITE_DONE, lit); } else { FloatingPoint::PartialRational res(arg.convertToRational()); if (res.second) { - Node lit = NodeManager::currentNM()->mkConst(CONST_RATIONAL, res.first); + Node lit = NodeManager::currentNM()->mkConstReal(res.first); return RewriteResponse(REWRITE_DONE, lit); } else { // Can't constant fold the underspecified case diff --git a/src/theory/quantifiers/cegqi/ceg_arith_instantiator.cpp b/src/theory/quantifiers/cegqi/ceg_arith_instantiator.cpp index 56debbbac..ecf2d9a48 100644 --- a/src/theory/quantifiers/cegqi/ceg_arith_instantiator.cpp +++ b/src/theory/quantifiers/cegqi/ceg_arith_instantiator.cpp @@ -36,8 +36,8 @@ namespace quantifiers { ArithInstantiator::ArithInstantiator(Env& env, TypeNode tn, VtsTermCache* vtc) : Instantiator(env, tn), d_vtc(vtc) { - d_zero = NodeManager::currentNM()->mkConst(CONST_RATIONAL, Rational(0)); - d_one = NodeManager::currentNM()->mkConst(CONST_RATIONAL, Rational(1)); + d_zero = NodeManager::currentNM()->mkConstRealOrInt(tn, Rational(0)); + d_one = NodeManager::currentNM()->mkConstRealOrInt(tn, Rational(1)); } void ArithInstantiator::reset(CegInstantiator* ci, @@ -185,8 +185,7 @@ bool ArithInstantiator::processAssertion(CegInstantiator* ci, uval = nm->mkNode( PLUS, val, - nm->mkConst(CONST_RATIONAL, - Rational(isUpperBoundCTT(uires) ? 1 : -1))); + nm->mkConstInt(Rational(isUpperBoundCTT(uires) ? 1 : -1))); uval = rewrite(uval); } else @@ -253,11 +252,10 @@ bool ArithInstantiator::processAssertion(CegInstantiator* ci, if (d_type.isInteger()) { uires = is_upper ? CEG_TT_LOWER : CEG_TT_UPPER; - uval = - nm->mkNode(PLUS, - val, - nm->mkConst(CONST_RATIONAL, - Rational(isUpperBoundCTT(uires) ? 1 : -1))); + uval = nm->mkNode( + PLUS, + val, + nm->mkConstInt(Rational(isUpperBoundCTT(uires) ? 1 : -1))); uval = rewrite(uval); } else @@ -278,8 +276,8 @@ bool ArithInstantiator::processAssertion(CegInstantiator* ci, { if (options().quantifiers.cegqiModel) { - Node delta_coeff = nm->mkConst( - CONST_RATIONAL, Rational(isUpperBoundCTT(uires) ? 1 : -1)); + Node delta_coeff = nm->mkConstRealOrInt( + d_type, Rational(isUpperBoundCTT(uires) ? 1 : -1)); if (vts_coeff_delta.isNull()) { vts_coeff_delta = delta_coeff; @@ -455,9 +453,8 @@ bool ArithInstantiator::processAssertions(CegInstantiator* ci, Assert(d_mbp_coeff[rr][j].isConst()); value[t] = nm->mkNode( MULT, - nm->mkConst( - CONST_RATIONAL, - Rational(1) / d_mbp_coeff[rr][j].getConst()), + nm->mkConstReal(Rational(1) + / d_mbp_coeff[rr][j].getConst()), value[t]); value[t] = rewrite(value[t]); } @@ -611,10 +608,9 @@ bool ArithInstantiator::processAssertions(CegInstantiator* ci, } else { - val = - nm->mkNode(MULT, - nm->mkNode(PLUS, vals[0], vals[1]), - nm->mkConst(CONST_RATIONAL, Rational(1) / Rational(2))); + val = nm->mkNode(MULT, + nm->mkNode(PLUS, vals[0], vals[1]), + nm->mkConstReal(Rational(1) / Rational(2))); val = rewrite(val); } } @@ -809,7 +805,7 @@ CegTermType ArithInstantiator::solve_arith(CegInstantiator* ci, vts_coeff[t] = itminf->second; if (vts_coeff[t].isNull()) { - vts_coeff[t] = nm->mkConst(CONST_RATIONAL, Rational(1)); + vts_coeff[t] = nm->mkConstRealOrInt(d_type, Rational(1)); } // negate if coefficient on variable is positive std::map::iterator itv = msum.find(pv); @@ -826,8 +822,8 @@ CegTermType ArithInstantiator::solve_arith(CegInstantiator* ci, { vts_coeff[t] = nm->mkNode( MULT, - nm->mkConst(CONST_RATIONAL, - Rational(-1) / itv->second.getConst()), + nm->mkConstReal(Rational(-1) + / itv->second.getConst()), vts_coeff[t]); vts_coeff[t] = rewrite(vts_coeff[t]); } @@ -887,7 +883,7 @@ CegTermType ArithInstantiator::solve_arith(CegInstantiator* ci, } } // multiply everything by this coefficient - Node rcoeff = nm->mkConst(CONST_RATIONAL, Rational(coeff)); + Node rcoeff = nm->mkConstInt(Rational(coeff)); std::vector real_part; for (std::map::iterator it = msum.begin(); it != msum.end(); ++it) diff --git a/src/theory/quantifiers/cegqi/ceg_instantiator.cpp b/src/theory/quantifiers/cegqi/ceg_instantiator.cpp index 81ae18f4f..45ac899e1 100644 --- a/src/theory/quantifiers/cegqi/ceg_instantiator.cpp +++ b/src/theory/quantifiers/cegqi/ceg_instantiator.cpp @@ -20,6 +20,7 @@ #include "expr/node_algorithm.h" #include "options/quantifiers_options.h" #include "theory/arith/arith_msum.h" +#include "theory/arith/arith_utilities.h" #include "theory/quantifiers/cegqi/ceg_arith_instantiator.h" #include "theory/quantifiers/cegqi/ceg_bv_instantiator.h" #include "theory/quantifiers/cegqi/ceg_dt_instantiator.h" @@ -137,10 +138,7 @@ void TermProperties::composeProperty(TermProperties& p) } else { - NodeManager* nm = NodeManager::currentNM(); - d_coeff = nm->mkConst(CONST_RATIONAL, - Rational(d_coeff.getConst() - * p.d_coeff.getConst())); + d_coeff = arith::multConstants(d_coeff, p.d_coeff); } } @@ -163,12 +161,7 @@ void SolvedForm::push_back(Node pv, Node n, TermProperties& pv_prop) } else { - Assert(new_theta.isConst()); - Assert(pv_prop.d_coeff.isConst()); - NodeManager* nm = NodeManager::currentNM(); - new_theta = nm->mkConst(CONST_RATIONAL, - Rational(new_theta.getConst() - * pv_prop.d_coeff.getConst())); + new_theta = arith::multConstants(new_theta, pv_prop.d_coeff); } d_theta.push_back(new_theta); } diff --git a/src/theory/quantifiers/ematching/inst_match_generator.cpp b/src/theory/quantifiers/ematching/inst_match_generator.cpp index e3dd246a7..ab4bbc91b 100644 --- a/src/theory/quantifiers/ematching/inst_match_generator.cpp +++ b/src/theory/quantifiers/ematching/inst_match_generator.cpp @@ -364,8 +364,8 @@ int InstMatchGenerator::getMatch(Node f, Node t, InstMatch& m) { if (pat.getKind() == GT) { - t_match = - nm->mkNode(MINUS, t, nm->mkConst(CONST_RATIONAL, Rational(1))); + t_match = nm->mkNode( + MINUS, t, nm->mkConstRealOrInt(t.getType(), Rational(1))); }else{ t_match = t; } @@ -374,20 +374,21 @@ int InstMatchGenerator::getMatch(Node f, Node t, InstMatch& m) { if (pat.getKind() == EQUAL) { - if (t.getType().isBoolean()) + TypeNode tn = t.getType(); + if (tn.isBoolean()) { t_match = nm->mkConst(!d_qstate.areEqual(nm->mkConst(true), t)); } else { - Assert(t.getType().isRealOrInt()); - t_match = - nm->mkNode(PLUS, t, nm->mkConst(CONST_RATIONAL, Rational(1))); + Assert(tn.isRealOrInt()); + t_match = nm->mkNode(PLUS, t, nm->mkConstRealOrInt(tn, Rational(1))); } } else if (pat.getKind() == GEQ) { - t_match = nm->mkNode(PLUS, t, nm->mkConst(CONST_RATIONAL, Rational(1))); + t_match = + nm->mkNode(PLUS, t, nm->mkConstRealOrInt(t.getType(), Rational(1))); } else if (pat.getKind() == GT) { diff --git a/src/theory/quantifiers/ematching/relational_match_generator.cpp b/src/theory/quantifiers/ematching/relational_match_generator.cpp index 5cf9079e8..6ec3334cd 100644 --- a/src/theory/quantifiers/ematching/relational_match_generator.cpp +++ b/src/theory/quantifiers/ematching/relational_match_generator.cpp @@ -97,7 +97,7 @@ int RelationalMatchGenerator::getNextMatch(Node q, InstMatch& m) s = nm->mkNode( PLUS, s, - nm->mkConst(CONST_RATIONAL, Rational(d_rel == GEQ ? -1 : 1))); + nm->mkConstRealOrInt(s.getType(), Rational(d_rel == GEQ ? -1 : 1))); } d_counter++; Trace("relational-match-gen")