From: Andrew Reynolds Date: Mon, 2 May 2022 19:11:22 +0000 (-0500) Subject: Further refactoring in preparation for CONST_INTEGER (#8687) X-Git-Tag: cvc5-1.0.1~192 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=c02a35e96ed55cebe5a8ec18159301926153ba56;p=cvc5.git Further refactoring in preparation for CONST_INTEGER (#8687) Miscellaneous refactorings from trying to enable CONST_INTEGER. --- diff --git a/src/expr/node_manager_template.cpp b/src/expr/node_manager_template.cpp index 24e1f01f3..c27679b4f 100644 --- a/src/expr/node_manager_template.cpp +++ b/src/expr/node_manager_template.cpp @@ -1314,6 +1314,15 @@ Node NodeManager::mkConstInt(const Rational& r) return mkConst(kind::CONST_RATIONAL, r); } +Node NodeManager::mkConstRealOrInt(const Rational& r) +{ + if (r.isIntegral()) + { + return mkConstInt(r); + } + return mkConstReal(r); +} + Node NodeManager::mkConstRealOrInt(const TypeNode& tn, const Rational& r) { Assert(tn.isRealOrInt()) << "Expected real or int for mkConstRealOrInt, got " @@ -1329,7 +1338,8 @@ Node NodeManager::mkRealAlgebraicNumber(const RealAlgebraicNumber& ran) { if (ran.isRational()) { - return mkConstReal(ran.toRational()); + // may generate an integer it is it integral + return mkConstRealOrInt(ran.toRational()); } // Creating this node may refine the ran to the point where isRational returns // true @@ -1341,7 +1351,8 @@ Node NodeManager::mkRealAlgebraicNumber(const RealAlgebraicNumber& ran) const RealAlgebraicNumber& cur = inner.getConst(); if (cur.isRational()) { - return mkConstReal(cur.toRational()); + // may generate an integer it is it integral + return mkConstRealOrInt(cur.toRational()); } if (cur == ran) break; inner = mkConst(Kind::REAL_ALGEBRAIC_NUMBER_OP, cur); diff --git a/src/expr/node_manager_template.h b/src/expr/node_manager_template.h index 678728c78..e10f448b4 100644 --- a/src/expr/node_manager_template.h +++ b/src/expr/node_manager_template.h @@ -686,6 +686,12 @@ class NodeManager */ Node mkConstInt(const Rational& r); + /** + * Make constant real or int, which calls one of the above methods based + * on whether r is integral. + */ + Node mkConstRealOrInt(const Rational& r); + /** * Make constant real or int, which calls one of the above methods based * on the type tn. diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp index 128b2e84c..804e476fb 100644 --- a/src/theory/arith/arith_rewriter.cpp +++ b/src/theory/arith/arith_rewriter.cpp @@ -899,6 +899,12 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) Trace("arith-tf-rewrite") << "Rewrite transcendental function : " << t << std::endl; NodeManager* nm = NodeManager::currentNM(); + if (t[0].getKind() == TO_REAL) + { + // always strip TO_REAL from argument. + Node ret = nm->mkNode(t.getKind(), t[0][0]); + return RewriteResponse(REWRITE_AGAIN, ret); + } switch (t.getKind()) { case kind::EXPONENTIAL: @@ -1019,7 +1025,6 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) { new_arg = nm->mkNode(kind::ADD, new_arg, rem); } - new_arg = ensureReal(new_arg); // sin( 2*n*PI + x ) = sin( x ) return RewriteResponse(REWRITE_AGAIN_FULL, nm->mkNode(kind::SINE, new_arg)); @@ -1049,8 +1054,8 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) if (r_abs.getDenominator() == two) { Assert(r_abs.getNumerator() == one); - return RewriteResponse( - REWRITE_DONE, ensureReal(nm->mkConstReal(Rational(r.sgn())))); + return RewriteResponse(REWRITE_DONE, + nm->mkConstReal(Rational(r.sgn()))); } else if (r_abs.getDenominator() == six) { diff --git a/src/theory/arith/linear/infer_bounds.cpp b/src/theory/arith/linear/infer_bounds.cpp index ec2843aa2..8666921f5 100644 --- a/src/theory/arith/linear/infer_bounds.cpp +++ b/src/theory/arith/linear/infer_bounds.cpp @@ -151,7 +151,7 @@ Node InferBoundsResult::getTerm() const { return d_term; } Node InferBoundsResult::getLiteral() const{ const Rational& q = getValue().getNoninfinitesimalPart(); NodeManager* nm = NodeManager::currentNM(); - Node qnode = nm->mkConst(CONST_RATIONAL, q); + Node qnode = nm->mkConstReal(q); Kind k; if(d_upperBound){ diff --git a/src/theory/arith/linear/normal_form.cpp b/src/theory/arith/linear/normal_form.cpp index 81bb23833..ecbf10b1b 100644 --- a/src/theory/arith/linear/normal_form.cpp +++ b/src/theory/arith/linear/normal_form.cpp @@ -220,13 +220,16 @@ VarList VarList::operator*(const VarList& other) const { } bool Monomial::isMember(TNode n){ - if(n.getKind() == kind::CONST_RATIONAL) { + Kind k = n.getKind(); + if (k == kind::CONST_RATIONAL || k == kind::CONST_INTEGER) + { return true; - } else if(multStructured(n)) { + } + else if (multStructured(n)) + { return VarList::isMember(n[1]); - } else { - return VarList::isMember(n); } + return VarList::isMember(n); } Monomial Monomial::mkMonomial(const Constant& c, const VarList& vl) { @@ -249,13 +252,16 @@ Monomial Monomial::mkMonomial(const VarList& vl) { } Monomial Monomial::parseMonomial(Node n) { - if(n.getKind() == kind::CONST_RATIONAL) { + Kind k = n.getKind(); + if (k == kind::CONST_RATIONAL || k == kind::CONST_INTEGER) + { return Monomial(Constant(n)); - } else if(multStructured(n)) { + } + else if (multStructured(n)) + { return Monomial::mkMonomial(Constant(n[0]),VarList::parseVarList(n[1])); - } else { - return Monomial(VarList::parseVarList(n)); } + return Monomial(VarList::parseVarList(n)); } Monomial Monomial::operator*(const Rational& q) const { if(q.isZero()){ @@ -699,7 +705,10 @@ SumPair SumPair::mkSumPair(const Polynomial& p){ } } -Comparison::Comparison(TNode n) : NodeWrapper(n) { Assert(isNormalForm()); } +Comparison::Comparison(TNode n) : NodeWrapper(n) +{ + Assert(isNormalForm()) << "Bad comparison normal form: " << n; +} SumPair Comparison::toSumPair() const { Kind cmpKind = comparisonKind(); @@ -719,8 +728,8 @@ SumPair Comparison::toSumPair() const { return SumPair(-p, c); } } - case kind::EQUAL: - case kind::DISTINCT: + case kind::EQUAL: + case kind::DISTINCT: { Polynomial left = getLeft(); Polynomial right = getRight(); @@ -758,8 +767,8 @@ Polynomial Comparison::normalizedVariablePart() const { return -p; } } - case kind::EQUAL: - case kind::DISTINCT: + case kind::EQUAL: + case kind::DISTINCT: { Polynomial left = getLeft(); Polynomial right = getRight(); @@ -798,8 +807,8 @@ DeltaRational Comparison::normalizedDeltaRational() const { return DeltaRational(-q, -delta); } } - case kind::EQUAL: - case kind::DISTINCT: + case kind::EQUAL: + case kind::DISTINCT: { Polynomial right = getRight(); Monomial firstRight = right.getHead(); @@ -914,19 +923,20 @@ Node Comparison::toNode(Kind k, const Polynomial& l, const Polynomial& r) { return toNode(kind::GEQ, r, l).notNode(); case kind::LT: return toNode(kind::GT, r, l).notNode(); - case kind::DISTINCT: - return toNode(kind::EQUAL, r, l).notNode(); + case kind::DISTINCT: return toNode(kind::EQUAL, r, l).notNode(); default: Unreachable(); } } bool Comparison::rightIsConstant() const { + Kind k; if(getNode().getKind() == kind::NOT){ - return getNode()[0][1].getKind() == kind::CONST_RATIONAL; + k = getNode()[0][1].getKind(); }else{ - return getNode()[1].getKind() == kind::CONST_RATIONAL; + k = getNode()[1].getKind(); } + return k == kind::CONST_RATIONAL || k == kind::CONST_INTEGER; } size_t Comparison::getComplexity() const{ @@ -1004,8 +1014,7 @@ bool Comparison::isNormalForm() const { return isNormalGT(); case kind::GEQ: return isNormalGEQ(); - case kind::EQUAL: - return isNormalEquality(); + case kind::EQUAL: return isNormalEquality(); case kind::LT: return isNormalLT(); case kind::LEQ: @@ -1304,7 +1313,9 @@ Node Comparison::mkIntEquality(const Polynomial& p){ Comparison Comparison::mkComparison(Kind k, const Polynomial& l, const Polynomial& r){ //Make this special case fast for sharing! - if((k == kind::EQUAL || k == kind::DISTINCT) && l.isVarList() && r.isVarList()){ + if ((k == kind::EQUAL || k == kind::DISTINCT) && l.isVarList() + && r.isVarList()) + { VarList vLeft = l.asVarList(); VarList vRight = r.asVarList(); @@ -1312,7 +1323,8 @@ Comparison Comparison::mkComparison(Kind k, const Polynomial& l, const Polynomia // return true for equalities and false for disequalities return Comparison(k == kind::EQUAL); }else{ - Node eqNode = vLeft < vRight ? toNode( kind::EQUAL, l, r) : toNode( kind::EQUAL, r, l); + Node eqNode = vLeft < vRight ? toNode(kind::EQUAL, l, r) + : toNode(kind::EQUAL, r, l); Node forK = (k == kind::DISTINCT) ? eqNode.notNode() : eqNode; return Comparison(forK); } @@ -1327,10 +1339,10 @@ Comparison Comparison::mkComparison(Kind k, const Polynomial& l, const Polynomia Node result = Node::null(); bool isInteger = diff.allIntegralVariables(); switch(k){ - case kind::EQUAL: - result = isInteger ? mkIntEquality(diff) : mkRatEquality(diff); - break; - case kind::DISTINCT: + case kind::EQUAL: + result = isInteger ? mkIntEquality(diff) : mkRatEquality(diff); + break; + case kind::DISTINCT: { Node eq = isInteger ? mkIntEquality(diff) : mkRatEquality(diff); result = eq.notNode(); @@ -1377,8 +1389,7 @@ Kind Comparison::comparisonKind(TNode literal){ case kind::CONST_BOOLEAN: case kind::GT: case kind::GEQ: - case kind::EQUAL: - return literal.getKind(); + case kind::EQUAL: return literal.getKind(); case kind::NOT: { TNode negatedAtom = literal[0]; @@ -1387,8 +1398,7 @@ Kind Comparison::comparisonKind(TNode literal){ return kind::LEQ; case kind::GEQ: //(not (GEQ x c)) <=> (LT x c) return kind::LT; - case kind::EQUAL: - return kind::DISTINCT; + case kind::EQUAL: return kind::DISTINCT; default: return kind::UNDEFINED_KIND; } diff --git a/src/theory/arith/linear/normal_form.h b/src/theory/arith/linear/normal_form.h index 9656e2876..c9f8eb72b 100644 --- a/src/theory/arith/linear/normal_form.h +++ b/src/theory/arith/linear/normal_form.h @@ -231,6 +231,7 @@ public: Kind k = n.getKind(); switch (k) { + case kind::CONST_INTEGER: case kind::CONST_RATIONAL: return false; case kind::INTS_DIVISION: case kind::INTS_MODULUS: @@ -347,13 +348,18 @@ class Constant : public NodeWrapper { public: Constant(Node n) : NodeWrapper(n) { Assert(isMember(getNode())); } - static bool isMember(Node n) { return n.getKind() == kind::CONST_RATIONAL; } + static bool isMember(Node n) + { + Kind k = n.getKind(); + return k == kind::CONST_RATIONAL || k == kind::CONST_INTEGER; + } bool isNormalForm() { return isMember(getNode()); } static Constant mkConstant(Node n) { - Assert(n.getKind() == kind::CONST_RATIONAL); + Assert(n.getKind() == kind::CONST_RATIONAL + || n.getKind() == kind::CONST_INTEGER); return Constant(n); } @@ -633,9 +639,8 @@ private: } static bool multStructured(Node n) { - return n.getKind() == kind::MULT && - n[0].getKind() == kind::CONST_RATIONAL && - n.getNumChildren() == 2; + return n.getKind() == kind::MULT && n[0].isConst() + && n.getNumChildren() == 2; } Monomial(const Constant& c): @@ -794,7 +799,7 @@ private: bool d_singleton; Polynomial(TNode n) : NodeWrapper(n), d_singleton(Monomial::isMember(n)) { - Assert(isMember(getNode())); + Assert(isMember(getNode())) << "Bad polynomial member " << n; } static Node makePlusNode(const std::vector& m) { diff --git a/src/theory/arith/rewriter/node_utils.h b/src/theory/arith/rewriter/node_utils.h index bd6bff961..54f8a1696 100644 --- a/src/theory/arith/rewriter/node_utils.h +++ b/src/theory/arith/rewriter/node_utils.h @@ -91,15 +91,7 @@ inline Node mkConst(const Integer& value) { return NodeManager::currentNM()->mkConstInt(value); } -/** Create an integer or rational constant node */ -inline Node mkConst(const Rational& value) -{ - if (value.isIntegral()) - { - return NodeManager::currentNM()->mkConstInt(value); - } - return NodeManager::currentNM()->mkConstReal(value); -} + /** Create a real algebraic number node */ inline Node mkConst(const RealAlgebraicNumber& value) { diff --git a/src/theory/arith/theory_arith_type_rules.cpp b/src/theory/arith/theory_arith_type_rules.cpp index eaf3315cb..31dede3a9 100644 --- a/src/theory/arith/theory_arith_type_rules.cpp +++ b/src/theory/arith/theory_arith_type_rules.cpp @@ -30,10 +30,7 @@ TypeNode ArithConstantTypeRule::computeType(NodeManager* nodeManager, { return nodeManager->integerType(); } - else - { - return nodeManager->realType(); - } + return nodeManager->realType(); } TypeNode ArithRealAlgebraicNumberOpTypeRule::computeType( diff --git a/src/theory/quantifiers/instantiate.cpp b/src/theory/quantifiers/instantiate.cpp index 23abba94a..bc0b2ba97 100644 --- a/src/theory/quantifiers/instantiate.cpp +++ b/src/theory/quantifiers/instantiate.cpp @@ -152,7 +152,7 @@ bool Instantiate::addInstantiation(Node q, << terms[i] << std::endl; bad_inst = true; } - else if (!terms[i].getType().isSubtypeOf(q[0][i].getType())) + else if (terms[i].getType() != q[0][i].getType()) { Trace("inst") << "***& inst bad type : " << terms[i] << " " << terms[i].getType() << "/" << q[0][i].getType() @@ -761,7 +761,7 @@ Node Instantiate::ensureType(Node n, TypeNode tn) Trace("inst-add-debug2") << "Ensure " << n << " : " << tn << std::endl; TypeNode ntn = n.getType(); Assert(ntn.isComparableTo(tn)); - if (ntn.isSubtypeOf(tn)) + if (ntn == tn) { return n; } @@ -769,6 +769,10 @@ Node Instantiate::ensureType(Node n, TypeNode tn) { return NodeManager::currentNM()->mkNode(TO_INTEGER, n); } + else if (tn.isReal()) + { + return NodeManager::currentNM()->mkNode(TO_REAL, n); + } return Node::null(); }