From 32347c2043d60dc83cd2a5675d3f7796a42022a2 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Mon, 20 Dec 2021 13:10:30 -0600 Subject: [PATCH] Eliminating some uses of const rational in arithmetic (#7846) Note that there are several nested dependencies in arithmetic for constructing constants Constant::mkConstant ---> mkRationalNode ---> mkConst(CONST_RATIONAL, r) This starts to disambiguate these calls. --- .../passes/pseudo_boolean_processor.cpp | 3 +- src/theory/arith/arith_rewriter.cpp | 133 +++++++++--------- src/theory/arith/branch_and_bound.cpp | 12 +- src/theory/arith/dio_solver.cpp | 6 +- .../arith/nl/transcendental/sine_solver.cpp | 5 +- src/theory/arith/operator_elim.cpp | 4 +- 6 files changed, 84 insertions(+), 79 deletions(-) diff --git a/src/preprocessing/passes/pseudo_boolean_processor.cpp b/src/preprocessing/passes/pseudo_boolean_processor.cpp index 0e7ac9c79..eae1d00fd 100644 --- a/src/preprocessing/passes/pseudo_boolean_processor.cpp +++ b/src/preprocessing/passes/pseudo_boolean_processor.cpp @@ -301,7 +301,8 @@ void PseudoBooleanProcessor::learn(Node assertion) Node PseudoBooleanProcessor::mkGeqOne(Node v) { NodeManager* nm = NodeManager::currentNM(); - return nm->mkNode(kind::GEQ, v, mkRationalNode(Rational(1))); + return nm->mkNode( + kind::GEQ, v, nm->mkConstRealOrInt(v.getType(), Rational(1))); } void PseudoBooleanProcessor::learn(const std::vector& assertions) diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp index af6f23c1f..0268a9eb1 100644 --- a/src/theory/arith/arith_rewriter.cpp +++ b/src/theory/arith/arith_rewriter.cpp @@ -49,7 +49,7 @@ bool ArithRewriter::isAtom(TNode n) { RewriteResponse ArithRewriter::rewriteConstant(TNode t){ Assert(t.isConst()); - Assert(t.getKind() == kind::CONST_RATIONAL); + Assert(t.getKind() == CONST_RATIONAL || t.getKind() == CONST_INTEGER); return RewriteResponse(REWRITE_DONE, t); } @@ -66,7 +66,8 @@ RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){ if(pre){ if(t[0] == t[1]){ Rational zero(0); - Node zeroNode = mkRationalNode(zero); + Node zeroNode = + NodeManager::currentNM()->mkConstRealOrInt(t.getType(), zero); return RewriteResponse(REWRITE_DONE, zeroNode); }else{ Node noMinus = makeSubtractionNode(t[0],t[1]); @@ -83,9 +84,12 @@ RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){ RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){ Assert(t.getKind() == kind::UMINUS); - if(t[0].getKind() == kind::CONST_RATIONAL){ + if (t[0].isConst()) + { Rational neg = -(t[0].getConst()); - return RewriteResponse(REWRITE_DONE, mkRationalNode(neg)); + NodeManager* nm = NodeManager::currentNM(); + return RewriteResponse(REWRITE_DONE, + nm->mkConstRealOrInt(t[0].getType(), neg)); } Node noUminus = makeUnaryMinusNode(t[0]); @@ -142,7 +146,7 @@ RewriteResponse ArithRewriter::preRewriteTerm(TNode t){ } else { return RewriteResponse( REWRITE_DONE, - NodeManager::currentNM()->mkConst(CONST_RATIONAL, -rat)); + NodeManager::currentNM()->mkConstRealOrInt(t[0].getType(), -rat)); } } return RewriteResponse(REWRITE_DONE, t); @@ -208,7 +212,7 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){ } else { return RewriteResponse( REWRITE_DONE, - NodeManager::currentNM()->mkConst(CONST_RATIONAL, -rat)); + NodeManager::currentNM()->mkConstRealOrInt(t[0].getType(), -rat)); } } return RewriteResponse(REWRITE_DONE, t); @@ -217,7 +221,8 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){ case kind::TO_INTEGER: return rewriteExtIntegerOp(t); case kind::POW: { - if(t[1].getKind() == kind::CONST_RATIONAL){ + if (t[1].isConst()) + { const Rational& exp = t[1].getConst(); TNode base = t[0]; if(exp.sgn() == 0){ @@ -241,8 +246,9 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){ } } } - else if (t[0].getKind() == kind::CONST_RATIONAL - && t[0].getConst().getNumerator().toUnsignedInt() == 2) + else if (t[0].isConst() + && t[0].getConst().getNumerator().toUnsignedInt() + == 2) { return RewriteResponse( REWRITE_DONE, NodeManager::currentNM()->mkNode(kind::POW2, t[1])); @@ -270,19 +276,20 @@ RewriteResponse ArithRewriter::preRewriteMult(TNode t){ Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT); if(t.getNumChildren() == 2){ - if(t[0].getKind() == kind::CONST_RATIONAL - && t[0].getConst().isOne()){ + if (t[0].isConst() && t[0].getConst().isOne()) + { return RewriteResponse(REWRITE_DONE, t[1]); } - if(t[1].getKind() == kind::CONST_RATIONAL - && t[1].getConst().isOne()){ + if (t[1].isConst() && t[1].getConst().isOne()) + { return RewriteResponse(REWRITE_DONE, t[0]); } } // Rewrite multiplications with a 0 argument and to 0 for(TNode::iterator i = t.begin(); i != t.end(); ++i) { - if((*i).getKind() == kind::CONST_RATIONAL) { + if ((*i).isConst()) + { if((*i).getConst().isZero()) { TNode zero = (*i); return RewriteResponse(REWRITE_DONE, zero); @@ -387,13 +394,10 @@ RewriteResponse ArithRewriter::postRewritePow2(TNode t) Integer i = t[0].getConst().getNumerator(); if (i < 0) { - return RewriteResponse( - REWRITE_DONE, - nm->mkConst(CONST_RATIONAL, Rational(Integer(0), Integer(1)))); + return RewriteResponse(REWRITE_DONE, nm->mkConstInt(Rational(0))); } unsigned long k = i.getUnsignedLong(); - Node ret = - nm->mkConst(CONST_RATIONAL, Rational(Integer(2).pow(k), Integer(1))); + Node ret = nm->mkConstInt(Rational(Integer(2).pow(k))); return RewriteResponse(REWRITE_DONE, ret); } return RewriteResponse(REWRITE_DONE, t); @@ -455,8 +459,9 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) { NodeManager* nm = NodeManager::currentNM(); switch( t.getKind() ){ case kind::EXPONENTIAL: { - if(t[0].getKind() == kind::CONST_RATIONAL){ - Node one = nm->mkConst(CONST_RATIONAL, Rational(1)); + if (t[0].isConst()) + { + Node one = nm->mkConstReal(Rational(1)); if(t[0].getConst().sgn()>=0 && t[0].getType().isInteger() && t[0]!=one){ return RewriteResponse( REWRITE_AGAIN, @@ -480,17 +485,16 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) { } break; case kind::SINE: - if(t[0].getKind() == kind::CONST_RATIONAL){ + if (t[0].isConst()) + { const Rational& rat = t[0].getConst(); if(rat.sgn() == 0){ - return RewriteResponse(REWRITE_DONE, - nm->mkConst(CONST_RATIONAL, Rational(0))); + return RewriteResponse(REWRITE_DONE, nm->mkConstReal(Rational(0))); } else if (rat.sgn() == -1) { - Node ret = nm->mkNode( - kind::UMINUS, - nm->mkNode(kind::SINE, nm->mkConst(CONST_RATIONAL, -rat))); + Node ret = nm->mkNode(kind::UMINUS, + nm->mkNode(kind::SINE, nm->mkConstReal(-rat))); return RewriteResponse(REWRITE_AGAIN_FULL, ret); } }else{ @@ -507,7 +511,7 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) { { if (itm->second.isNull()) { - pi_factor = mkRationalNode(Rational(1)); + pi_factor = nm->mkConstReal(Rational(1)); } else { @@ -564,7 +568,7 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) { // sin( PI + x ) = -sin( x ) if (rem.isNull()) { - return RewriteResponse(REWRITE_DONE, mkRationalNode(Rational(0))); + return RewriteResponse(REWRITE_DONE, nm->mkConstReal(Rational(0))); } else { @@ -584,7 +588,7 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) { { Assert(r_abs.getNumerator() == one); return RewriteResponse(REWRITE_DONE, - mkRationalNode(Rational(r.sgn()))); + nm->mkConstReal(Rational(r.sgn()))); } else if (r_abs.getDenominator() == six) { @@ -593,7 +597,7 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) { { return RewriteResponse( REWRITE_DONE, - mkRationalNode(Rational(r.sgn()) / Rational(2))); + nm->mkConstReal(Rational(r.sgn()) / Rational(2))); } } } @@ -603,13 +607,13 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) { case kind::COSINE: { return RewriteResponse( REWRITE_AGAIN_FULL, - nm->mkNode(kind::SINE, - nm->mkNode(kind::MINUS, - nm->mkNode(kind::MULT, - nm->mkConst(CONST_RATIONAL, - Rational(1) / Rational(2)), - mkPi()), - t[0]))); + nm->mkNode( + kind::SINE, + nm->mkNode(kind::MINUS, + nm->mkNode(kind::MULT, + nm->mkConstReal(Rational(1) / Rational(2)), + mkPi()), + t[0]))); } break; case kind::TANGENT: @@ -624,7 +628,7 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) { { return RewriteResponse(REWRITE_AGAIN_FULL, nm->mkNode(kind::DIVISION, - mkRationalNode(Rational(1)), + nm->mkConstReal(Rational(1)), nm->mkNode(kind::SINE, t[0]))); } break; @@ -632,7 +636,7 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) { { return RewriteResponse(REWRITE_AGAIN_FULL, nm->mkNode(kind::DIVISION, - mkRationalNode(Rational(1)), + nm->mkConstReal(Rational(1)), nm->mkNode(kind::COSINE, t[0]))); } break; @@ -660,17 +664,15 @@ RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){ if(atom.getOperator().getConst().k.isOne()) { return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true)); } + NodeManager* nm = NodeManager::currentNM(); return RewriteResponse( REWRITE_AGAIN, - NodeManager::currentNM()->mkNode( - kind::EQUAL, - NodeManager::currentNM()->mkNode( - kind::INTS_MODULUS_TOTAL, - atom[0], - NodeManager::currentNM()->mkConst( - CONST_RATIONAL, - Rational(atom.getOperator().getConst().k))), - NodeManager::currentNM()->mkConst(CONST_RATIONAL, Rational(0)))); + nm->mkNode(kind::EQUAL, + nm->mkNode(kind::INTS_MODULUS_TOTAL, + atom[0], + nm->mkConstInt(Rational( + atom.getOperator().getConst().k))), + nm->mkConstInt(Rational(0)))); } // left |><| right @@ -747,8 +749,9 @@ RewriteResponse ArithRewriter::preRewrite(TNode t){ } Node ArithRewriter::makeUnaryMinusNode(TNode n){ + NodeManager* nm = NodeManager::currentNM(); Rational qNegOne(-1); - return NodeManager::currentNM()->mkNode(kind::MULT, mkRationalNode(qNegOne),n); + return nm->mkNode(kind::MULT, nm->mkConstRealOrInt(n.getType(), qNegOne), n); } Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){ @@ -763,12 +766,14 @@ RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){ Node left = t[0]; Node right = t[1]; - if(right.getKind() == kind::CONST_RATIONAL){ + if (right.isConst()) + { + NodeManager* nm = NodeManager::currentNM(); const Rational& den = right.getConst(); if(den.isZero()){ if(t.getKind() == kind::DIVISION_TOTAL){ - return RewriteResponse(REWRITE_DONE, mkRationalNode(0)); + return RewriteResponse(REWRITE_DONE, nm->mkConstReal(0)); }else{ // This is unsupported, but this is not a good place to complain return RewriteResponse(REWRITE_DONE, t); @@ -776,16 +781,17 @@ RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){ } Assert(den != Rational(0)); - if(left.getKind() == kind::CONST_RATIONAL){ + if (left.isConst()) + { const Rational& num = left.getConst(); Rational div = num / den; - Node result = mkRationalNode(div); + Node result = nm->mkConstReal(div); return RewriteResponse(REWRITE_DONE, result); } Rational div = den.inverse(); - Node result = mkRationalNode(div); + Node result = nm->mkConstReal(div); Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result); if(pre){ @@ -793,16 +799,14 @@ RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){ }else{ return RewriteResponse(REWRITE_AGAIN, mult); } - }else{ - return RewriteResponse(REWRITE_DONE, t); } + return RewriteResponse(REWRITE_DONE, t); } RewriteResponse ArithRewriter::rewriteIntsDivMod(TNode t, bool pre) { NodeManager* nm = NodeManager::currentNM(); Kind k = t.getKind(); - Node zero = nm->mkConst(CONST_RATIONAL, Rational(0)); if (k == kind::INTS_MODULUS) { if (t[1].isConst() && !t[1].getConst().isZero()) @@ -867,10 +871,10 @@ RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre) Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL); TNode n = t[0]; TNode d = t[1]; - bool dIsConstant = d.getKind() == kind::CONST_RATIONAL; + bool dIsConstant = d.isConst(); if(dIsConstant && d.getConst().isZero()){ // (div x 0) ---> 0 or (mod x 0) ---> 0 - return returnRewrite(t, mkRationalNode(0), Rewrite::DIV_MOD_BY_ZERO); + return returnRewrite(t, nm->mkConstInt(0), Rewrite::DIV_MOD_BY_ZERO); }else if(dIsConstant && d.getConst().isOne()){ if (k == kind::INTS_MODULUS_TOTAL) { @@ -886,14 +890,13 @@ RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre) // pull negation // (div x (- c)) ---> (- (div x c)) // (mod x (- c)) ---> (mod x c) - Node nn = nm->mkNode( - k, t[0], nm->mkConst(CONST_RATIONAL, -t[1].getConst())); + Node nn = nm->mkNode(k, t[0], nm->mkConstInt(-t[1].getConst())); Node ret = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL) ? nm->mkNode(kind::UMINUS, nn) : nn; return returnRewrite(t, ret, Rewrite::DIV_MOD_PULL_NEG_DEN); } - else if (dIsConstant && n.getKind() == kind::CONST_RATIONAL) + else if (dIsConstant && n.isConst()) { Assert(d.getConst().isIntegral()); Assert(n.getConst().isIntegral()); @@ -907,7 +910,7 @@ RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre) // constant evaluation // (mod c1 c2) ---> c3 or (div c1 c2) ---> c3 - Node resultNode = mkRationalNode(Rational(result)); + Node resultNode = nm->mkConstInt(Rational(result)); return returnRewrite(t, resultNode, Rewrite::CONST_EVAL); } if (k == kind::INTS_MODULUS_TOTAL) @@ -953,7 +956,7 @@ RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre) if (t[0].getKind() == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1]) { // (div (mod x c) c) --> 0 - Node ret = mkRationalNode(0); + Node ret = nm->mkConstInt(0); return returnRewrite(t, ret, Rewrite::DIV_OVER_MOD); } } diff --git a/src/theory/arith/branch_and_bound.cpp b/src/theory/arith/branch_and_bound.cpp index eb02339bb..6d9a71722 100644 --- a/src/theory/arith/branch_and_bound.cpp +++ b/src/theory/arith/branch_and_bound.cpp @@ -58,10 +58,10 @@ TrustNode BranchAndBound::branchIntegerVariable(TNode var, Rational value) // Prioritize trying a simple rounding of the real solution first, // it that fails, fall back on original branch and bound strategy. - Node ub = rewrite(nm->mkNode(LEQ, var, mkRationalNode(nearest - 1))); - Node lb = rewrite(nm->mkNode(GEQ, var, mkRationalNode(nearest + 1))); + Node ub = rewrite(nm->mkNode(LEQ, var, nm->mkConstInt(nearest - 1))); + Node lb = rewrite(nm->mkNode(GEQ, var, nm->mkConstInt(nearest + 1))); Node right = nm->mkNode(OR, ub, lb); - Node rawEq = nm->mkNode(EQUAL, var, mkRationalNode(nearest)); + Node rawEq = nm->mkNode(EQUAL, var, nm->mkConstInt(nearest)); Node eq = rewrite(rawEq); // Also preprocess it before we send it out. This is important since // arithmetic may prefer eliminating equalities. @@ -78,8 +78,8 @@ TrustNode BranchAndBound::branchIntegerVariable(TNode var, Rational value) Trace("integers") << "l: " << l << std::endl; if (proofsEnabled()) { - Node less = nm->mkNode(LT, var, mkRationalNode(nearest)); - Node greater = nm->mkNode(GT, var, mkRationalNode(nearest)); + Node less = nm->mkNode(LT, var, nm->mkConstInt(nearest)); + Node greater = nm->mkNode(GT, var, nm->mkConstInt(nearest)); // TODO (project #37): justify. Thread proofs through *ensureLiteral*. Debug("integers::pf") << "less: " << less << std::endl; Debug("integers::pf") << "greater: " << greater << std::endl; @@ -119,7 +119,7 @@ TrustNode BranchAndBound::branchIntegerVariable(TNode var, Rational value) } else { - Node ub = rewrite(nm->mkNode(LEQ, var, mkRationalNode(floor))); + Node ub = rewrite(nm->mkNode(LEQ, var, nm->mkConstInt(floor))); Node lb = ub.notNode(); if (proofsEnabled()) { diff --git a/src/theory/arith/dio_solver.cpp b/src/theory/arith/dio_solver.cpp index 99dcc93ca..af3d8a692 100644 --- a/src/theory/arith/dio_solver.cpp +++ b/src/theory/arith/dio_solver.cpp @@ -820,8 +820,10 @@ void DioSolver::addTrailElementAsLemma(TrailIndex i) { Node DioSolver::trailIndexToEquality(TrailIndex i) const { const SumPair& sp = d_trail[i].d_eq; - Node zero = mkRationalNode(0); - Node eq = (sp.getNode()).eqNode(zero); + Node n = sp.getNode(); + Node zero = + NodeManager::currentNM()->mkConstRealOrInt(n.getType(), Rational(0)); + Node eq = n.eqNode(zero); return eq; } diff --git a/src/theory/arith/nl/transcendental/sine_solver.cpp b/src/theory/arith/nl/transcendental/sine_solver.cpp index 6c1bec647..d574a9572 100644 --- a/src/theory/arith/nl/transcendental/sine_solver.cpp +++ b/src/theory/arith/nl/transcendental/sine_solver.cpp @@ -45,10 +45,9 @@ namespace { */ inline Node mkValidPhase(TNode a, TNode pi) { + NodeManager* nm = NodeManager::currentNM(); return mkBounded( - NodeManager::currentNM()->mkNode(Kind::MULT, mkRationalNode(-1), pi), - a, - pi); + nm->mkNode(Kind::MULT, nm->mkConstReal(Rational(-1)), pi), a, pi); } } // namespace diff --git a/src/theory/arith/operator_elim.cpp b/src/theory/arith/operator_elim.cpp index 99f5621d6..05a83c81c 100644 --- a/src/theory/arith/operator_elim.cpp +++ b/src/theory/arith/operator_elim.cpp @@ -122,8 +122,8 @@ Node OperatorElim::eliminateOperators(Node node, // 0 <= node[0] - toIntSkolem < 1 Node v = bvm->mkBoundVar(node[0], nm->integerType()); - Node one = mkRationalNode(1); - Node zero = mkRationalNode(0); + Node one = nm->mkConstReal(Rational(1)); + Node zero = nm->mkConstReal(Rational(0)); Node diff = nm->mkNode(MINUS, node[0], v); Node lem = mkInRange(diff, zero, one); Node toIntSkolem = -- 2.30.2