case kind::ARCCOTANGENT:
case kind::SQRT: return preRewriteTranscendental(t);
case kind::INTS_DIVISION:
- case kind::INTS_MODULUS:
- return RewriteResponse(REWRITE_DONE, t);
+ case kind::INTS_MODULUS: return rewriteIntsDivMod(t, true);
case kind::INTS_DIVISION_TOTAL:
case kind::INTS_MODULUS_TOTAL:
return rewriteIntsDivModTotal(t,true);
case kind::ARCCOTANGENT:
case kind::SQRT: return postRewriteTranscendental(t);
case kind::INTS_DIVISION:
- case kind::INTS_MODULUS:
- return RewriteResponse(REWRITE_DONE, t);
+ case kind::INTS_MODULUS: return rewriteIntsDivMod(t, false);
case kind::INTS_DIVISION_TOTAL:
case kind::INTS_MODULUS_TOTAL:
return rewriteIntsDivModTotal(t, false);
}
}
-RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre){
+RewriteResponse ArithRewriter::rewriteIntsDivMod(TNode t, bool pre)
+{
+ NodeManager* nm = NodeManager::currentNM();
Kind k = t.getKind();
- // Assert(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL ||
- // k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
+ Node zero = nm->mkConst(Rational(0));
+ if (k == kind::INTS_MODULUS)
+ {
+ if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
+ {
+ // can immediately replace by INTS_MODULUS_TOTAL
+ Node ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, t[0], t[1]);
+ return returnRewrite(t, ret, Rewrite::MOD_TOTAL_BY_CONST);
+ }
+ }
+ if (k == kind::INTS_DIVISION)
+ {
+ if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
+ {
+ // can immediately replace by INTS_DIVISION_TOTAL
+ Node ret = nm->mkNode(kind::INTS_DIVISION_TOTAL, t[0], t[1]);
+ return returnRewrite(t, ret, Rewrite::DIV_TOTAL_BY_CONST);
+ }
+ }
+ return RewriteResponse(REWRITE_DONE, t);
+}
- //Leaving the function as before (INTS_MODULUS can be handled),
- // but restricting its use here
+RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre)
+{
+ if (pre)
+ {
+ // do not rewrite at prewrite.
+ return RewriteResponse(REWRITE_DONE, t);
+ }
+ NodeManager* nm = NodeManager::currentNM();
+ Kind k = t.getKind();
Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL);
- TNode n = t[0], d = t[1];
+ TNode n = t[0];
+ TNode d = t[1];
bool dIsConstant = d.getKind() == kind::CONST_RATIONAL;
if(dIsConstant && d.getConst<Rational>().isZero()){
- if(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL){
- return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
- }else{
- // Do nothing for k == INTS_MODULUS
- return RewriteResponse(REWRITE_DONE, t);
- }
+ // (div x 0) ---> 0 or (mod x 0) ---> 0
+ return returnRewrite(t, mkRationalNode(0), Rewrite::DIV_MOD_BY_ZERO);
}else if(dIsConstant && d.getConst<Rational>().isOne()){
- if(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL){
- return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
- }else{
- Assert(k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
- return RewriteResponse(REWRITE_AGAIN, n);
+ if (k == kind::INTS_MODULUS_TOTAL)
+ {
+ // (mod x 1) --> 0
+ return returnRewrite(t, mkRationalNode(0), Rewrite::MOD_BY_ONE);
}
+ Assert(k == kind::INTS_DIVISION_TOTAL);
+ // (div x 1) --> x
+ return returnRewrite(t, n, Rewrite::DIV_BY_ONE);
}
else if (dIsConstant && d.getConst<Rational>().sgn() < 0)
{
// pull negation
- // (div x (- c)) ---> (- (div x c))
- // (mod x (- c)) ---> (mod x c)
- NodeManager* nm = NodeManager::currentNM();
+ // (div x (- c)) ---> (- (div x c))
+ // (mod x (- c)) ---> (mod x c)
Node nn = nm->mkNode(k, t[0], nm->mkConst(-t[1].getConst<Rational>()));
Node ret = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL)
? nm->mkNode(kind::UMINUS, nn)
: nn;
- return RewriteResponse(REWRITE_AGAIN_FULL, ret);
+ return returnRewrite(t, ret, Rewrite::DIV_MOD_PULL_NEG_DEN);
}
else if (dIsConstant && n.getKind() == kind::CONST_RATIONAL)
{
Integer result = isDiv ? ni.euclidianDivideQuotient(di) : ni.euclidianDivideRemainder(di);
+ // constant evaluation
+ // (mod c1 c2) ---> c3 or (div c1 c2) ---> c3
Node resultNode = mkRationalNode(Rational(result));
- return RewriteResponse(REWRITE_DONE, resultNode);
+ return returnRewrite(t, resultNode, Rewrite::CONST_EVAL);
+ }
+ if (k == kind::INTS_MODULUS_TOTAL)
+ {
+ // Note these rewrites do not need to account for modulus by zero as being
+ // a UF, which is handled by the reduction of INTS_MODULUS.
+ Kind k0 = t[0].getKind();
+ if (k0 == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
+ {
+ // (mod (mod x c) c) --> (mod x c)
+ return returnRewrite(t, t[0], Rewrite::MOD_OVER_MOD);
+ }
+ else if (k0 == kind::NONLINEAR_MULT || k0 == kind::MULT || k0 == kind::PLUS)
+ {
+ // can drop all
+ std::vector<Node> newChildren;
+ bool childChanged = false;
+ for (const Node& tc : t[0])
+ {
+ if (tc.getKind() == kind::INTS_MODULUS_TOTAL && tc[1] == t[1])
+ {
+ newChildren.push_back(tc[0]);
+ childChanged = true;
+ continue;
+ }
+ newChildren.push_back(tc);
+ }
+ if (childChanged)
+ {
+ // (mod (op ... (mod x c) ...) c) ---> (mod (op ... x ...) c) where
+ // op is one of { NONLINEAR_MULT, MULT, PLUS }.
+ Node ret = nm->mkNode(k0, newChildren);
+ ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, ret, t[1]);
+ return returnRewrite(t, ret, Rewrite::MOD_CHILD_MOD);
+ }
+ }
}
else
{
- return RewriteResponse(REWRITE_DONE, t);
+ Assert(k == kind::INTS_DIVISION_TOTAL);
+ // Note these rewrites do not need to account for division by zero as being
+ // a UF, which is handled by the reduction of INTS_DIVISION.
+ if (t[0].getKind() == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
+ {
+ // (div (mod x c) c) --> 0
+ Node ret = mkRationalNode(0);
+ return returnRewrite(t, ret, Rewrite::DIV_OVER_MOD);
+ }
}
+ return RewriteResponse(REWRITE_DONE, t);
+}
+
+RewriteResponse ArithRewriter::returnRewrite(TNode t, Node ret, Rewrite r)
+{
+ Trace("arith-rewrite") << "ArithRewriter : " << t << " == " << ret << " by "
+ << r << std::endl;
+ return RewriteResponse(REWRITE_AGAIN_FULL, ret);
}
}/* CVC4::theory::arith namespace */
--- /dev/null
+/********************* */
+/*! \file rewrites.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ ** Andrew Reynolds
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved. See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Type for rewrites for arithmetic.
+ **/
+
+#include "cvc4_private.h"
+
+#ifndef CVC4__THEORY__ARITH__REWRITES_H
+#define CVC4__THEORY__ARITH__REWRITES_H
+
+#include <iosfwd>
+
+namespace CVC4 {
+namespace theory {
+namespace arith {
+
+/**
+ * Types of rewrites used by arithmetic
+ */
+enum class Rewrite : uint32_t
+{
+ NONE,
+ // constant evaluation
+ CONST_EVAL,
+ // (mod x c) replaced by total (mod x c) if c != 0
+ MOD_TOTAL_BY_CONST,
+ // (div x c) replaced by total (div x c) if c != 0
+ DIV_TOTAL_BY_CONST,
+ // Total versions choose arbitrary values for 0 denominator:
+ // (div x 0) ---> 0
+ // (mod x 0) ---> 0
+ DIV_MOD_BY_ZERO,
+ // (mod x 1) --> 0
+ MOD_BY_ONE,
+ // (div x 1) --> x
+ DIV_BY_ONE,
+ // (div x (- c)) ---> (- (div x c))
+ // (mod x (- c)) ---> (mod x c)
+ DIV_MOD_PULL_NEG_DEN,
+ // (mod (mod x c) c) --> (mod x c)
+ MOD_OVER_MOD,
+ // (mod (op ... (mod x c) ...) c) ---> (mod (op ... x ...) c) where
+ // op is one of { NONLINEAR_MULT, MULT, PLUS }.
+ MOD_CHILD_MOD,
+ // (div (mod x c) c) --> 0
+ DIV_OVER_MOD
+};
+
+/**
+ * Converts an rewrite to a string. Note: This function is also used in
+ * `safe_print()`. Changing this functions name or signature will result in
+ * `safe_print()` printing "<unsupported>" instead of the proper strings for
+ * the enum values.
+ *
+ * @param r The rewrite
+ * @return The name of the rewrite
+ */
+const char* toString(Rewrite r);
+
+/**
+ * Writes an rewrite name to a stream.
+ *
+ * @param out The stream to write to
+ * @param r The rewrite to write to the stream
+ * @return The stream
+ */
+std::ostream& operator<<(std::ostream& out, Rewrite r);
+
+} // namespace arith
+} // namespace theory
+} // namespace CVC4
+
+#endif /* CVC4__THEORY__ARITH__REWRITES_H */