From: Andrew Reynolds Date: Wed, 28 Oct 2020 17:35:43 +0000 (-0500) Subject: Add rewrites for div/mod in the arithmetic rewriter (#5352) X-Git-Tag: cvc5-1.0.0~2654 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=b0dd5a3adc67d72a08ca9d8d3de208840a1001a3;p=cvc5.git Add rewrites for div/mod in the arithmetic rewriter (#5352) This adds some basic rewrites for integer div/mod in the rewriter. This is in preparation for improved preprocessing and rewriting for NIA problems with heavy use of div/mod. --- diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index a9bc8dc7e..0c6cd8f50 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -388,6 +388,8 @@ libcvc4_add_sources( theory/arith/proof_checker.cpp theory/arith/proof_checker.h theory/arith/proof_macros.h + theory/arith/rewrites.cpp + theory/arith/rewrites.h theory/arith/simplex.cpp theory/arith/simplex.h theory/arith/simplex_update.cpp diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp index 1cd6398dc..dc91d678e 100644 --- a/src/theory/arith/arith_rewriter.cpp +++ b/src/theory/arith/arith_rewriter.cpp @@ -119,8 +119,7 @@ RewriteResponse ArithRewriter::preRewriteTerm(TNode t){ case kind::ARCCOTANGENT: case kind::SQRT: return preRewriteTranscendental(t); case kind::INTS_DIVISION: - case kind::INTS_MODULUS: - return RewriteResponse(REWRITE_DONE, t); + case kind::INTS_MODULUS: return rewriteIntsDivMod(t, true); case kind::INTS_DIVISION_TOTAL: case kind::INTS_MODULUS_TOTAL: return rewriteIntsDivModTotal(t,true); @@ -183,8 +182,7 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){ case kind::ARCCOTANGENT: case kind::SQRT: return postRewriteTranscendental(t); case kind::INTS_DIVISION: - case kind::INTS_MODULUS: - return RewriteResponse(REWRITE_DONE, t); + case kind::INTS_MODULUS: return rewriteIntsDivMod(t, false); case kind::INTS_DIVISION_TOTAL: case kind::INTS_MODULUS_TOTAL: return rewriteIntsDivModTotal(t, false); @@ -763,42 +761,68 @@ RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){ } } -RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre){ +RewriteResponse ArithRewriter::rewriteIntsDivMod(TNode t, bool pre) +{ + NodeManager* nm = NodeManager::currentNM(); Kind k = t.getKind(); - // Assert(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL || - // k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL); + Node zero = nm->mkConst(Rational(0)); + if (k == kind::INTS_MODULUS) + { + if (t[1].isConst() && !t[1].getConst().isZero()) + { + // can immediately replace by INTS_MODULUS_TOTAL + Node ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, t[0], t[1]); + return returnRewrite(t, ret, Rewrite::MOD_TOTAL_BY_CONST); + } + } + if (k == kind::INTS_DIVISION) + { + if (t[1].isConst() && !t[1].getConst().isZero()) + { + // can immediately replace by INTS_DIVISION_TOTAL + Node ret = nm->mkNode(kind::INTS_DIVISION_TOTAL, t[0], t[1]); + return returnRewrite(t, ret, Rewrite::DIV_TOTAL_BY_CONST); + } + } + return RewriteResponse(REWRITE_DONE, t); +} - //Leaving the function as before (INTS_MODULUS can be handled), - // but restricting its use here +RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre) +{ + if (pre) + { + // do not rewrite at prewrite. + return RewriteResponse(REWRITE_DONE, t); + } + NodeManager* nm = NodeManager::currentNM(); + Kind k = t.getKind(); Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL); - TNode n = t[0], d = t[1]; + TNode n = t[0]; + TNode d = t[1]; bool dIsConstant = d.getKind() == kind::CONST_RATIONAL; if(dIsConstant && d.getConst().isZero()){ - if(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL){ - return RewriteResponse(REWRITE_DONE, mkRationalNode(0)); - }else{ - // Do nothing for k == INTS_MODULUS - return RewriteResponse(REWRITE_DONE, t); - } + // (div x 0) ---> 0 or (mod x 0) ---> 0 + return returnRewrite(t, mkRationalNode(0), Rewrite::DIV_MOD_BY_ZERO); }else if(dIsConstant && d.getConst().isOne()){ - if(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL){ - return RewriteResponse(REWRITE_DONE, mkRationalNode(0)); - }else{ - Assert(k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL); - return RewriteResponse(REWRITE_AGAIN, n); + if (k == kind::INTS_MODULUS_TOTAL) + { + // (mod x 1) --> 0 + return returnRewrite(t, mkRationalNode(0), Rewrite::MOD_BY_ONE); } + Assert(k == kind::INTS_DIVISION_TOTAL); + // (div x 1) --> x + return returnRewrite(t, n, Rewrite::DIV_BY_ONE); } else if (dIsConstant && d.getConst().sgn() < 0) { // pull negation - // (div x (- c)) ---> (- (div x c)) - // (mod x (- c)) ---> (mod x c) - NodeManager* nm = NodeManager::currentNM(); + // (div x (- c)) ---> (- (div x c)) + // (mod x (- c)) ---> (mod x c) Node nn = nm->mkNode(k, t[0], nm->mkConst(-t[1].getConst())); Node ret = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL) ? nm->mkNode(kind::UMINUS, nn) : nn; - return RewriteResponse(REWRITE_AGAIN_FULL, ret); + return returnRewrite(t, ret, Rewrite::DIV_MOD_PULL_NEG_DEN); } else if (dIsConstant && n.getKind() == kind::CONST_RATIONAL) { @@ -812,13 +836,66 @@ RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre){ Integer result = isDiv ? ni.euclidianDivideQuotient(di) : ni.euclidianDivideRemainder(di); + // constant evaluation + // (mod c1 c2) ---> c3 or (div c1 c2) ---> c3 Node resultNode = mkRationalNode(Rational(result)); - return RewriteResponse(REWRITE_DONE, resultNode); + return returnRewrite(t, resultNode, Rewrite::CONST_EVAL); + } + if (k == kind::INTS_MODULUS_TOTAL) + { + // Note these rewrites do not need to account for modulus by zero as being + // a UF, which is handled by the reduction of INTS_MODULUS. + Kind k0 = t[0].getKind(); + if (k0 == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1]) + { + // (mod (mod x c) c) --> (mod x c) + return returnRewrite(t, t[0], Rewrite::MOD_OVER_MOD); + } + else if (k0 == kind::NONLINEAR_MULT || k0 == kind::MULT || k0 == kind::PLUS) + { + // can drop all + std::vector newChildren; + bool childChanged = false; + for (const Node& tc : t[0]) + { + if (tc.getKind() == kind::INTS_MODULUS_TOTAL && tc[1] == t[1]) + { + newChildren.push_back(tc[0]); + childChanged = true; + continue; + } + newChildren.push_back(tc); + } + if (childChanged) + { + // (mod (op ... (mod x c) ...) c) ---> (mod (op ... x ...) c) where + // op is one of { NONLINEAR_MULT, MULT, PLUS }. + Node ret = nm->mkNode(k0, newChildren); + ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, ret, t[1]); + return returnRewrite(t, ret, Rewrite::MOD_CHILD_MOD); + } + } } else { - return RewriteResponse(REWRITE_DONE, t); + Assert(k == kind::INTS_DIVISION_TOTAL); + // Note these rewrites do not need to account for division by zero as being + // a UF, which is handled by the reduction of INTS_DIVISION. + if (t[0].getKind() == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1]) + { + // (div (mod x c) c) --> 0 + Node ret = mkRationalNode(0); + return returnRewrite(t, ret, Rewrite::DIV_OVER_MOD); + } } + return RewriteResponse(REWRITE_DONE, t); +} + +RewriteResponse ArithRewriter::returnRewrite(TNode t, Node ret, Rewrite r) +{ + Trace("arith-rewrite") << "ArithRewriter : " << t << " == " << ret << " by " + << r << std::endl; + return RewriteResponse(REWRITE_AGAIN_FULL, ret); } }/* CVC4::theory::arith namespace */ diff --git a/src/theory/arith/arith_rewriter.h b/src/theory/arith/arith_rewriter.h index 458a032bc..8f35acd1c 100644 --- a/src/theory/arith/arith_rewriter.h +++ b/src/theory/arith/arith_rewriter.h @@ -20,6 +20,7 @@ #ifndef CVC4__THEORY__ARITH__ARITH_REWRITER_H #define CVC4__THEORY__ARITH__ARITH_REWRITER_H +#include "theory/arith/rewrites.h" #include "theory/theory.h" #include "theory/theory_rewriter.h" @@ -45,6 +46,7 @@ class ArithRewriter : public TheoryRewriter static RewriteResponse rewriteMinus(TNode t, bool pre); static RewriteResponse rewriteUMinus(TNode t, bool pre); static RewriteResponse rewriteDiv(TNode t, bool pre); + static RewriteResponse rewriteIntsDivMod(TNode t, bool pre); static RewriteResponse rewriteIntsDivModTotal(TNode t, bool pre); static RewriteResponse preRewritePlus(TNode t); @@ -66,7 +68,8 @@ class ArithRewriter : public TheoryRewriter static inline bool isTerm(TNode n) { return !isAtom(n); } - + /** return rewrite */ + static RewriteResponse returnRewrite(TNode t, Node ret, Rewrite r); }; /* class ArithRewriter */ }/* CVC4::theory::arith namespace */ diff --git a/src/theory/arith/rewrites.cpp b/src/theory/arith/rewrites.cpp new file mode 100644 index 000000000..870d009d2 --- /dev/null +++ b/src/theory/arith/rewrites.cpp @@ -0,0 +1,50 @@ +/********************* */ +/*! \file rewrites.cpp + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Implementation of inference information utility. + **/ + +#include "theory/arith/rewrites.h" + +#include + +namespace CVC4 { +namespace theory { +namespace arith { + +const char* toString(Rewrite r) +{ + switch (r) + { + case Rewrite::NONE: return "NONE"; + case Rewrite::CONST_EVAL: return "CONST_EVAL"; + case Rewrite::MOD_TOTAL_BY_CONST: return "MOD_TOTAL_BY_CONST"; + case Rewrite::DIV_TOTAL_BY_CONST: return "DIV_TOTAL_BY_CONST"; + case Rewrite::DIV_MOD_BY_ZERO: return "DIV_MOD_BY_ZERO"; + case Rewrite::MOD_BY_ONE: return "MOD_BY_ONE"; + case Rewrite::DIV_BY_ONE: return "DIV_BY_ONE"; + case Rewrite::DIV_MOD_PULL_NEG_DEN: return "DIV_MOD_PULL_NEG_DEN"; + case Rewrite::MOD_OVER_MOD: return "MOD_OVER_MOD"; + case Rewrite::MOD_CHILD_MOD: return "MOD_CHILD_MOD"; + case Rewrite::DIV_OVER_MOD: return "DIV_OVER_MOD"; + default: return "?"; + } +} + +std::ostream& operator<<(std::ostream& out, Rewrite r) +{ + out << toString(r); + return out; +} + +} // namespace arith +} // namespace theory +} // namespace CVC4 diff --git a/src/theory/arith/rewrites.h b/src/theory/arith/rewrites.h new file mode 100644 index 000000000..48913e8cf --- /dev/null +++ b/src/theory/arith/rewrites.h @@ -0,0 +1,82 @@ +/********************* */ +/*! \file rewrites.h + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Type for rewrites for arithmetic. + **/ + +#include "cvc4_private.h" + +#ifndef CVC4__THEORY__ARITH__REWRITES_H +#define CVC4__THEORY__ARITH__REWRITES_H + +#include + +namespace CVC4 { +namespace theory { +namespace arith { + +/** + * Types of rewrites used by arithmetic + */ +enum class Rewrite : uint32_t +{ + NONE, + // constant evaluation + CONST_EVAL, + // (mod x c) replaced by total (mod x c) if c != 0 + MOD_TOTAL_BY_CONST, + // (div x c) replaced by total (div x c) if c != 0 + DIV_TOTAL_BY_CONST, + // Total versions choose arbitrary values for 0 denominator: + // (div x 0) ---> 0 + // (mod x 0) ---> 0 + DIV_MOD_BY_ZERO, + // (mod x 1) --> 0 + MOD_BY_ONE, + // (div x 1) --> x + DIV_BY_ONE, + // (div x (- c)) ---> (- (div x c)) + // (mod x (- c)) ---> (mod x c) + DIV_MOD_PULL_NEG_DEN, + // (mod (mod x c) c) --> (mod x c) + MOD_OVER_MOD, + // (mod (op ... (mod x c) ...) c) ---> (mod (op ... x ...) c) where + // op is one of { NONLINEAR_MULT, MULT, PLUS }. + MOD_CHILD_MOD, + // (div (mod x c) c) --> 0 + DIV_OVER_MOD +}; + +/** + * Converts an rewrite to a string. Note: This function is also used in + * `safe_print()`. Changing this functions name or signature will result in + * `safe_print()` printing "" instead of the proper strings for + * the enum values. + * + * @param r The rewrite + * @return The name of the rewrite + */ +const char* toString(Rewrite r); + +/** + * Writes an rewrite name to a stream. + * + * @param out The stream to write to + * @param r The rewrite to write to the stream + * @return The stream + */ +std::ostream& operator<<(std::ostream& out, Rewrite r); + +} // namespace arith +} // namespace theory +} // namespace CVC4 + +#endif /* CVC4__THEORY__ARITH__REWRITES_H */