#include "theory/arith/arith_utilities.h"
#include "theory/arith/normal_form.h"
#include "theory/arith/operator_elim.h"
+#include "theory/arith/rewriter/addition.h"
+#include "theory/arith/rewriter/node_utils.h"
+#include "theory/arith/rewriter/ordering.h"
#include "theory/theory.h"
#include "util/bitvector.h"
#include "util/divisible.h"
namespace {
-/**
- * Implements an ordering on arithmetic leaf nodes, excluding rationals. As this
- * comparator is meant to be used on children of Kind::NONLINEAR_MULT, we expect
- * rationals to be handled separately. Furthermore, we expect there to be only a
- * single real algebraic number.
- * It broadly categorizes leaf nodes into real algebraic numbers, integers,
- * variables, and the rest. The ordering is built as follows:
- * - real algebraic numbers come first
- * - real terms come before integer terms
- * - variables come before non-variable terms
- * - finally, fall back to node ordering
- */
-struct LeafNodeComparator
-{
- /** Implements operator<(a, b) as described above */
- bool operator()(TNode a, TNode b)
- {
- if (a == b) return false;
-
- bool aIsRAN = a.getKind() == Kind::REAL_ALGEBRAIC_NUMBER;
- bool bIsRAN = b.getKind() == Kind::REAL_ALGEBRAIC_NUMBER;
- if (aIsRAN != bIsRAN) return aIsRAN;
- Assert(!aIsRAN && !bIsRAN) << "real algebraic numbers should be combined";
-
- bool aIsInt = a.getType().isInteger();
- bool bIsInt = b.getType().isInteger();
- if (aIsInt != bIsInt) return !aIsInt;
-
- bool aIsVar = a.isVar();
- bool bIsVar = b.isVar();
- if (aIsVar != bIsVar) return aIsVar;
-
- return a < b;
- }
-};
-
-/**
- * Implements an ordering on arithmetic nonlinear multiplications. As we assume
- * rationals to be handled separately, we only consider Kind::NONLINEAR_MULT as
- * multiplication terms. For individual factors of the product, we rely on the
- * ordering from LeafNodeComparator. Furthermore, we expect products to be
- * sorted according to LeafNodeComparator. The ordering is built as follows:
- * - single factors come first (everything that is not NONLINEAR_MULT)
- * - multiplications with less factors come first
- * - multiplications are compared lexicographically
- */
-struct ProductNodeComparator
-{
- /** Implements operator<(a, b) as described above */
- bool operator()(TNode a, TNode b)
- {
- if (a == b) return false;
-
- Assert(a.getKind() != Kind::MULT);
- Assert(b.getKind() != Kind::MULT);
-
- bool aIsMult = a.getKind() == Kind::NONLINEAR_MULT;
- bool bIsMult = b.getKind() == Kind::NONLINEAR_MULT;
- if (aIsMult != bIsMult) return !aIsMult;
-
- if (!aIsMult)
- {
- return LeafNodeComparator()(a, b);
- }
-
- size_t aLen = a.getNumChildren();
- size_t bLen = b.getNumChildren();
- if (aLen != bLen) return aLen < bLen;
-
- for (size_t i = 0; i < aLen; ++i)
- {
- if (a[i] != b[i])
- {
- return LeafNodeComparator()(a[i], b[i]);
- }
- }
- Unreachable() << "Nodes are different, but have the same content";
- return false;
- }
-};
-
-
template <typename L, typename R>
bool evaluateRelation(Kind rel, const L& l, const R& r)
{
return RewriteResponse(REWRITE_DONE, t);
}
-RewriteResponse ArithRewriter::rewriteVariable(TNode t){
+RewriteResponse ArithRewriter::rewriteVariable(TNode t)
+{
Assert(t.isVar());
return RewriteResponse(REWRITE_DONE, t);
}
-RewriteResponse ArithRewriter::rewriteSub(TNode t)
-{
- Assert(t.getKind() == kind::SUB);
- Assert(t.getNumChildren() == 2);
-
- auto* nm = NodeManager::currentNM();
-
- if (t[0] == t[1])
- {
- return RewriteResponse(REWRITE_DONE,
- nm->mkConstRealOrInt(t.getType(), Rational(0)));
- }
- return RewriteResponse(REWRITE_AGAIN_FULL,
- nm->mkNode(Kind::ADD, t[0], makeUnaryMinusNode(t[1])));
-}
-
RewriteResponse ArithRewriter::rewriteNeg(TNode t, bool pre)
{
Assert(t.getKind() == kind::NEG);
if (t[0].isConst())
{
Rational neg = -(t[0].getConst<Rational>());
- NodeManager* nm = NodeManager::currentNM();
- return RewriteResponse(REWRITE_DONE,
- nm->mkConstRealOrInt(t[0].getType(), neg));
+ return RewriteResponse(REWRITE_DONE, rewriter::mkConst(neg));
}
- if (t[0].getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+ if (rewriter::isRAN(t[0]))
{
- const RealAlgebraicNumber& r =
- t[0].getOperator().getConst<RealAlgebraicNumber>();
- NodeManager* nm = NodeManager::currentNM();
- return RewriteResponse(REWRITE_DONE, nm->mkRealAlgebraicNumber(-r));
+ return RewriteResponse(REWRITE_DONE,
+ rewriter::mkConst(-rewriter::getRAN(t[0])));
}
- Node noUminus = makeUnaryMinusNode(t[0]);
- if(pre)
+ auto* nm = NodeManager::currentNM();
+ Node noUminus = nm->mkNode(kind::MULT, rewriter::mkConst(Integer(-1)), t[0]);
+ if (pre)
return RewriteResponse(REWRITE_DONE, noUminus);
else
return RewriteResponse(REWRITE_AGAIN, noUminus);
}
+RewriteResponse ArithRewriter::rewriteSub(TNode t)
+{
+ Assert(t.getKind() == kind::SUB);
+ Assert(t.getNumChildren() == 2);
+
+ if (t[0] == t[1])
+ {
+ return RewriteResponse(REWRITE_DONE, rewriter::mkConst(Integer(0)));
+ }
+ auto* nm = NodeManager::currentNM();
+ return RewriteResponse(
+ REWRITE_AGAIN_FULL,
+ nm->mkNode(Kind::ADD,
+ t[0],
+ nm->mkNode(kind::MULT, rewriter::mkConst(Integer(-1)), t[1])));
+}
+
RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
if(t.isConst()){
return rewriteConstant(t);
Assert(node.getKind() == kind::MULT
|| node.getKind() == kind::NONLINEAR_MULT);
- auto res = getZeroChild(node);
- if (res)
+ if (auto res = rewriter::getZeroChild(node); res)
{
return RewriteResponse(REWRITE_DONE, *res);
}
Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
Assert(t.getNumChildren() >= 2);
- if (auto res = getZeroChild(t); res)
+ std::vector<TNode> children;
+ expr::algorithm::flatten(t, children, Kind::MULT, Kind::NONLINEAR_MULT);
+
+ if (auto res = rewriter::getZeroChild(children); res)
{
return RewriteResponse(REWRITE_DONE, *res);
}
- Rational rational = Rational(1);
+ // Distribute over addition
+ if (std::any_of(children.begin(), children.end(), [](TNode child) {
+ return child.getKind() == Kind::ADD;
+ }))
+ {
+ return RewriteResponse(REWRITE_DONE,
+ rewriter::distributeMultiplication(children));
+ }
+
RealAlgebraicNumber ran = RealAlgebraicNumber(Integer(1));
- Polynomial poly = Polynomial::mkOne();
+ std::vector<Node> leafs;
- for (const auto& child : t)
+ for (const auto& child : children)
{
if (child.isConst())
{
{
return RewriteResponse(REWRITE_DONE, child);
}
- rational *= child.getConst<Rational>();
+ ran *= child.getConst<Rational>();
}
- else if (child.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+ else if (rewriter::isRAN(child))
{
- ran *= child.getOperator().getConst<RealAlgebraicNumber>();
+ ran *= rewriter::getRAN(child);
}
else
{
- poly = poly * Polynomial::parsePolynomial(child);
+ leafs.emplace_back(child);
}
}
- if (!rational.isOne())
- {
- poly = poly * rational;
- }
- if (isOne(ran))
- {
- return RewriteResponse(REWRITE_DONE, poly.getNode());
- }
- auto* nm = NodeManager::currentNM();
- if (poly.isConstant())
- {
- ran *= RealAlgebraicNumber(poly.getHead().getConstant().getValue());
- return RewriteResponse(REWRITE_DONE, nm->mkRealAlgebraicNumber(ran));
- }
- return RewriteResponse(
- REWRITE_DONE,
- nm->mkNode(
- Kind::MULT, nm->mkRealAlgebraicNumber(ran), poly.getNode()));
+ return RewriteResponse(REWRITE_DONE,
+ rewriter::mkMultTerm(ran, std::move(leafs)));
}
-RewriteResponse ArithRewriter::postRewritePow2(TNode t)
+RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre)
{
- Assert(t.getKind() == kind::POW2);
- NodeManager* nm = NodeManager::currentNM();
- // if constant, we eliminate
- if (t[0].isConst())
- {
- // pow2 is only supported for integers
- Assert(t[0].getType().isInteger());
- Integer i = t[0].getConst<Rational>().getNumerator();
- if (i < 0)
- {
- return RewriteResponse(REWRITE_DONE, nm->mkConstInt(Rational(0)));
- }
- // (pow2 t) ---> (pow 2 t) and continue rewriting to eliminate pow
- Node two = nm->mkConstInt(Rational(Integer(2)));
- Node ret = nm->mkNode(kind::POW, two, t[0]);
- return RewriteResponse(REWRITE_AGAIN, ret);
- }
- return RewriteResponse(REWRITE_DONE, t);
-}
-
-RewriteResponse ArithRewriter::postRewriteIAnd(TNode t)
-{
- Assert(t.getKind() == kind::IAND);
- size_t bsize = t.getOperator().getConst<IntAnd>().d_size;
- NodeManager* nm = NodeManager::currentNM();
- // if constant, we eliminate
- if (t[0].isConst() && t[1].isConst())
- {
- Node iToBvop = nm->mkConst(IntToBitVector(bsize));
- Node arg1 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[0]);
- Node arg2 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[1]);
- Node bvand = nm->mkNode(kind::BITVECTOR_AND, arg1, arg2);
- Node ret = nm->mkNode(kind::BITVECTOR_TO_NAT, bvand);
- return RewriteResponse(REWRITE_AGAIN_FULL, ret);
- }
- else if (t[0] > t[1])
- {
- // ((_ iand k) x y) ---> ((_ iand k) y x) if x > y by node ordering
- Node ret = nm->mkNode(kind::IAND, t.getOperator(), t[1], t[0]);
- return RewriteResponse(REWRITE_AGAIN, ret);
- }
- else if (t[0] == t[1])
- {
- // ((_ iand k) x x) ---> x
- return RewriteResponse(REWRITE_DONE, t[0]);
- }
- // simplifications involving constants
- for (unsigned i = 0; i < 2; i++)
- {
- if (!t[i].isConst())
- {
- continue;
- }
- if (t[i].getConst<Rational>().sgn() == 0)
- {
- // ((_ iand k) 0 y) ---> 0
- return RewriteResponse(REWRITE_DONE, t[i]);
- }
- if (t[i].getConst<Rational>().getNumerator() == Integer(2).pow(bsize) - 1)
- {
- // ((_ iand k) 111...1 y) ---> (mod y 2^k)
- Node twok = nm->mkConstInt(Rational(Integer(2).pow(bsize)));
- Node ret = nm->mkNode(kind::INTS_MODULUS, t[1-i], twok);
- return RewriteResponse(REWRITE_AGAIN, ret);
- }
- }
- return RewriteResponse(REWRITE_DONE, t);
-}
-
-RewriteResponse ArithRewriter::preRewriteTranscendental(TNode t) {
- return RewriteResponse(REWRITE_DONE, t);
-}
-
-RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) {
- Trace("arith-tf-rewrite") << "Rewrite transcendental function : " << t << std::endl;
- NodeManager* nm = NodeManager::currentNM();
- switch( t.getKind() ){
- case kind::EXPONENTIAL: {
- if (t[0].isConst())
- {
- Node one = nm->mkConstReal(Rational(1));
- if(t[0].getConst<Rational>().sgn()>=0 && t[0].getType().isInteger() && t[0]!=one){
- return RewriteResponse(
- REWRITE_AGAIN,
- nm->mkNode(kind::POW, nm->mkNode(kind::EXPONENTIAL, one), t[0]));
- }else{
- return RewriteResponse(REWRITE_DONE, t);
- }
- }
- else if (t[0].getKind() == kind::ADD)
- {
- std::vector<Node> product;
- for (const Node tc : t[0])
- {
- product.push_back(nm->mkNode(kind::EXPONENTIAL, tc));
- }
- // We need to do a full rewrite here, since we can get exponentials of
- // constants, e.g. when we are rewriting exp(2 + x)
- return RewriteResponse(REWRITE_AGAIN_FULL,
- nm->mkNode(kind::MULT, product));
- }
- }
- break;
- case kind::SINE:
- if (t[0].isConst())
- {
- const Rational& rat = t[0].getConst<Rational>();
- if(rat.sgn() == 0){
- return RewriteResponse(REWRITE_DONE, nm->mkConstReal(Rational(0)));
- }
- else if (rat.sgn() == -1)
- {
- Node ret = nm->mkNode(kind::NEG,
- nm->mkNode(kind::SINE, nm->mkConstReal(-rat)));
- return RewriteResponse(REWRITE_AGAIN_FULL, ret);
- }
- }else{
- // get the factor of PI in the argument
- Node pi_factor;
- Node pi;
- Node rem;
- std::map<Node, Node> msum;
- if (ArithMSum::getMonomialSum(t[0], msum))
- {
- pi = mkPi();
- std::map<Node, Node>::iterator itm = msum.find(pi);
- if (itm != msum.end())
- {
- if (itm->second.isNull())
- {
- pi_factor = nm->mkConstReal(Rational(1));
- }
- else
- {
- pi_factor = itm->second;
- }
- msum.erase(pi);
- if (!msum.empty())
- {
- rem = ArithMSum::mkNode(t[0].getType(), msum);
- }
- }
- }
- else
- {
- Assert(false);
- }
-
- // if there is a factor of PI
- if( !pi_factor.isNull() ){
- Trace("arith-tf-rewrite-debug") << "Process pi factor = " << pi_factor << std::endl;
- Rational r = pi_factor.getConst<Rational>();
- Rational r_abs = r.abs();
- Rational rone = Rational(1);
- Rational rtwo = Rational(2);
- if (r_abs > rone)
- {
- //add/substract 2*pi beyond scope
- Rational ra_div_two = (r_abs + rone) / rtwo;
- Node new_pi_factor;
- if (r.sgn() == 1)
- {
- new_pi_factor = nm->mkConstReal(r - rtwo * ra_div_two.floor());
- }
- else
- {
- Assert(r.sgn() == -1);
- new_pi_factor = nm->mkConstReal(r + rtwo * ra_div_two.floor());
- }
- Node new_arg = nm->mkNode(kind::MULT, new_pi_factor, pi);
- if (!rem.isNull())
- {
- new_arg = nm->mkNode(kind::ADD, new_arg, rem);
- }
- // sin( 2*n*PI + x ) = sin( x )
- return RewriteResponse(REWRITE_AGAIN_FULL,
- nm->mkNode(kind::SINE, new_arg));
- }
- else if (r_abs == rone)
- {
- // sin( PI + x ) = -sin( x )
- if (rem.isNull())
- {
- return RewriteResponse(REWRITE_DONE, nm->mkConstReal(Rational(0)));
- }
- else
- {
- return RewriteResponse(
- REWRITE_AGAIN_FULL,
- nm->mkNode(kind::NEG, nm->mkNode(kind::SINE, rem)));
- }
- }
- else if (rem.isNull())
- {
- // other rational cases based on Niven's theorem
- // (https://en.wikipedia.org/wiki/Niven%27s_theorem)
- Integer one = Integer(1);
- Integer two = Integer(2);
- Integer six = Integer(6);
- if (r_abs.getDenominator() == two)
- {
- Assert(r_abs.getNumerator() == one);
- return RewriteResponse(REWRITE_DONE,
- nm->mkConstReal(Rational(r.sgn())));
- }
- else if (r_abs.getDenominator() == six)
- {
- Integer five = Integer(5);
- if (r_abs.getNumerator() == one || r_abs.getNumerator() == five)
- {
- return RewriteResponse(
- REWRITE_DONE,
- nm->mkConstReal(Rational(r.sgn()) / Rational(2)));
- }
- }
- }
- }
- }
- break;
- case kind::COSINE: {
- return RewriteResponse(
- REWRITE_AGAIN_FULL,
- nm->mkNode(
- kind::SINE,
- nm->mkNode(kind::SUB,
- nm->mkNode(kind::MULT,
- nm->mkConstReal(Rational(1) / Rational(2)),
- mkPi()),
- t[0])));
- }
- break;
- case kind::TANGENT:
- {
- return RewriteResponse(REWRITE_AGAIN_FULL,
- nm->mkNode(kind::DIVISION,
- nm->mkNode(kind::SINE, t[0]),
- nm->mkNode(kind::COSINE, t[0])));
- }
- break;
- case kind::COSECANT:
- {
- return RewriteResponse(REWRITE_AGAIN_FULL,
- nm->mkNode(kind::DIVISION,
- nm->mkConstReal(Rational(1)),
- nm->mkNode(kind::SINE, t[0])));
- }
- break;
- case kind::SECANT:
- {
- return RewriteResponse(REWRITE_AGAIN_FULL,
- nm->mkNode(kind::DIVISION,
- nm->mkConstReal(Rational(1)),
- nm->mkNode(kind::COSINE, t[0])));
- }
- break;
- case kind::COTANGENT:
- {
- return RewriteResponse(REWRITE_AGAIN_FULL,
- nm->mkNode(kind::DIVISION,
- nm->mkNode(kind::COSINE, t[0]),
- nm->mkNode(kind::SINE, t[0])));
- }
- break;
- default:
- break;
- }
- return RewriteResponse(REWRITE_DONE, t);
-}
-
-Node ArithRewriter::makeUnaryMinusNode(TNode n){
- NodeManager* nm = NodeManager::currentNM();
- Rational qNegOne(-1);
- return nm->mkNode(kind::MULT, nm->mkConstRealOrInt(n.getType(), qNegOne), n);
-}
-
-RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind() == kind::DIVISION);
Assert(t.getNumChildren() == 2);
NodeManager* nm = NodeManager::currentNM();
const Rational& den = right.getConst<Rational>();
- if(den.isZero()){
- if(t.getKind() == kind::DIVISION_TOTAL){
+ if (den.isZero())
+ {
+ if (t.getKind() == kind::DIVISION_TOTAL)
+ {
return RewriteResponse(REWRITE_DONE, nm->mkConstReal(0));
- }else{
+ }
+ else
+ {
// This is unsupported, but this is not a good place to complain
return RewriteResponse(REWRITE_DONE, t);
}
}
Node result = nm->mkRealAlgebraicNumber(inverse(den));
- Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
- if(pre){
+ Node mult = NodeManager::currentNM()->mkNode(kind::MULT, left, result);
+ if (pre)
+ {
return RewriteResponse(REWRITE_DONE, mult);
- }else{
+ }
+ else
+ {
return RewriteResponse(REWRITE_AGAIN, mult);
}
}
TNode n = t[0];
TNode d = t[1];
bool dIsConstant = d.isConst();
- if(dIsConstant && d.getConst<Rational>().isZero()){
+ if (dIsConstant && d.getConst<Rational>().isZero())
+ {
// (div x 0) ---> 0 or (mod x 0) ---> 0
return returnRewrite(t, nm->mkConstInt(0), Rewrite::DIV_MOD_BY_ZERO);
- }else if(dIsConstant && d.getConst<Rational>().isOne()){
+ }
+ else if (dIsConstant && d.getConst<Rational>().isOne())
+ {
if (k == kind::INTS_MODULUS_TOTAL)
{
// (mod x 1) --> 0
bool isDiv = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
- Integer result = isDiv ? ni.euclidianDivideQuotient(di) : ni.euclidianDivideRemainder(di);
+ Integer result = isDiv ? ni.euclidianDivideQuotient(di)
+ : ni.euclidianDivideRemainder(di);
// constant evaluation
// (mod c1 c2) ---> c3 or (div c1 c2) ---> c3
return RewriteResponse(REWRITE_DONE, t);
}
+RewriteResponse ArithRewriter::postRewriteIAnd(TNode t)
+{
+ Assert(t.getKind() == kind::IAND);
+ size_t bsize = t.getOperator().getConst<IntAnd>().d_size;
+ NodeManager* nm = NodeManager::currentNM();
+ // if constant, we eliminate
+ if (t[0].isConst() && t[1].isConst())
+ {
+ Node iToBvop = nm->mkConst(IntToBitVector(bsize));
+ Node arg1 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[0]);
+ Node arg2 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[1]);
+ Node bvand = nm->mkNode(kind::BITVECTOR_AND, arg1, arg2);
+ Node ret = nm->mkNode(kind::BITVECTOR_TO_NAT, bvand);
+ return RewriteResponse(REWRITE_AGAIN_FULL, ret);
+ }
+ else if (t[0] > t[1])
+ {
+ // ((_ iand k) x y) ---> ((_ iand k) y x) if x > y by node ordering
+ Node ret = nm->mkNode(kind::IAND, t.getOperator(), t[1], t[0]);
+ return RewriteResponse(REWRITE_AGAIN, ret);
+ }
+ else if (t[0] == t[1])
+ {
+ // ((_ iand k) x x) ---> (mod x 2^k)
+ Node twok = nm->mkConstInt(Rational(Integer(2).pow(bsize)));
+ Node ret = nm->mkNode(kind::INTS_MODULUS, t[0], twok);
+ return RewriteResponse(REWRITE_AGAIN, ret);
+ }
+ // simplifications involving constants
+ for (unsigned i = 0; i < 2; i++)
+ {
+ if (!t[i].isConst())
+ {
+ continue;
+ }
+ if (t[i].getConst<Rational>().sgn() == 0)
+ {
+ // ((_ iand k) 0 y) ---> 0
+ return RewriteResponse(REWRITE_DONE, t[i]);
+ }
+ if (t[i].getConst<Rational>().getNumerator() == Integer(2).pow(bsize) - 1)
+ {
+ // ((_ iand k) 111...1 y) ---> (mod y 2^k)
+ Node twok = nm->mkConstInt(Rational(Integer(2).pow(bsize)));
+ Node ret = nm->mkNode(kind::INTS_MODULUS, t[1 - i], twok);
+ return RewriteResponse(REWRITE_AGAIN, ret);
+ }
+ }
+ return RewriteResponse(REWRITE_DONE, t);
+}
+
+RewriteResponse ArithRewriter::postRewritePow2(TNode t)
+{
+ Assert(t.getKind() == kind::POW2);
+ NodeManager* nm = NodeManager::currentNM();
+ // if constant, we eliminate
+ if (t[0].isConst())
+ {
+ // pow2 is only supported for integers
+ Assert(t[0].getType().isInteger());
+ Integer i = t[0].getConst<Rational>().getNumerator();
+ if (i < 0)
+ {
+ return RewriteResponse(REWRITE_DONE, rewriter::mkConst(Integer(0)));
+ }
+ // (pow2 t) ---> (pow 2 t) and continue rewriting to eliminate pow
+ Node two = rewriter::mkConst(Integer(2));
+ Node ret = nm->mkNode(kind::POW, two, t[0]);
+ return RewriteResponse(REWRITE_AGAIN, ret);
+ }
+ return RewriteResponse(REWRITE_DONE, t);
+}
+
+RewriteResponse ArithRewriter::preRewriteTranscendental(TNode t)
+{
+ return RewriteResponse(REWRITE_DONE, t);
+}
+
+RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t)
+{
+ Trace("arith-tf-rewrite")
+ << "Rewrite transcendental function : " << t << std::endl;
+ NodeManager* nm = NodeManager::currentNM();
+ switch (t.getKind())
+ {
+ case kind::EXPONENTIAL:
+ {
+ if (t[0].isConst())
+ {
+ Node one = rewriter::mkConst(Integer(1));
+ if (t[0].getConst<Rational>().sgn() >= 0 && t[0].getType().isInteger()
+ && t[0] != one)
+ {
+ return RewriteResponse(
+ REWRITE_AGAIN,
+ nm->mkNode(kind::POW, nm->mkNode(kind::EXPONENTIAL, one), t[0]));
+ }
+ else
+ {
+ return RewriteResponse(REWRITE_DONE, t);
+ }
+ }
+ else if (t[0].getKind() == kind::ADD)
+ {
+ std::vector<Node> product;
+ for (const Node tc : t[0])
+ {
+ product.push_back(nm->mkNode(kind::EXPONENTIAL, tc));
+ }
+ // We need to do a full rewrite here, since we can get exponentials of
+ // constants, e.g. when we are rewriting exp(2 + x)
+ return RewriteResponse(REWRITE_AGAIN_FULL,
+ nm->mkNode(kind::MULT, product));
+ }
+ }
+ break;
+ case kind::SINE:
+ if (t[0].isConst())
+ {
+ const Rational& rat = t[0].getConst<Rational>();
+ if (rat.sgn() == 0)
+ {
+ return RewriteResponse(REWRITE_DONE, rewriter::mkConst(Integer(0)));
+ }
+ else if (rat.sgn() == -1)
+ {
+ Node ret = nm->mkNode(
+ kind::NEG, nm->mkNode(kind::SINE, rewriter::mkConst(-rat)));
+ return RewriteResponse(REWRITE_AGAIN_FULL, ret);
+ }
+ }
+ else if ((t[0].getKind() == MULT || t[0].getKind() == NONLINEAR_MULT)
+ && t[0][0].isConst() && t[0][0].getConst<Rational>().sgn() == -1)
+ {
+ // sin(-n*x) ---> -sin(n*x)
+ std::vector<Node> mchildren(t[0].begin(), t[0].end());
+ mchildren[0] = nm->mkConstReal(-t[0][0].getConst<Rational>());
+ Node ret = nm->mkNode(
+ kind::NEG,
+ nm->mkNode(kind::SINE, nm->mkNode(t[0].getKind(), mchildren)));
+ return RewriteResponse(REWRITE_AGAIN_FULL, ret);
+ }
+ else
+ {
+ // get the factor of PI in the argument
+ Node pi_factor;
+ Node pi;
+ Node rem;
+ std::map<Node, Node> msum;
+ if (ArithMSum::getMonomialSum(t[0], msum))
+ {
+ pi = mkPi();
+ std::map<Node, Node>::iterator itm = msum.find(pi);
+ if (itm != msum.end())
+ {
+ if (itm->second.isNull())
+ {
+ pi_factor = rewriter::mkConst(Integer(1));
+ }
+ else
+ {
+ pi_factor = itm->second;
+ }
+ msum.erase(pi);
+ if (!msum.empty())
+ {
+ rem = ArithMSum::mkNode(t[0].getType(), msum);
+ }
+ }
+ }
+ else
+ {
+ Assert(false);
+ }
+
+ // if there is a factor of PI
+ if (!pi_factor.isNull())
+ {
+ Trace("arith-tf-rewrite-debug")
+ << "Process pi factor = " << pi_factor << std::endl;
+ Rational r = pi_factor.getConst<Rational>();
+ Rational r_abs = r.abs();
+ Rational rone = Rational(1);
+ Rational rtwo = Rational(2);
+ if (r_abs > rone)
+ {
+ // add/substract 2*pi beyond scope
+ Rational ra_div_two = (r_abs + rone) / rtwo;
+ Node new_pi_factor;
+ if (r.sgn() == 1)
+ {
+ new_pi_factor = nm->mkConstReal(r - rtwo * ra_div_two.floor());
+ }
+ else
+ {
+ Assert(r.sgn() == -1);
+ new_pi_factor = nm->mkConstReal(r + rtwo * ra_div_two.floor());
+ }
+ Node new_arg = nm->mkNode(kind::MULT, new_pi_factor, pi);
+ if (!rem.isNull())
+ {
+ new_arg = nm->mkNode(kind::ADD, new_arg, rem);
+ }
+ // sin( 2*n*PI + x ) = sin( x )
+ return RewriteResponse(REWRITE_AGAIN_FULL,
+ nm->mkNode(kind::SINE, new_arg));
+ }
+ else if (r_abs == rone)
+ {
+ // sin( PI + x ) = -sin( x )
+ if (rem.isNull())
+ {
+ return RewriteResponse(REWRITE_DONE,
+ nm->mkConstReal(Rational(0)));
+ }
+ else
+ {
+ return RewriteResponse(
+ REWRITE_AGAIN_FULL,
+ nm->mkNode(kind::NEG, nm->mkNode(kind::SINE, rem)));
+ }
+ }
+ else if (rem.isNull())
+ {
+ // other rational cases based on Niven's theorem
+ // (https://en.wikipedia.org/wiki/Niven%27s_theorem)
+ Integer one = Integer(1);
+ Integer two = Integer(2);
+ Integer six = Integer(6);
+ if (r_abs.getDenominator() == two)
+ {
+ Assert(r_abs.getNumerator() == one);
+ return RewriteResponse(REWRITE_DONE,
+ nm->mkConstReal(Rational(r.sgn())));
+ }
+ else if (r_abs.getDenominator() == six)
+ {
+ Integer five = Integer(5);
+ if (r_abs.getNumerator() == one || r_abs.getNumerator() == five)
+ {
+ return RewriteResponse(
+ REWRITE_DONE,
+ nm->mkConstReal(Rational(r.sgn()) / Rational(2)));
+ }
+ }
+ }
+ }
+ }
+ break;
+ case kind::COSINE:
+ {
+ return RewriteResponse(
+ REWRITE_AGAIN_FULL,
+ nm->mkNode(
+ kind::SINE,
+ nm->mkNode(kind::SUB,
+ nm->mkNode(kind::MULT,
+ nm->mkConstReal(Rational(1) / Rational(2)),
+ mkPi()),
+ t[0])));
+ }
+ break;
+ case kind::TANGENT:
+ {
+ return RewriteResponse(REWRITE_AGAIN_FULL,
+ nm->mkNode(kind::DIVISION,
+ nm->mkNode(kind::SINE, t[0]),
+ nm->mkNode(kind::COSINE, t[0])));
+ }
+ break;
+ case kind::COSECANT:
+ {
+ return RewriteResponse(REWRITE_AGAIN_FULL,
+ nm->mkNode(kind::DIVISION,
+ nm->mkConstReal(Rational(1)),
+ nm->mkNode(kind::SINE, t[0])));
+ }
+ break;
+ case kind::SECANT:
+ {
+ return RewriteResponse(REWRITE_AGAIN_FULL,
+ nm->mkNode(kind::DIVISION,
+ nm->mkConstReal(Rational(1)),
+ nm->mkNode(kind::COSINE, t[0])));
+ }
+ break;
+ case kind::COTANGENT:
+ {
+ return RewriteResponse(REWRITE_AGAIN_FULL,
+ nm->mkNode(kind::DIVISION,
+ nm->mkNode(kind::COSINE, t[0]),
+ nm->mkNode(kind::SINE, t[0])));
+ }
+ break;
+ default: break;
+ }
+ return RewriteResponse(REWRITE_DONE, t);
+}
+
TrustNode ArithRewriter::expandDefinition(Node node)
{
// call eliminate operators, to eliminate partial operators only
RewriteResponse ArithRewriter::returnRewrite(TNode t, Node ret, Rewrite r)
{
- Trace("arith-rewrite") << "ArithRewriter : " << t << " == " << ret << " by "
- << r << std::endl;
+ Trace("arith-rewriter") << "ArithRewriter : " << t << " == " << ret << " by "
+ << r << std::endl;
return RewriteResponse(REWRITE_AGAIN_FULL, ret);
}