-/********************* */
-/*! \file arith_rewriter.cpp
- ** \verbatim
- ** Top contributors (to current version):
- ** Tim King, Andrew Reynolds, Morgan Deters
- ** This file is part of the CVC4 project.
- ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS
- ** in the top-level source directory) and their institutional affiliations.
- ** All rights reserved. See the file COPYING in the top-level source
- ** directory for licensing information.\endverbatim
- **
- ** \brief [[ Add one-line brief description here ]]
- **
- ** [[ Add lengthier description here ]]
- ** \todo document this file
- **/
+/******************************************************************************
+ * Top contributors (to current version):
+ * Andrew Reynolds, Tim King, Morgan Deters
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved. See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * [[ Add one-line brief description here ]]
+ *
+ * [[ Add lengthier description here ]]
+ * \todo document this file
+ */
+#include "theory/arith/arith_rewriter.h"
+
+#include <optional>
#include <set>
+#include <sstream>
#include <stack>
#include <vector>
+#include "expr/algorithm/flatten.h"
#include "smt/logic_exception.h"
#include "theory/arith/arith_msum.h"
-#include "theory/arith/arith_rewriter.h"
#include "theory/arith/arith_utilities.h"
#include "theory/arith/normal_form.h"
+#include "theory/arith/operator_elim.h"
+#include "theory/arith/rewriter/addition.h"
+#include "theory/arith/rewriter/node_utils.h"
+#include "theory/arith/rewriter/ordering.h"
#include "theory/theory.h"
+#include "util/bitvector.h"
+#include "util/divisible.h"
+#include "util/iand.h"
+#include "util/real_algebraic_number.h"
+
+using namespace cvc5::kind;
-namespace CVC4 {
+namespace cvc5 {
namespace theory {
namespace arith {
+namespace {
+
+template <typename L, typename R>
+bool evaluateRelation(Kind rel, const L& l, const R& r)
+{
+ switch (rel)
+ {
+ case Kind::LT: return l < r;
+ case Kind::LEQ: return l <= r;
+ case Kind::EQUAL: return l == r;
+ case Kind::GEQ: return l >= r;
+ case Kind::GT: return l > r;
+ default: Unreachable(); return false;
+ }
+}
+
+/**
+ * Check whether the parent has a child that is a constant zero.
+ * If so, return this child. Otherwise, return std::nullopt.
+ */
+template <typename Iterable>
+std::optional<TNode> getZeroChild(const Iterable& parent)
+{
+ for (const auto& node : parent)
+ {
+ if (node.isConst() && node.template getConst<Rational>().isZero())
+ {
+ return node;
+ }
+ }
+ return std::nullopt;
+}
+
+} // namespace
+
+ArithRewriter::ArithRewriter(OperatorElim& oe) : d_opElim(oe) {}
+
+RewriteResponse ArithRewriter::preRewrite(TNode t)
+{
+ Trace("arith-rewriter") << "preRewrite(" << t << ")" << std::endl;
+ if (isAtom(t))
+ {
+ auto res = preRewriteAtom(t);
+ Trace("arith-rewriter")
+ << res.d_status << " -> " << res.d_node << std::endl;
+ return res;
+ }
+ auto res = preRewriteTerm(t);
+ Trace("arith-rewriter") << res.d_status << " -> " << res.d_node << std::endl;
+ return res;
+}
+
+RewriteResponse ArithRewriter::postRewrite(TNode t)
+{
+ Trace("arith-rewriter") << "postRewrite(" << t << ")" << std::endl;
+ if (isAtom(t))
+ {
+ auto res = postRewriteAtom(t);
+ Trace("arith-rewriter")
+ << res.d_status << " -> " << res.d_node << std::endl;
+ return res;
+ }
+ auto res = postRewriteTerm(t);
+ Trace("arith-rewriter") << res.d_status << " -> " << res.d_node << std::endl;
+ return res;
+}
+
+RewriteResponse ArithRewriter::preRewriteAtom(TNode atom)
+{
+ Assert(isAtom(atom));
+
+ NodeManager* nm = NodeManager::currentNM();
+
+ if (isRelationOperator(atom.getKind()) && atom[0] == atom[1])
+ {
+ switch (atom.getKind())
+ {
+ case Kind::LT: return RewriteResponse(REWRITE_DONE, nm->mkConst(false));
+ case Kind::LEQ: return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
+ case Kind::EQUAL: return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
+ case Kind::GEQ: return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
+ case Kind::GT: return RewriteResponse(REWRITE_DONE, nm->mkConst(false));
+ default:;
+ }
+ }
+
+ switch (atom.getKind())
+ {
+ case Kind::GT:
+ return RewriteResponse(REWRITE_DONE,
+ nm->mkNode(kind::LEQ, atom[0], atom[1]).notNode());
+ case Kind::LT:
+ return RewriteResponse(REWRITE_DONE,
+ nm->mkNode(kind::GEQ, atom[0], atom[1]).notNode());
+ case Kind::IS_INTEGER:
+ if (atom[0].getType().isInteger())
+ {
+ return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
+ }
+ break;
+ case Kind::DIVISIBLE:
+ if (atom.getOperator().getConst<Divisible>().k.isOne())
+ {
+ return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
+ }
+ break;
+ default:;
+ }
+
+ return RewriteResponse(REWRITE_DONE, atom);
+}
+
+RewriteResponse ArithRewriter::postRewriteAtom(TNode atom)
+{
+ Assert(isAtom(atom));
+ if (atom.getKind() == kind::IS_INTEGER)
+ {
+ return rewriteExtIntegerOp(atom);
+ }
+ else if (atom.getKind() == kind::DIVISIBLE)
+ {
+ if (atom[0].isConst())
+ {
+ return RewriteResponse(REWRITE_DONE,
+ NodeManager::currentNM()->mkConst(bool(
+ (atom[0].getConst<Rational>()
+ / atom.getOperator().getConst<Divisible>().k)
+ .isIntegral())));
+ }
+ if (atom.getOperator().getConst<Divisible>().k.isOne())
+ {
+ return RewriteResponse(REWRITE_DONE,
+ NodeManager::currentNM()->mkConst(true));
+ }
+ NodeManager* nm = NodeManager::currentNM();
+ return RewriteResponse(
+ REWRITE_AGAIN,
+ nm->mkNode(kind::EQUAL,
+ nm->mkNode(kind::INTS_MODULUS_TOTAL,
+ atom[0],
+ nm->mkConstInt(Rational(
+ atom.getOperator().getConst<Divisible>().k))),
+ nm->mkConstInt(Rational(0))));
+ }
+
+ // left |><| right
+ TNode left = atom[0];
+ TNode right = atom[1];
+
+ auto* nm = NodeManager::currentNM();
+ if (left.isConst())
+ {
+ const Rational& l = left.getConst<Rational>();
+ if (right.isConst())
+ {
+ const Rational& r = right.getConst<Rational>();
+ return RewriteResponse(
+ REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
+ }
+ else if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+ {
+ const RealAlgebraicNumber& r =
+ right.getOperator().getConst<RealAlgebraicNumber>();
+ return RewriteResponse(
+ REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
+ }
+ }
+ else if (left.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+ {
+ const RealAlgebraicNumber& l =
+ left.getOperator().getConst<RealAlgebraicNumber>();
+ if (right.isConst())
+ {
+ const Rational& r = right.getConst<Rational>();
+ return RewriteResponse(
+ REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
+ }
+ else if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+ {
+ const RealAlgebraicNumber& r =
+ right.getOperator().getConst<RealAlgebraicNumber>();
+ return RewriteResponse(
+ REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
+ }
+ }
+
+ Polynomial pleft = Polynomial::parsePolynomial(left);
+ Polynomial pright = Polynomial::parsePolynomial(right);
+
+ Debug("arith::rewriter") << "pleft " << pleft.getNode() << std::endl;
+ Debug("arith::rewriter") << "pright " << pright.getNode() << std::endl;
+
+ Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
+ Assert(cmp.isNormalForm());
+ return RewriteResponse(REWRITE_DONE, cmp.getNode());
+}
+
bool ArithRewriter::isAtom(TNode n) {
Kind k = n.getKind();
return arith::isRelationOperator(k) || k == kind::IS_INTEGER
RewriteResponse ArithRewriter::rewriteConstant(TNode t){
Assert(t.isConst());
- Assert(t.getKind() == kind::CONST_RATIONAL);
+ Assert(t.getKind() == CONST_RATIONAL || t.getKind() == CONST_INTEGER);
return RewriteResponse(REWRITE_DONE, t);
}
-RewriteResponse ArithRewriter::rewriteVariable(TNode t){
- Assert(t.isVar());
+RewriteResponse ArithRewriter::rewriteRAN(TNode t)
+{
+ Assert(t.getKind() == REAL_ALGEBRAIC_NUMBER);
+
+ const RealAlgebraicNumber& r =
+ t.getOperator().getConst<RealAlgebraicNumber>();
+ if (r.isRational())
+ {
+ return RewriteResponse(
+ REWRITE_DONE, NodeManager::currentNM()->mkConstReal(r.toRational()));
+ }
return RewriteResponse(REWRITE_DONE, t);
}
-RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
- Assert(t.getKind()== kind::MINUS);
+RewriteResponse ArithRewriter::rewriteVariable(TNode t)
+{
+ Assert(t.isVar());
- if(pre){
- if(t[0] == t[1]){
- Rational zero(0);
- Node zeroNode = mkRationalNode(zero);
- return RewriteResponse(REWRITE_DONE, zeroNode);
- }else{
- Node noMinus = makeSubtractionNode(t[0],t[1]);
- return RewriteResponse(REWRITE_DONE, noMinus);
- }
- }else{
- Polynomial minuend = Polynomial::parsePolynomial(t[0]);
- Polynomial subtrahend = Polynomial::parsePolynomial(t[1]);
- Polynomial diff = minuend - subtrahend;
- return RewriteResponse(REWRITE_DONE, diff.getNode());
- }
+ return RewriteResponse(REWRITE_DONE, t);
}
-RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
- Assert(t.getKind()== kind::UMINUS);
+RewriteResponse ArithRewriter::rewriteNeg(TNode t, bool pre)
+{
+ Assert(t.getKind() == kind::NEG);
- if(t[0].getKind() == kind::CONST_RATIONAL){
+ if (t[0].isConst())
+ {
Rational neg = -(t[0].getConst<Rational>());
- return RewriteResponse(REWRITE_DONE, mkRationalNode(neg));
+ return RewriteResponse(REWRITE_DONE, rewriter::mkConst(neg));
+ }
+ if (rewriter::isRAN(t[0]))
+ {
+ return RewriteResponse(REWRITE_DONE,
+ rewriter::mkConst(-rewriter::getRAN(t[0])));
}
- Node noUminus = makeUnaryMinusNode(t[0]);
- if(pre)
+ auto* nm = NodeManager::currentNM();
+ Node noUminus = nm->mkNode(kind::MULT, rewriter::mkConst(Integer(-1)), t[0]);
+ if (pre)
return RewriteResponse(REWRITE_DONE, noUminus);
else
return RewriteResponse(REWRITE_AGAIN, noUminus);
}
+RewriteResponse ArithRewriter::rewriteSub(TNode t)
+{
+ Assert(t.getKind() == kind::SUB);
+ Assert(t.getNumChildren() == 2);
+
+ if (t[0] == t[1])
+ {
+ return RewriteResponse(REWRITE_DONE, rewriter::mkConst(Integer(0)));
+ }
+ auto* nm = NodeManager::currentNM();
+ return RewriteResponse(
+ REWRITE_AGAIN_FULL,
+ nm->mkNode(Kind::ADD,
+ t[0],
+ nm->mkNode(kind::MULT, rewriter::mkConst(Integer(-1)), t[1])));
+}
+
RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
if(t.isConst()){
return rewriteConstant(t);
return rewriteVariable(t);
}else{
switch(Kind k = t.getKind()){
- case kind::MINUS:
- return rewriteMinus(t, true);
- case kind::UMINUS:
- return rewriteUMinus(t, true);
- case kind::DIVISION:
- case kind::DIVISION_TOTAL:
- return rewriteDiv(t,true);
- case kind::PLUS:
- return preRewritePlus(t);
- case kind::MULT:
- case kind::NONLINEAR_MULT:
- return preRewriteMult(t);
- case kind::EXPONENTIAL:
- case kind::SINE:
- case kind::COSINE:
- case kind::TANGENT:
- case kind::COSECANT:
- case kind::SECANT:
- case kind::COTANGENT:
- case kind::ARCSINE:
- case kind::ARCCOSINE:
- case kind::ARCTANGENT:
- case kind::ARCCOSECANT:
- case kind::ARCSECANT:
- case kind::ARCCOTANGENT:
- case kind::SQRT: return preRewriteTranscendental(t);
- case kind::INTS_DIVISION:
- case kind::INTS_MODULUS:
- return RewriteResponse(REWRITE_DONE, t);
- case kind::INTS_DIVISION_TOTAL:
- case kind::INTS_MODULUS_TOTAL:
- return rewriteIntsDivModTotal(t,true);
- case kind::ABS:
- if(t[0].isConst()) {
- const Rational& rat = t[0].getConst<Rational>();
- if(rat >= 0) {
- return RewriteResponse(REWRITE_DONE, t[0]);
- } else {
- return RewriteResponse(REWRITE_DONE,
- NodeManager::currentNM()->mkConst(-rat));
- }
- }
- return RewriteResponse(REWRITE_DONE, t);
- case kind::IS_INTEGER:
- case kind::TO_INTEGER:
- return RewriteResponse(REWRITE_DONE, t);
- case kind::TO_REAL:
- return RewriteResponse(REWRITE_DONE, t[0]);
- case kind::POW:
- return RewriteResponse(REWRITE_DONE, t);
- case kind::PI:
- return RewriteResponse(REWRITE_DONE, t);
- default:
- Unhandled(k);
+ case kind::REAL_ALGEBRAIC_NUMBER: return rewriteRAN(t);
+ case kind::SUB: return rewriteSub(t);
+ case kind::NEG: return rewriteNeg(t, true);
+ case kind::DIVISION:
+ case kind::DIVISION_TOTAL: return rewriteDiv(t, true);
+ case kind::ADD: return preRewritePlus(t);
+ case kind::MULT:
+ case kind::NONLINEAR_MULT: return preRewriteMult(t);
+ case kind::IAND: return RewriteResponse(REWRITE_DONE, t);
+ case kind::POW2: return RewriteResponse(REWRITE_DONE, t);
+ case kind::EXPONENTIAL:
+ case kind::SINE:
+ case kind::COSINE:
+ case kind::TANGENT:
+ case kind::COSECANT:
+ case kind::SECANT:
+ case kind::COTANGENT:
+ case kind::ARCSINE:
+ case kind::ARCCOSINE:
+ case kind::ARCTANGENT:
+ case kind::ARCCOSECANT:
+ case kind::ARCSECANT:
+ case kind::ARCCOTANGENT:
+ case kind::SQRT: return preRewriteTranscendental(t);
+ case kind::INTS_DIVISION:
+ case kind::INTS_MODULUS: return rewriteIntsDivMod(t, true);
+ case kind::INTS_DIVISION_TOTAL:
+ case kind::INTS_MODULUS_TOTAL: return rewriteIntsDivModTotal(t, true);
+ case kind::ABS: return rewriteAbs(t);
+ case kind::IS_INTEGER:
+ case kind::TO_INTEGER: return RewriteResponse(REWRITE_DONE, t);
+ case kind::TO_REAL:
+ case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
+ case kind::POW: return RewriteResponse(REWRITE_DONE, t);
+ case kind::PI: return RewriteResponse(REWRITE_DONE, t);
+ default: Unhandled() << k;
}
}
}
}else if(t.isVar()){
return rewriteVariable(t);
}else{
+ Trace("arith-rewriter") << "postRewriteTerm: " << t << std::endl;
switch(t.getKind()){
- case kind::MINUS:
- return rewriteMinus(t, false);
- case kind::UMINUS:
- return rewriteUMinus(t, false);
- case kind::DIVISION:
- case kind::DIVISION_TOTAL:
- return rewriteDiv(t, false);
- case kind::PLUS:
- return postRewritePlus(t);
- case kind::MULT:
- case kind::NONLINEAR_MULT:
- return postRewriteMult(t);
- case kind::EXPONENTIAL:
- case kind::SINE:
- case kind::COSINE:
- case kind::TANGENT:
- case kind::COSECANT:
- case kind::SECANT:
- case kind::COTANGENT:
- case kind::ARCSINE:
- case kind::ARCCOSINE:
- case kind::ARCTANGENT:
- case kind::ARCCOSECANT:
- case kind::ARCSECANT:
- case kind::ARCCOTANGENT:
- case kind::SQRT: return postRewriteTranscendental(t);
- case kind::INTS_DIVISION:
- case kind::INTS_MODULUS:
- return RewriteResponse(REWRITE_DONE, t);
- case kind::INTS_DIVISION_TOTAL:
- case kind::INTS_MODULUS_TOTAL:
- return rewriteIntsDivModTotal(t, false);
- case kind::ABS:
- if(t[0].isConst()) {
- const Rational& rat = t[0].getConst<Rational>();
- if(rat >= 0) {
- return RewriteResponse(REWRITE_DONE, t[0]);
- } else {
- return RewriteResponse(REWRITE_DONE,
- NodeManager::currentNM()->mkConst(-rat));
- }
- }
- case kind::TO_REAL:
- return RewriteResponse(REWRITE_DONE, t[0]);
- case kind::TO_INTEGER:
- if(t[0].isConst()) {
- return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(Rational(t[0].getConst<Rational>().floor())));
- }
- if(t[0].getType().isInteger()) {
- return RewriteResponse(REWRITE_DONE, t[0]);
- }
- //Unimplemented("TO_INTEGER, nonconstant");
- //return rewriteToInteger(t);
- return RewriteResponse(REWRITE_DONE, t);
- case kind::IS_INTEGER:
- if(t[0].isConst()) {
- return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(t[0].getConst<Rational>().getDenominator() == 1));
- }
- if(t[0].getType().isInteger()) {
- return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
- }
- //Unimplemented("IS_INTEGER, nonconstant");
- //return rewriteIsInteger(t);
- return RewriteResponse(REWRITE_DONE, t);
- case kind::POW:
+ case kind::REAL_ALGEBRAIC_NUMBER: return rewriteRAN(t);
+ case kind::SUB: return rewriteSub(t);
+ case kind::NEG: return rewriteNeg(t, false);
+ case kind::DIVISION:
+ case kind::DIVISION_TOTAL: return rewriteDiv(t, false);
+ case kind::ADD: return postRewritePlus(t);
+ case kind::MULT:
+ case kind::NONLINEAR_MULT: return postRewriteMult(t);
+ case kind::IAND: return postRewriteIAnd(t);
+ case kind::POW2: return postRewritePow2(t);
+ case kind::EXPONENTIAL:
+ case kind::SINE:
+ case kind::COSINE:
+ case kind::TANGENT:
+ case kind::COSECANT:
+ case kind::SECANT:
+ case kind::COTANGENT:
+ case kind::ARCSINE:
+ case kind::ARCCOSINE:
+ case kind::ARCTANGENT:
+ case kind::ARCCOSECANT:
+ case kind::ARCSECANT:
+ case kind::ARCCOTANGENT:
+ case kind::SQRT: return postRewriteTranscendental(t);
+ case kind::INTS_DIVISION:
+ case kind::INTS_MODULUS: return rewriteIntsDivMod(t, false);
+ case kind::INTS_DIVISION_TOTAL:
+ case kind::INTS_MODULUS_TOTAL: return rewriteIntsDivModTotal(t, false);
+ case kind::ABS: return rewriteAbs(t);
+ case kind::TO_REAL:
+ case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
+ case kind::TO_INTEGER: return rewriteExtIntegerOp(t);
+ case kind::POW:
{
- if(t[1].getKind() == kind::CONST_RATIONAL){
+ if (t[1].isConst())
+ {
const Rational& exp = t[1].getConst<Rational>();
TNode base = t[0];
if(exp.sgn() == 0){
- return RewriteResponse(REWRITE_DONE, mkRationalNode(Rational(1)));
+ return RewriteResponse(REWRITE_DONE,
+ NodeManager::currentNM()->mkConstRealOrInt(
+ t.getType(), Rational(1)));
}else if(exp.sgn() > 0 && exp.isIntegral()){
- CVC4::Rational r(INT_MAX);
- if( exp<r ){
+ cvc5::Rational r(expr::NodeValue::MAX_CHILDREN);
+ if (exp <= r)
+ {
unsigned num = exp.getNumerator().toUnsignedInt();
if( num==1 ){
return RewriteResponse(REWRITE_AGAIN, base);
}else{
- NodeBuilder<> nb(kind::MULT);
+ NodeBuilder nb(kind::MULT);
for(unsigned i=0; i < num; ++i){
nb << base;
}
}
}
}
+ else if (t[0].isConst()
+ && t[0].getConst<Rational>().getNumerator().toUnsignedInt()
+ == 2)
+ {
+ return RewriteResponse(
+ REWRITE_DONE, NodeManager::currentNM()->mkNode(kind::POW2, t[1]));
+ }
// Todo improve the exception thrown
std::stringstream ss;
- ss << "The POW(^) operator can only be used with a natural number ";
- ss << "in the exponent. Exception occurred in:" << std::endl;
+ ss << "The exponent of the POW(^) operator can only be a positive "
+ "integral constant below "
+ << (expr::NodeValue::MAX_CHILDREN + 1) << ". ";
+ ss << "Exception occurred in:" << std::endl;
ss << " " << t;
throw LogicException(ss.str());
}
}
-RewriteResponse ArithRewriter::preRewriteMult(TNode t){
- Assert(t.getKind()== kind::MULT || t.getKind()== kind::NONLINEAR_MULT);
+RewriteResponse ArithRewriter::preRewritePlus(TNode t){
+ Assert(t.getKind() == kind::ADD);
+ return RewriteResponse(REWRITE_DONE, expr::algorithm::flatten(t));
+}
+
+RewriteResponse ArithRewriter::postRewritePlus(TNode t){
+ Assert(t.getKind() == kind::ADD);
+ Assert(t.getNumChildren() > 1);
- if(t.getNumChildren() == 2){
- if(t[0].getKind() == kind::CONST_RATIONAL
- && t[0].getConst<Rational>().isOne()){
- return RewriteResponse(REWRITE_DONE, t[1]);
- }
- if(t[1].getKind() == kind::CONST_RATIONAL
- && t[1].getConst<Rational>().isOne()){
- return RewriteResponse(REWRITE_DONE, t[0]);
+ {
+ Node flat = expr::algorithm::flatten(t);
+ if (flat != t)
+ {
+ return RewriteResponse(REWRITE_AGAIN, flat);
}
}
- // Rewrite multiplications with a 0 argument and to 0
- for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
- if((*i).getKind() == kind::CONST_RATIONAL) {
- if((*i).getConst<Rational>().isZero()) {
- TNode zero = (*i);
- return RewriteResponse(REWRITE_DONE, zero);
+ Rational rational;
+ RealAlgebraicNumber ran;
+ std::vector<Monomial> monomials;
+ std::vector<Polynomial> polynomials;
+
+ for (const auto& child : t)
+ {
+ if (child.isConst())
+ {
+ if (child.getConst<Rational>().isZero())
+ {
+ continue;
}
+ rational += child.getConst<Rational>();
+ }
+ else if (child.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+ {
+ ran += child.getOperator().getConst<RealAlgebraicNumber>();
+ }
+ else if (Monomial::isMember(child))
+ {
+ monomials.emplace_back(Monomial::parseMonomial(child));
+ }
+ else
+ {
+ polynomials.emplace_back(Polynomial::parsePolynomial(child));
}
}
- return RewriteResponse(REWRITE_DONE, t);
-}
-static bool canFlatten(Kind k, TNode t){
- for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
- TNode child = *i;
- if(child.getKind() == k){
- return true;
- }
+ if(!monomials.empty()){
+ Monomial::sort(monomials);
+ Monomial::combineAdjacentMonomials(monomials);
+ polynomials.emplace_back(Polynomial::mkPolynomial(monomials));
+ }
+ if (!rational.isZero())
+ {
+ polynomials.emplace_back(
+ Polynomial::mkPolynomial(Constant::mkConstant(rational)));
}
- return false;
-}
-static void flatten(std::vector<TNode>& pb, Kind k, TNode t){
- if(t.getKind() == k){
- for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
- TNode child = *i;
- if(child.getKind() == k){
- flatten(pb, k, child);
- }else{
- pb.push_back(child);
- }
+ Polynomial poly = Polynomial::sumPolynomials(polynomials);
+
+ if (isZero(ran))
+ {
+ return RewriteResponse(REWRITE_DONE, poly.getNode());
+ }
+ if (poly.containsConstant())
+ {
+ ran += RealAlgebraicNumber(poly.getHead().getConstant().getValue());
+ if (!poly.isConstant())
+ {
+ poly = poly.getTail();
}
- }else{
- pb.push_back(t);
}
-}
-static Node flatten(Kind k, TNode t){
- std::vector<TNode> pb;
- flatten(pb, k, t);
- Assert(pb.size() >= 2);
- return NodeManager::currentNM()->mkNode(k, pb);
+ auto* nm = NodeManager::currentNM();
+ if (poly.isConstant())
+ {
+ return RewriteResponse(REWRITE_DONE, nm->mkRealAlgebraicNumber(ran));
+ }
+ return RewriteResponse(
+ REWRITE_DONE,
+ nm->mkNode(Kind::ADD, nm->mkRealAlgebraicNumber(ran), poly.getNode()));
}
-RewriteResponse ArithRewriter::preRewritePlus(TNode t){
- Assert(t.getKind()== kind::PLUS);
+RewriteResponse ArithRewriter::preRewriteMult(TNode node)
+{
+ Assert(node.getKind() == kind::MULT
+ || node.getKind() == kind::NONLINEAR_MULT);
- if(canFlatten(kind::PLUS, t)){
- return RewriteResponse(REWRITE_DONE, flatten(kind::PLUS, t));
- }else{
- return RewriteResponse(REWRITE_DONE, t);
+ if (auto res = rewriter::getZeroChild(node); res)
+ {
+ return RewriteResponse(REWRITE_DONE, *res);
}
+ return RewriteResponse(REWRITE_DONE, node);
}
-RewriteResponse ArithRewriter::postRewritePlus(TNode t){
- Assert(t.getKind()== kind::PLUS);
+RewriteResponse ArithRewriter::postRewriteMult(TNode t){
+ Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
+ Assert(t.getNumChildren() >= 2);
- std::vector<Monomial> monomials;
- std::vector<Polynomial> polynomials;
+ std::vector<TNode> children;
+ expr::algorithm::flatten(t, children, Kind::MULT, Kind::NONLINEAR_MULT);
- for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
- TNode curr = *i;
- if(Monomial::isMember(curr)){
- monomials.push_back(Monomial::parseMonomial(curr));
- }else{
- polynomials.push_back(Polynomial::parsePolynomial(curr));
- }
+ if (auto res = rewriter::getZeroChild(children); res)
+ {
+ return RewriteResponse(REWRITE_DONE, *res);
}
- if(!monomials.empty()){
- Monomial::sort(monomials);
- Monomial::combineAdjacentMonomials(monomials);
- polynomials.push_back(Polynomial::mkPolynomial(monomials));
+ // Distribute over addition
+ if (std::any_of(children.begin(), children.end(), [](TNode child) {
+ return child.getKind() == Kind::ADD;
+ }))
+ {
+ return RewriteResponse(REWRITE_DONE,
+ rewriter::distributeMultiplication(children));
}
- Polynomial res = Polynomial::sumPolynomials(polynomials);
+ RealAlgebraicNumber ran = RealAlgebraicNumber(Integer(1));
+ std::vector<Node> leafs;
- return RewriteResponse(REWRITE_DONE, res.getNode());
+ for (const auto& child : children)
+ {
+ if (child.isConst())
+ {
+ if (child.getConst<Rational>().isZero())
+ {
+ return RewriteResponse(REWRITE_DONE, child);
+ }
+ ran *= child.getConst<Rational>();
+ }
+ else if (rewriter::isRAN(child))
+ {
+ ran *= rewriter::getRAN(child);
+ }
+ else
+ {
+ leafs.emplace_back(child);
+ }
+ }
+
+ return RewriteResponse(REWRITE_DONE,
+ rewriter::mkMultTerm(ran, std::move(leafs)));
}
-RewriteResponse ArithRewriter::postRewriteMult(TNode t){
- Assert(t.getKind()== kind::MULT || t.getKind()==kind::NONLINEAR_MULT);
+RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre)
+{
+ Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind() == kind::DIVISION);
+ Assert(t.getNumChildren() == 2);
- Polynomial res = Polynomial::mkOne();
+ Node left = t[0];
+ Node right = t[1];
+ if (right.isConst())
+ {
+ NodeManager* nm = NodeManager::currentNM();
+ const Rational& den = right.getConst<Rational>();
+
+ if (den.isZero())
+ {
+ if (t.getKind() == kind::DIVISION_TOTAL)
+ {
+ return RewriteResponse(REWRITE_DONE, nm->mkConstReal(0));
+ }
+ else
+ {
+ // This is unsupported, but this is not a good place to complain
+ return RewriteResponse(REWRITE_DONE, t);
+ }
+ }
+ Assert(den != Rational(0));
- for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
- Node curr = *i;
- Polynomial currPoly = Polynomial::parsePolynomial(curr);
+ if (left.isConst())
+ {
+ const Rational& num = left.getConst<Rational>();
+ return RewriteResponse(REWRITE_DONE, nm->mkConstReal(num / den));
+ }
+ if (left.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+ {
+ const RealAlgebraicNumber& num =
+ left.getOperator().getConst<RealAlgebraicNumber>();
+ return RewriteResponse(
+ REWRITE_DONE,
+ nm->mkRealAlgebraicNumber(num / RealAlgebraicNumber(den)));
+ }
- res = res * currPoly;
+ Node result = nm->mkConstReal(den.inverse());
+ Node mult = NodeManager::currentNM()->mkNode(kind::MULT, left, result);
+ if (pre)
+ {
+ return RewriteResponse(REWRITE_DONE, mult);
+ }
+ else
+ {
+ return RewriteResponse(REWRITE_AGAIN, mult);
+ }
}
+ if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+ {
+ NodeManager* nm = NodeManager::currentNM();
+ const RealAlgebraicNumber& den =
+ right.getOperator().getConst<RealAlgebraicNumber>();
+ if (left.isConst())
+ {
+ const Rational& num = left.getConst<Rational>();
+ return RewriteResponse(
+ REWRITE_DONE,
+ nm->mkRealAlgebraicNumber(RealAlgebraicNumber(num) / den));
+ }
+ if (left.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+ {
+ const RealAlgebraicNumber& num =
+ left.getOperator().getConst<RealAlgebraicNumber>();
+ return RewriteResponse(REWRITE_DONE,
+ nm->mkRealAlgebraicNumber(num / den));
+ }
- return RewriteResponse(REWRITE_DONE, res.getNode());
+ Node result = nm->mkRealAlgebraicNumber(inverse(den));
+ Node mult = NodeManager::currentNM()->mkNode(kind::MULT, left, result);
+ if (pre)
+ {
+ return RewriteResponse(REWRITE_DONE, mult);
+ }
+ else
+ {
+ return RewriteResponse(REWRITE_AGAIN, mult);
+ }
+ }
+ return RewriteResponse(REWRITE_DONE, t);
}
+RewriteResponse ArithRewriter::rewriteAbs(TNode t)
+{
+ Assert(t.getKind() == Kind::ABS);
+ Assert(t.getNumChildren() == 1);
-RewriteResponse ArithRewriter::preRewriteTranscendental(TNode t) {
+ if (t[0].isConst())
+ {
+ const Rational& rat = t[0].getConst<Rational>();
+ if (rat >= 0)
+ {
+ return RewriteResponse(REWRITE_DONE, t[0]);
+ }
+ return RewriteResponse(
+ REWRITE_DONE,
+ NodeManager::currentNM()->mkConstRealOrInt(t[0].getType(), -rat));
+ }
+ if (t[0].getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+ {
+ const RealAlgebraicNumber& ran =
+ t[0].getOperator().getConst<RealAlgebraicNumber>();
+ if (ran >= RealAlgebraicNumber())
+ {
+ return RewriteResponse(REWRITE_DONE, t[0]);
+ }
+ return RewriteResponse(
+ REWRITE_DONE, NodeManager::currentNM()->mkRealAlgebraicNumber(-ran));
+ }
return RewriteResponse(REWRITE_DONE, t);
}
-RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) {
- Trace("arith-tf-rewrite") << "Rewrite transcendental function : " << t << std::endl;
+RewriteResponse ArithRewriter::rewriteIntsDivMod(TNode t, bool pre)
+{
NodeManager* nm = NodeManager::currentNM();
- switch( t.getKind() ){
- case kind::EXPONENTIAL: {
- if(t[0].getKind() == kind::CONST_RATIONAL){
- Node one = nm->mkConst(Rational(1));
- if(t[0].getConst<Rational>().sgn()>=0 && t[0].getType().isInteger() && t[0]!=one){
- return RewriteResponse(
- REWRITE_AGAIN,
- nm->mkNode(kind::POW, nm->mkNode(kind::EXPONENTIAL, one), t[0]));
- }else{
- return RewriteResponse(REWRITE_DONE, t);
+ Kind k = t.getKind();
+ if (k == kind::INTS_MODULUS)
+ {
+ if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
+ {
+ // can immediately replace by INTS_MODULUS_TOTAL
+ Node ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, t[0], t[1]);
+ return returnRewrite(t, ret, Rewrite::MOD_TOTAL_BY_CONST);
+ }
+ }
+ if (k == kind::INTS_DIVISION)
+ {
+ if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
+ {
+ // can immediately replace by INTS_DIVISION_TOTAL
+ Node ret = nm->mkNode(kind::INTS_DIVISION_TOTAL, t[0], t[1]);
+ return returnRewrite(t, ret, Rewrite::DIV_TOTAL_BY_CONST);
+ }
+ }
+ return RewriteResponse(REWRITE_DONE, t);
+}
+
+RewriteResponse ArithRewriter::rewriteExtIntegerOp(TNode t)
+{
+ Assert(t.getKind() == kind::TO_INTEGER || t.getKind() == kind::IS_INTEGER);
+ bool isPred = t.getKind() == kind::IS_INTEGER;
+ NodeManager* nm = NodeManager::currentNM();
+ if (t[0].isConst())
+ {
+ Node ret;
+ if (isPred)
+ {
+ ret = nm->mkConst(t[0].getConst<Rational>().isIntegral());
+ }
+ else
+ {
+ ret = nm->mkConstInt(Rational(t[0].getConst<Rational>().floor()));
+ }
+ return returnRewrite(t, ret, Rewrite::INT_EXT_CONST);
+ }
+ if (t[0].getType().isInteger())
+ {
+ Node ret = isPred ? nm->mkConst(true) : Node(t[0]);
+ return returnRewrite(t, ret, Rewrite::INT_EXT_INT);
+ }
+ if (t[0].getKind() == kind::PI)
+ {
+ Node ret = isPred ? nm->mkConst(false) : nm->mkConstReal(Rational(3));
+ return returnRewrite(t, ret, Rewrite::INT_EXT_PI);
+ }
+ return RewriteResponse(REWRITE_DONE, t);
+}
+
+RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre)
+{
+ if (pre)
+ {
+ // do not rewrite at prewrite.
+ return RewriteResponse(REWRITE_DONE, t);
+ }
+ NodeManager* nm = NodeManager::currentNM();
+ Kind k = t.getKind();
+ Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL);
+ TNode n = t[0];
+ TNode d = t[1];
+ bool dIsConstant = d.isConst();
+ if (dIsConstant && d.getConst<Rational>().isZero())
+ {
+ // (div x 0) ---> 0 or (mod x 0) ---> 0
+ return returnRewrite(t, nm->mkConstInt(0), Rewrite::DIV_MOD_BY_ZERO);
+ }
+ else if (dIsConstant && d.getConst<Rational>().isOne())
+ {
+ if (k == kind::INTS_MODULUS_TOTAL)
+ {
+ // (mod x 1) --> 0
+ return returnRewrite(t, nm->mkConstInt(0), Rewrite::MOD_BY_ONE);
+ }
+ Assert(k == kind::INTS_DIVISION_TOTAL);
+ // (div x 1) --> x
+ return returnRewrite(t, n, Rewrite::DIV_BY_ONE);
+ }
+ else if (dIsConstant && d.getConst<Rational>().sgn() < 0)
+ {
+ // pull negation
+ // (div x (- c)) ---> (- (div x c))
+ // (mod x (- c)) ---> (mod x c)
+ Node nn = nm->mkNode(k, t[0], nm->mkConstInt(-t[1].getConst<Rational>()));
+ Node ret = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL)
+ ? nm->mkNode(kind::NEG, nn)
+ : nn;
+ return returnRewrite(t, ret, Rewrite::DIV_MOD_PULL_NEG_DEN);
+ }
+ else if (dIsConstant && n.isConst())
+ {
+ Assert(d.getConst<Rational>().isIntegral());
+ Assert(n.getConst<Rational>().isIntegral());
+ Assert(!d.getConst<Rational>().isZero());
+ Integer di = d.getConst<Rational>().getNumerator();
+ Integer ni = n.getConst<Rational>().getNumerator();
+
+ bool isDiv = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
+
+ Integer result = isDiv ? ni.euclidianDivideQuotient(di)
+ : ni.euclidianDivideRemainder(di);
+
+ // constant evaluation
+ // (mod c1 c2) ---> c3 or (div c1 c2) ---> c3
+ Node resultNode = nm->mkConstInt(Rational(result));
+ return returnRewrite(t, resultNode, Rewrite::CONST_EVAL);
+ }
+ if (k == kind::INTS_MODULUS_TOTAL)
+ {
+ // Note these rewrites do not need to account for modulus by zero as being
+ // a UF, which is handled by the reduction of INTS_MODULUS.
+ Kind k0 = t[0].getKind();
+ if (k0 == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
+ {
+ // (mod (mod x c) c) --> (mod x c)
+ return returnRewrite(t, t[0], Rewrite::MOD_OVER_MOD);
+ }
+ else if (k0 == kind::NONLINEAR_MULT || k0 == kind::MULT || k0 == kind::ADD)
+ {
+ // can drop all
+ std::vector<Node> newChildren;
+ bool childChanged = false;
+ for (const Node& tc : t[0])
+ {
+ if (tc.getKind() == kind::INTS_MODULUS_TOTAL && tc[1] == t[1])
+ {
+ newChildren.push_back(tc[0]);
+ childChanged = true;
+ continue;
+ }
+ newChildren.push_back(tc);
}
- }else if(t[0].getKind() == kind::PLUS ){
- std::vector<Node> product;
- for( unsigned i=0; i<t[0].getNumChildren(); i++ ){
- product.push_back(nm->mkNode(kind::EXPONENTIAL, t[0][i]));
+ if (childChanged)
+ {
+ // (mod (op ... (mod x c) ...) c) ---> (mod (op ... x ...) c) where
+ // op is one of { NONLINEAR_MULT, MULT, ADD }.
+ Node ret = nm->mkNode(k0, newChildren);
+ ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, ret, t[1]);
+ return returnRewrite(t, ret, Rewrite::MOD_CHILD_MOD);
}
- return RewriteResponse(REWRITE_AGAIN, nm->mkNode(kind::MULT, product));
}
}
- break;
- case kind::SINE:
- if(t[0].getKind() == kind::CONST_RATIONAL){
- const Rational& rat = t[0].getConst<Rational>();
- if(rat.sgn() == 0){
- return RewriteResponse(REWRITE_DONE, nm->mkConst(Rational(0)));
+ else
+ {
+ Assert(k == kind::INTS_DIVISION_TOTAL);
+ // Note these rewrites do not need to account for division by zero as being
+ // a UF, which is handled by the reduction of INTS_DIVISION.
+ if (t[0].getKind() == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
+ {
+ // (div (mod x c) c) --> 0
+ Node ret = nm->mkConstInt(0);
+ return returnRewrite(t, ret, Rewrite::DIV_OVER_MOD);
+ }
+ }
+ return RewriteResponse(REWRITE_DONE, t);
+}
+
+RewriteResponse ArithRewriter::postRewriteIAnd(TNode t)
+{
+ Assert(t.getKind() == kind::IAND);
+ size_t bsize = t.getOperator().getConst<IntAnd>().d_size;
+ NodeManager* nm = NodeManager::currentNM();
+ // if constant, we eliminate
+ if (t[0].isConst() && t[1].isConst())
+ {
+ Node iToBvop = nm->mkConst(IntToBitVector(bsize));
+ Node arg1 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[0]);
+ Node arg2 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[1]);
+ Node bvand = nm->mkNode(kind::BITVECTOR_AND, arg1, arg2);
+ Node ret = nm->mkNode(kind::BITVECTOR_TO_NAT, bvand);
+ return RewriteResponse(REWRITE_AGAIN_FULL, ret);
+ }
+ else if (t[0] > t[1])
+ {
+ // ((_ iand k) x y) ---> ((_ iand k) y x) if x > y by node ordering
+ Node ret = nm->mkNode(kind::IAND, t.getOperator(), t[1], t[0]);
+ return RewriteResponse(REWRITE_AGAIN, ret);
+ }
+ else if (t[0] == t[1])
+ {
+ // ((_ iand k) x x) ---> (mod x 2^k)
+ Node twok = nm->mkConstInt(Rational(Integer(2).pow(bsize)));
+ Node ret = nm->mkNode(kind::INTS_MODULUS, t[0], twok);
+ return RewriteResponse(REWRITE_AGAIN, ret);
+ }
+ // simplifications involving constants
+ for (unsigned i = 0; i < 2; i++)
+ {
+ if (!t[i].isConst())
+ {
+ continue;
+ }
+ if (t[i].getConst<Rational>().sgn() == 0)
+ {
+ // ((_ iand k) 0 y) ---> 0
+ return RewriteResponse(REWRITE_DONE, t[i]);
+ }
+ if (t[i].getConst<Rational>().getNumerator() == Integer(2).pow(bsize) - 1)
+ {
+ // ((_ iand k) 111...1 y) ---> (mod y 2^k)
+ Node twok = nm->mkConstInt(Rational(Integer(2).pow(bsize)));
+ Node ret = nm->mkNode(kind::INTS_MODULUS, t[1 - i], twok);
+ return RewriteResponse(REWRITE_AGAIN, ret);
+ }
+ }
+ return RewriteResponse(REWRITE_DONE, t);
+}
+
+RewriteResponse ArithRewriter::postRewritePow2(TNode t)
+{
+ Assert(t.getKind() == kind::POW2);
+ NodeManager* nm = NodeManager::currentNM();
+ // if constant, we eliminate
+ if (t[0].isConst())
+ {
+ // pow2 is only supported for integers
+ Assert(t[0].getType().isInteger());
+ Integer i = t[0].getConst<Rational>().getNumerator();
+ if (i < 0)
+ {
+ return RewriteResponse(REWRITE_DONE, rewriter::mkConst(Integer(0)));
+ }
+ // (pow2 t) ---> (pow 2 t) and continue rewriting to eliminate pow
+ Node two = rewriter::mkConst(Integer(2));
+ Node ret = nm->mkNode(kind::POW, two, t[0]);
+ return RewriteResponse(REWRITE_AGAIN, ret);
+ }
+ return RewriteResponse(REWRITE_DONE, t);
+}
+
+RewriteResponse ArithRewriter::preRewriteTranscendental(TNode t)
+{
+ return RewriteResponse(REWRITE_DONE, t);
+}
+
+RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t)
+{
+ Trace("arith-tf-rewrite")
+ << "Rewrite transcendental function : " << t << std::endl;
+ NodeManager* nm = NodeManager::currentNM();
+ switch (t.getKind())
+ {
+ case kind::EXPONENTIAL:
+ {
+ if (t[0].isConst())
+ {
+ Node one = rewriter::mkConst(Integer(1));
+ if (t[0].getConst<Rational>().sgn() >= 0 && t[0].getType().isInteger()
+ && t[0] != one)
+ {
+ return RewriteResponse(
+ REWRITE_AGAIN,
+ nm->mkNode(kind::POW, nm->mkNode(kind::EXPONENTIAL, one), t[0]));
+ }
+ else
+ {
+ return RewriteResponse(REWRITE_DONE, t);
+ }
}
- else if (rat.sgn() == -1)
+ else if (t[0].getKind() == kind::ADD)
{
- Node ret =
- nm->mkNode(kind::UMINUS, nm->mkNode(kind::SINE, nm->mkConst(-rat)));
- return RewriteResponse(REWRITE_AGAIN_FULL, ret);
+ std::vector<Node> product;
+ for (const Node tc : t[0])
+ {
+ product.push_back(nm->mkNode(kind::EXPONENTIAL, tc));
+ }
+ // We need to do a full rewrite here, since we can get exponentials of
+ // constants, e.g. when we are rewriting exp(2 + x)
+ return RewriteResponse(REWRITE_AGAIN_FULL,
+ nm->mkNode(kind::MULT, product));
}
- }else{
- // get the factor of PI in the argument
- Node pi_factor;
- Node pi;
- Node rem;
- std::map<Node, Node> msum;
- if (ArithMSum::getMonomialSum(t[0], msum))
+ }
+ break;
+ case kind::SINE:
+ if (t[0].isConst())
{
- pi = mkPi();
- std::map<Node, Node>::iterator itm = msum.find(pi);
- if (itm != msum.end())
+ const Rational& rat = t[0].getConst<Rational>();
+ if (rat.sgn() == 0)
{
- if (itm->second.isNull())
- {
- pi_factor = mkRationalNode(Rational(1));
- }
- else
- {
- pi_factor = itm->second;
- }
- msum.erase(pi);
- if (!msum.empty())
- {
- rem = ArithMSum::mkNode(msum);
- }
+ return RewriteResponse(REWRITE_DONE, rewriter::mkConst(Integer(0)));
+ }
+ else if (rat.sgn() == -1)
+ {
+ Node ret = nm->mkNode(
+ kind::NEG, nm->mkNode(kind::SINE, rewriter::mkConst(-rat)));
+ return RewriteResponse(REWRITE_AGAIN_FULL, ret);
}
}
- else
+ else if ((t[0].getKind() == MULT || t[0].getKind() == NONLINEAR_MULT)
+ && t[0][0].isConst() && t[0][0].getConst<Rational>().sgn() == -1)
{
- Assert(false);
+ // sin(-n*x) ---> -sin(n*x)
+ std::vector<Node> mchildren(t[0].begin(), t[0].end());
+ mchildren[0] = nm->mkConstReal(-t[0][0].getConst<Rational>());
+ Node ret = nm->mkNode(
+ kind::NEG,
+ nm->mkNode(kind::SINE, nm->mkNode(t[0].getKind(), mchildren)));
+ return RewriteResponse(REWRITE_AGAIN_FULL, ret);
}
-
- // if there is a factor of PI
- if( !pi_factor.isNull() ){
- Trace("arith-tf-rewrite-debug") << "Process pi factor = " << pi_factor << std::endl;
- Rational r = pi_factor.getConst<Rational>();
- Rational r_abs = r.abs();
- Rational rone = Rational(1);
- Node ntwo = mkRationalNode(Rational(2));
- if (r_abs > rone)
+ else
+ {
+ // get the factor of PI in the argument
+ Node pi_factor;
+ Node pi;
+ Node rem;
+ std::map<Node, Node> msum;
+ if (ArithMSum::getMonomialSum(t[0], msum))
{
- //add/substract 2*pi beyond scope
- Node ra_div_two = nm->mkNode(
- kind::INTS_DIVISION, mkRationalNode(r_abs + rone), ntwo);
- Node new_pi_factor;
- if( r.sgn()==1 ){
- new_pi_factor =
- nm->mkNode(kind::MINUS,
- pi_factor,
- nm->mkNode(kind::MULT, ntwo, ra_div_two));
- }else{
- Assert( r.sgn()==-1 );
- new_pi_factor =
- nm->mkNode(kind::PLUS,
- pi_factor,
- nm->mkNode(kind::MULT, ntwo, ra_div_two));
- }
- Node new_arg = nm->mkNode(kind::MULT, new_pi_factor, pi);
- if (!rem.isNull())
+ pi = mkPi();
+ std::map<Node, Node>::iterator itm = msum.find(pi);
+ if (itm != msum.end())
{
- new_arg = nm->mkNode(kind::PLUS, new_arg, rem);
+ if (itm->second.isNull())
+ {
+ pi_factor = rewriter::mkConst(Integer(1));
+ }
+ else
+ {
+ pi_factor = itm->second;
+ }
+ msum.erase(pi);
+ if (!msum.empty())
+ {
+ rem = ArithMSum::mkNode(t[0].getType(), msum);
+ }
}
- // sin( 2*n*PI + x ) = sin( x )
- return RewriteResponse(REWRITE_AGAIN_FULL,
- nm->mkNode(kind::SINE, new_arg));
}
- else if (r_abs == rone)
+ else
{
- // sin( PI + x ) = -sin( x )
- if (rem.isNull())
- {
- return RewriteResponse(REWRITE_DONE, mkRationalNode(Rational(0)));
- }
- else
- {
- return RewriteResponse(
- REWRITE_AGAIN_FULL,
- nm->mkNode(kind::UMINUS, nm->mkNode(kind::SINE, rem)));
- }
+ Assert(false);
}
- else if (rem.isNull())
+
+ // if there is a factor of PI
+ if (!pi_factor.isNull())
{
- // other rational cases based on Niven's theorem
- // (https://en.wikipedia.org/wiki/Niven%27s_theorem)
- Integer one = Integer(1);
- Integer two = Integer(2);
- Integer six = Integer(6);
- if (r_abs.getDenominator() == two)
+ Trace("arith-tf-rewrite-debug")
+ << "Process pi factor = " << pi_factor << std::endl;
+ Rational r = pi_factor.getConst<Rational>();
+ Rational r_abs = r.abs();
+ Rational rone = Rational(1);
+ Rational rtwo = Rational(2);
+ if (r_abs > rone)
{
- Assert(r_abs.getNumerator() == one);
- return RewriteResponse(REWRITE_DONE,
- mkRationalNode(Rational(r.sgn())));
+ // add/substract 2*pi beyond scope
+ Rational ra_div_two = (r_abs + rone) / rtwo;
+ Node new_pi_factor;
+ if (r.sgn() == 1)
+ {
+ new_pi_factor = nm->mkConstReal(r - rtwo * ra_div_two.floor());
+ }
+ else
+ {
+ Assert(r.sgn() == -1);
+ new_pi_factor = nm->mkConstReal(r + rtwo * ra_div_two.floor());
+ }
+ Node new_arg = nm->mkNode(kind::MULT, new_pi_factor, pi);
+ if (!rem.isNull())
+ {
+ new_arg = nm->mkNode(kind::ADD, new_arg, rem);
+ }
+ // sin( 2*n*PI + x ) = sin( x )
+ return RewriteResponse(REWRITE_AGAIN_FULL,
+ nm->mkNode(kind::SINE, new_arg));
}
- else if (r_abs.getDenominator() == six)
+ else if (r_abs == rone)
{
- Integer five = Integer(5);
- if (r_abs.getNumerator() == one || r_abs.getNumerator() == five)
+ // sin( PI + x ) = -sin( x )
+ if (rem.isNull())
+ {
+ return RewriteResponse(REWRITE_DONE,
+ nm->mkConstReal(Rational(0)));
+ }
+ else
{
return RewriteResponse(
- REWRITE_DONE,
- mkRationalNode(Rational(r.sgn()) / Rational(2)));
+ REWRITE_AGAIN_FULL,
+ nm->mkNode(kind::NEG, nm->mkNode(kind::SINE, rem)));
+ }
+ }
+ else if (rem.isNull())
+ {
+ // other rational cases based on Niven's theorem
+ // (https://en.wikipedia.org/wiki/Niven%27s_theorem)
+ Integer one = Integer(1);
+ Integer two = Integer(2);
+ Integer six = Integer(6);
+ if (r_abs.getDenominator() == two)
+ {
+ Assert(r_abs.getNumerator() == one);
+ return RewriteResponse(REWRITE_DONE,
+ nm->mkConstReal(Rational(r.sgn())));
+ }
+ else if (r_abs.getDenominator() == six)
+ {
+ Integer five = Integer(5);
+ if (r_abs.getNumerator() == one || r_abs.getNumerator() == five)
+ {
+ return RewriteResponse(
+ REWRITE_DONE,
+ nm->mkConstReal(Rational(r.sgn()) / Rational(2)));
+ }
}
}
}
}
+ break;
+ case kind::COSINE:
+ {
+ return RewriteResponse(
+ REWRITE_AGAIN_FULL,
+ nm->mkNode(
+ kind::SINE,
+ nm->mkNode(kind::SUB,
+ nm->mkNode(kind::MULT,
+ nm->mkConstReal(Rational(1) / Rational(2)),
+ mkPi()),
+ t[0])));
}
break;
- case kind::COSINE: {
- return RewriteResponse(
- REWRITE_AGAIN_FULL,
- nm->mkNode(kind::SINE,
- nm->mkNode(kind::MINUS,
- nm->mkNode(kind::MULT,
- nm->mkConst(Rational(1) / Rational(2)),
- mkPi()),
- t[0])));
- }
- break;
- case kind::TANGENT:
- {
- return RewriteResponse(REWRITE_AGAIN_FULL,
- nm->mkNode(kind::DIVISION,
- nm->mkNode(kind::SINE, t[0]),
- nm->mkNode(kind::COSINE, t[0])));
- }
- break;
- case kind::COSECANT:
- {
- return RewriteResponse(REWRITE_AGAIN_FULL,
- nm->mkNode(kind::DIVISION,
- mkRationalNode(Rational(1)),
- nm->mkNode(kind::SINE, t[0])));
- }
- break;
- case kind::SECANT:
- {
- return RewriteResponse(REWRITE_AGAIN_FULL,
- nm->mkNode(kind::DIVISION,
- mkRationalNode(Rational(1)),
- nm->mkNode(kind::COSINE, t[0])));
- }
- break;
- case kind::COTANGENT:
- {
- return RewriteResponse(REWRITE_AGAIN_FULL,
- nm->mkNode(kind::DIVISION,
- nm->mkNode(kind::COSINE, t[0]),
- nm->mkNode(kind::SINE, t[0])));
- }
- break;
- default:
- break;
- }
- return RewriteResponse(REWRITE_DONE, t);
-}
-
-RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
- if(atom.getKind() == kind::IS_INTEGER) {
- if(atom[0].isConst()) {
- return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(atom[0].getConst<Rational>().isIntegral()));
- }
- if(atom[0].getType().isInteger()) {
- return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
- }
- // not supported, but this isn't the right place to complain
- return RewriteResponse(REWRITE_DONE, atom);
- } else if(atom.getKind() == kind::DIVISIBLE) {
- if(atom[0].isConst()) {
- return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(bool((atom[0].getConst<Rational>() / atom.getOperator().getConst<Divisible>().k).isIntegral())));
- }
- if(atom.getOperator().getConst<Divisible>().k.isOne()) {
- return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
- }
- return RewriteResponse(REWRITE_AGAIN, NodeManager::currentNM()->mkNode(kind::EQUAL, NodeManager::currentNM()->mkNode(kind::INTS_MODULUS_TOTAL, atom[0], NodeManager::currentNM()->mkConst(Rational(atom.getOperator().getConst<Divisible>().k))), NodeManager::currentNM()->mkConst(Rational(0))));
- }
-
- // left |><| right
- TNode left = atom[0];
- TNode right = atom[1];
-
- Polynomial pleft = Polynomial::parsePolynomial(left);
- Polynomial pright = Polynomial::parsePolynomial(right);
-
- Debug("arith::rewriter") << "pleft " << pleft.getNode() << std::endl;
- Debug("arith::rewriter") << "pright " << pright.getNode() << std::endl;
-
- Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
- Assert(cmp.isNormalForm());
- return RewriteResponse(REWRITE_DONE, cmp.getNode());
-}
-
-RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
- Assert(isAtom(atom));
-
- NodeManager* currNM = NodeManager::currentNM();
-
- if(atom.getKind() == kind::EQUAL) {
- if(atom[0] == atom[1]) {
- return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
- }
- }else if(atom.getKind() == kind::GT){
- Node leq = currNM->mkNode(kind::LEQ, atom[0], atom[1]);
- return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, leq));
- }else if(atom.getKind() == kind::LT){
- Node geq = currNM->mkNode(kind::GEQ, atom[0], atom[1]);
- return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, geq));
- }else if(atom.getKind() == kind::IS_INTEGER){
- if(atom[0].getType().isInteger()){
- return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
+ case kind::TANGENT:
+ {
+ return RewriteResponse(REWRITE_AGAIN_FULL,
+ nm->mkNode(kind::DIVISION,
+ nm->mkNode(kind::SINE, t[0]),
+ nm->mkNode(kind::COSINE, t[0])));
}
- }else if(atom.getKind() == kind::DIVISIBLE){
- if(atom.getOperator().getConst<Divisible>().k.isOne()){
- return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
+ break;
+ case kind::COSECANT:
+ {
+ return RewriteResponse(REWRITE_AGAIN_FULL,
+ nm->mkNode(kind::DIVISION,
+ nm->mkConstReal(Rational(1)),
+ nm->mkNode(kind::SINE, t[0])));
}
- }
-
- return RewriteResponse(REWRITE_DONE, atom);
-}
-
-RewriteResponse ArithRewriter::postRewrite(TNode t){
- if(isTerm(t)){
- RewriteResponse response = postRewriteTerm(t);
- if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
- Polynomial::parsePolynomial(response.node);
+ break;
+ case kind::SECANT:
+ {
+ return RewriteResponse(REWRITE_AGAIN_FULL,
+ nm->mkNode(kind::DIVISION,
+ nm->mkConstReal(Rational(1)),
+ nm->mkNode(kind::COSINE, t[0])));
}
- return response;
- }else if(isAtom(t)){
- RewriteResponse response = postRewriteAtom(t);
- if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
- Comparison::parseNormalForm(response.node);
+ break;
+ case kind::COTANGENT:
+ {
+ return RewriteResponse(REWRITE_AGAIN_FULL,
+ nm->mkNode(kind::DIVISION,
+ nm->mkNode(kind::COSINE, t[0]),
+ nm->mkNode(kind::SINE, t[0])));
}
- return response;
- }else{
- Unreachable();
- }
-}
-
-RewriteResponse ArithRewriter::preRewrite(TNode t){
- if(isTerm(t)){
- return preRewriteTerm(t);
- }else if(isAtom(t)){
- return preRewriteAtom(t);
- }else{
- Unreachable();
+ break;
+ default: break;
}
+ return RewriteResponse(REWRITE_DONE, t);
}
-Node ArithRewriter::makeUnaryMinusNode(TNode n){
- Rational qNegOne(-1);
- return NodeManager::currentNM()->mkNode(kind::MULT, mkRationalNode(qNegOne),n);
-}
-
-Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
- Node negR = makeUnaryMinusNode(r);
- Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
-
- return diff;
-}
-
-RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
- Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind()== kind::DIVISION);
-
- Node left = t[0];
- Node right = t[1];
- if(right.getKind() == kind::CONST_RATIONAL){
- const Rational& den = right.getConst<Rational>();
-
- if(den.isZero()){
- if(t.getKind() == kind::DIVISION_TOTAL){
- return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
- }else{
- // This is unsupported, but this is not a good place to complain
- return RewriteResponse(REWRITE_DONE, t);
- }
- }
- Assert(den != Rational(0));
-
- if(left.getKind() == kind::CONST_RATIONAL){
- const Rational& num = left.getConst<Rational>();
- Rational div = num / den;
- Node result = mkRationalNode(div);
- return RewriteResponse(REWRITE_DONE, result);
- }
-
- Rational div = den.inverse();
-
- Node result = mkRationalNode(div);
-
- Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
- if(pre){
- return RewriteResponse(REWRITE_DONE, mult);
- }else{
- return RewriteResponse(REWRITE_AGAIN, mult);
- }
- }else{
- return RewriteResponse(REWRITE_DONE, t);
- }
+TrustNode ArithRewriter::expandDefinition(Node node)
+{
+ // call eliminate operators, to eliminate partial operators only
+ std::vector<SkolemLemma> lems;
+ TrustNode ret = d_opElim.eliminate(node, lems, true);
+ Assert(lems.empty());
+ return ret;
}
-RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre){
- Kind k = t.getKind();
- // Assert(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL ||
- // k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
-
- //Leaving the function as before (INTS_MODULUS can be handled),
- // but restricting its use here
- Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL);
- TNode n = t[0], d = t[1];
- bool dIsConstant = d.getKind() == kind::CONST_RATIONAL;
- if(dIsConstant && d.getConst<Rational>().isZero()){
- if(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL){
- return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
- }else{
- // Do nothing for k == INTS_MODULUS
- return RewriteResponse(REWRITE_DONE, t);
- }
- }else if(dIsConstant && d.getConst<Rational>().isOne()){
- if(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL){
- return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
- }else{
- Assert(k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
- return RewriteResponse(REWRITE_AGAIN, n);
- }
- }else if(dIsConstant && d.getConst<Rational>().isNegativeOne()){
- if(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL){
- return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
- }else{
- Assert(k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
- return RewriteResponse(REWRITE_AGAIN, NodeManager::currentNM()->mkNode(kind::UMINUS, n));
- }
- }else if(dIsConstant && n.getKind() == kind::CONST_RATIONAL){
- Assert(d.getConst<Rational>().isIntegral());
- Assert(n.getConst<Rational>().isIntegral());
- Assert(!d.getConst<Rational>().isZero());
- Integer di = d.getConst<Rational>().getNumerator();
- Integer ni = n.getConst<Rational>().getNumerator();
-
- bool isDiv = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
-
- Integer result = isDiv ? ni.euclidianDivideQuotient(di) : ni.euclidianDivideRemainder(di);
-
- Node resultNode = mkRationalNode(Rational(result));
- return RewriteResponse(REWRITE_DONE, resultNode);
- }else{
- return RewriteResponse(REWRITE_DONE, t);
- }
+RewriteResponse ArithRewriter::returnRewrite(TNode t, Node ret, Rewrite r)
+{
+ Trace("arith-rewriter") << "ArithRewriter : " << t << " == " << ret << " by "
+ << r << std::endl;
+ return RewriteResponse(REWRITE_AGAIN_FULL, ret);
}
-}/* CVC4::theory::arith namespace */
-}/* CVC4::theory namespace */
-}/* CVC4 namespace */
+} // namespace arith
+} // namespace theory
+} // namespace cvc5