Add rewrites for div/mod in the arithmetic rewriter (#5352)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 28 Oct 2020 17:35:43 +0000 (12:35 -0500)
committerGitHub <noreply@github.com>
Wed, 28 Oct 2020 17:35:43 +0000 (12:35 -0500)
This adds some basic rewrites for integer div/mod in the rewriter.

This is in preparation for improved preprocessing and rewriting for NIA problems with heavy use of div/mod.

src/CMakeLists.txt
src/theory/arith/arith_rewriter.cpp
src/theory/arith/arith_rewriter.h
src/theory/arith/rewrites.cpp [new file with mode: 0644]
src/theory/arith/rewrites.h [new file with mode: 0644]

index a9bc8dc7e5c41c49f2090efdc509d2d6afff3274..0c6cd8f50e3690b4a24d3a0af0e5d1c575ddb656 100644 (file)
@@ -388,6 +388,8 @@ libcvc4_add_sources(
   theory/arith/proof_checker.cpp
   theory/arith/proof_checker.h
   theory/arith/proof_macros.h
+  theory/arith/rewrites.cpp
+  theory/arith/rewrites.h
   theory/arith/simplex.cpp
   theory/arith/simplex.h
   theory/arith/simplex_update.cpp
index 1cd6398dc0dd3a9e47bbcadcce525e80540f7028..dc91d678e14bc7d46601ff6ef61ab8e068289af1 100644 (file)
@@ -119,8 +119,7 @@ RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
     case kind::ARCCOTANGENT:
     case kind::SQRT: return preRewriteTranscendental(t);
     case kind::INTS_DIVISION:
-    case kind::INTS_MODULUS:
-      return RewriteResponse(REWRITE_DONE, t);
+    case kind::INTS_MODULUS: return rewriteIntsDivMod(t, true);
     case kind::INTS_DIVISION_TOTAL:
     case kind::INTS_MODULUS_TOTAL:
       return rewriteIntsDivModTotal(t,true);
@@ -183,8 +182,7 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
     case kind::ARCCOTANGENT:
     case kind::SQRT: return postRewriteTranscendental(t);
     case kind::INTS_DIVISION:
-    case kind::INTS_MODULUS:
-      return RewriteResponse(REWRITE_DONE, t);
+    case kind::INTS_MODULUS: return rewriteIntsDivMod(t, false);
     case kind::INTS_DIVISION_TOTAL:
     case kind::INTS_MODULUS_TOTAL:
       return rewriteIntsDivModTotal(t, false);
@@ -763,42 +761,68 @@ RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
   }
 }
 
-RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre){
+RewriteResponse ArithRewriter::rewriteIntsDivMod(TNode t, bool pre)
+{
+  NodeManager* nm = NodeManager::currentNM();
   Kind k = t.getKind();
-  // Assert(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL ||
-  //        k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
+  Node zero = nm->mkConst(Rational(0));
+  if (k == kind::INTS_MODULUS)
+  {
+    if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
+    {
+      // can immediately replace by INTS_MODULUS_TOTAL
+      Node ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, t[0], t[1]);
+      return returnRewrite(t, ret, Rewrite::MOD_TOTAL_BY_CONST);
+    }
+  }
+  if (k == kind::INTS_DIVISION)
+  {
+    if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
+    {
+      // can immediately replace by INTS_DIVISION_TOTAL
+      Node ret = nm->mkNode(kind::INTS_DIVISION_TOTAL, t[0], t[1]);
+      return returnRewrite(t, ret, Rewrite::DIV_TOTAL_BY_CONST);
+    }
+  }
+  return RewriteResponse(REWRITE_DONE, t);
+}
 
-  //Leaving the function as before (INTS_MODULUS can be handled),
-  // but restricting its use here
+RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre)
+{
+  if (pre)
+  {
+    // do not rewrite at prewrite.
+    return RewriteResponse(REWRITE_DONE, t);
+  }
+  NodeManager* nm = NodeManager::currentNM();
+  Kind k = t.getKind();
   Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL);
-  TNode n = t[0], d = t[1];
+  TNode n = t[0];
+  TNode d = t[1];
   bool dIsConstant = d.getKind() == kind::CONST_RATIONAL;
   if(dIsConstant && d.getConst<Rational>().isZero()){
-    if(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL){
-      return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
-    }else{
-      // Do nothing for k == INTS_MODULUS
-      return RewriteResponse(REWRITE_DONE, t);
-    }
+    // (div x 0) ---> 0 or (mod x 0) ---> 0
+    return returnRewrite(t, mkRationalNode(0), Rewrite::DIV_MOD_BY_ZERO);
   }else if(dIsConstant && d.getConst<Rational>().isOne()){
-    if(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL){
-      return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
-    }else{
-      Assert(k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
-      return RewriteResponse(REWRITE_AGAIN, n);
+    if (k == kind::INTS_MODULUS_TOTAL)
+    {
+      // (mod x 1) --> 0
+      return returnRewrite(t, mkRationalNode(0), Rewrite::MOD_BY_ONE);
     }
+    Assert(k == kind::INTS_DIVISION_TOTAL);
+    // (div x 1) --> x
+    return returnRewrite(t, n, Rewrite::DIV_BY_ONE);
   }
   else if (dIsConstant && d.getConst<Rational>().sgn() < 0)
   {
     // pull negation
-    //   (div x (- c)) ---> (- (div x c))
-    //   (mod x (- c)) ---> (mod x c)
-    NodeManager* nm = NodeManager::currentNM();
+    // (div x (- c)) ---> (- (div x c))
+    // (mod x (- c)) ---> (mod x c)
     Node nn = nm->mkNode(k, t[0], nm->mkConst(-t[1].getConst<Rational>()));
     Node ret = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL)
                    ? nm->mkNode(kind::UMINUS, nn)
                    : nn;
-    return RewriteResponse(REWRITE_AGAIN_FULL, ret);
+    return returnRewrite(t, ret, Rewrite::DIV_MOD_PULL_NEG_DEN);
   }
   else if (dIsConstant && n.getKind() == kind::CONST_RATIONAL)
   {
@@ -812,13 +836,66 @@ RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre){
 
     Integer result = isDiv ? ni.euclidianDivideQuotient(di) : ni.euclidianDivideRemainder(di);
 
+    // constant evaluation
+    // (mod c1 c2) ---> c3 or (div c1 c2) ---> c3
     Node resultNode = mkRationalNode(Rational(result));
-    return RewriteResponse(REWRITE_DONE, resultNode);
+    return returnRewrite(t, resultNode, Rewrite::CONST_EVAL);
+  }
+  if (k == kind::INTS_MODULUS_TOTAL)
+  {
+    // Note these rewrites do not need to account for modulus by zero as being
+    // a UF, which is handled by the reduction of INTS_MODULUS.
+    Kind k0 = t[0].getKind();
+    if (k0 == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
+    {
+      // (mod (mod x c) c) --> (mod x c)
+      return returnRewrite(t, t[0], Rewrite::MOD_OVER_MOD);
+    }
+    else if (k0 == kind::NONLINEAR_MULT || k0 == kind::MULT || k0 == kind::PLUS)
+    {
+      // can drop all
+      std::vector<Node> newChildren;
+      bool childChanged = false;
+      for (const Node& tc : t[0])
+      {
+        if (tc.getKind() == kind::INTS_MODULUS_TOTAL && tc[1] == t[1])
+        {
+          newChildren.push_back(tc[0]);
+          childChanged = true;
+          continue;
+        }
+        newChildren.push_back(tc);
+      }
+      if (childChanged)
+      {
+        // (mod (op ... (mod x c) ...) c) ---> (mod (op ... x ...) c) where
+        // op is one of { NONLINEAR_MULT, MULT, PLUS }.
+        Node ret = nm->mkNode(k0, newChildren);
+        ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, ret, t[1]);
+        return returnRewrite(t, ret, Rewrite::MOD_CHILD_MOD);
+      }
+    }
   }
   else
   {
-    return RewriteResponse(REWRITE_DONE, t);
+    Assert(k == kind::INTS_DIVISION_TOTAL);
+    // Note these rewrites do not need to account for division by zero as being
+    // a UF, which is handled by the reduction of INTS_DIVISION.
+    if (t[0].getKind() == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
+    {
+      // (div (mod x c) c) --> 0
+      Node ret = mkRationalNode(0);
+      return returnRewrite(t, ret, Rewrite::DIV_OVER_MOD);
+    }
   }
+  return RewriteResponse(REWRITE_DONE, t);
+}
+
+RewriteResponse ArithRewriter::returnRewrite(TNode t, Node ret, Rewrite r)
+{
+  Trace("arith-rewrite") << "ArithRewriter : " << t << " == " << ret << " by "
+                         << r << std::endl;
+  return RewriteResponse(REWRITE_AGAIN_FULL, ret);
 }
 
 }/* CVC4::theory::arith namespace */
index 458a032bc26801731a89f7abd84c5f56f5d6660d..8f35acd1cf7250c65d12dceb3b6afdef65345c9c 100644 (file)
@@ -20,6 +20,7 @@
 #ifndef CVC4__THEORY__ARITH__ARITH_REWRITER_H
 #define CVC4__THEORY__ARITH__ARITH_REWRITER_H
 
+#include "theory/arith/rewrites.h"
 #include "theory/theory.h"
 #include "theory/theory_rewriter.h"
 
@@ -45,6 +46,7 @@ class ArithRewriter : public TheoryRewriter
   static RewriteResponse rewriteMinus(TNode t, bool pre);
   static RewriteResponse rewriteUMinus(TNode t, bool pre);
   static RewriteResponse rewriteDiv(TNode t, bool pre);
+  static RewriteResponse rewriteIntsDivMod(TNode t, bool pre);
   static RewriteResponse rewriteIntsDivModTotal(TNode t, bool pre);
 
   static RewriteResponse preRewritePlus(TNode t);
@@ -66,7 +68,8 @@ class ArithRewriter : public TheoryRewriter
   static inline bool isTerm(TNode n) {
     return !isAtom(n);
   }
-
+  /** return rewrite */
+  static RewriteResponse returnRewrite(TNode t, Node ret, Rewrite r);
 }; /* class ArithRewriter */
 
 }/* CVC4::theory::arith namespace */
diff --git a/src/theory/arith/rewrites.cpp b/src/theory/arith/rewrites.cpp
new file mode 100644 (file)
index 0000000..870d009
--- /dev/null
@@ -0,0 +1,50 @@
+/*********************                                                        */
+/*! \file rewrites.cpp
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Andrew Reynolds
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Implementation of inference information utility.
+ **/
+
+#include "theory/arith/rewrites.h"
+
+#include <iostream>
+
+namespace CVC4 {
+namespace theory {
+namespace arith {
+
+const char* toString(Rewrite r)
+{
+  switch (r)
+  {
+    case Rewrite::NONE: return "NONE";
+    case Rewrite::CONST_EVAL: return "CONST_EVAL";
+    case Rewrite::MOD_TOTAL_BY_CONST: return "MOD_TOTAL_BY_CONST";
+    case Rewrite::DIV_TOTAL_BY_CONST: return "DIV_TOTAL_BY_CONST";
+    case Rewrite::DIV_MOD_BY_ZERO: return "DIV_MOD_BY_ZERO";
+    case Rewrite::MOD_BY_ONE: return "MOD_BY_ONE";
+    case Rewrite::DIV_BY_ONE: return "DIV_BY_ONE";
+    case Rewrite::DIV_MOD_PULL_NEG_DEN: return "DIV_MOD_PULL_NEG_DEN";
+    case Rewrite::MOD_OVER_MOD: return "MOD_OVER_MOD";
+    case Rewrite::MOD_CHILD_MOD: return "MOD_CHILD_MOD";
+    case Rewrite::DIV_OVER_MOD: return "DIV_OVER_MOD";
+    default: return "?";
+  }
+}
+
+std::ostream& operator<<(std::ostream& out, Rewrite r)
+{
+  out << toString(r);
+  return out;
+}
+
+}  // namespace arith
+}  // namespace theory
+}  // namespace CVC4
diff --git a/src/theory/arith/rewrites.h b/src/theory/arith/rewrites.h
new file mode 100644 (file)
index 0000000..48913e8
--- /dev/null
@@ -0,0 +1,82 @@
+/*********************                                                        */
+/*! \file rewrites.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Andrew Reynolds
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Type for rewrites for arithmetic.
+ **/
+
+#include "cvc4_private.h"
+
+#ifndef CVC4__THEORY__ARITH__REWRITES_H
+#define CVC4__THEORY__ARITH__REWRITES_H
+
+#include <iosfwd>
+
+namespace CVC4 {
+namespace theory {
+namespace arith {
+
+/**
+ * Types of rewrites used by arithmetic
+ */
+enum class Rewrite : uint32_t
+{
+  NONE,
+  // constant evaluation
+  CONST_EVAL,
+  // (mod x c) replaced by total (mod x c) if c != 0
+  MOD_TOTAL_BY_CONST,
+  // (div x c) replaced by total (div x c) if c != 0
+  DIV_TOTAL_BY_CONST,
+  // Total versions choose arbitrary values for 0 denominator:
+  // (div x 0) ---> 0
+  // (mod x 0) ---> 0
+  DIV_MOD_BY_ZERO,
+  // (mod x 1) --> 0
+  MOD_BY_ONE,
+  // (div x 1) --> x
+  DIV_BY_ONE,
+  // (div x (- c)) ---> (- (div x c))
+  // (mod x (- c)) ---> (mod x c)
+  DIV_MOD_PULL_NEG_DEN,
+  // (mod (mod x c) c) --> (mod x c)
+  MOD_OVER_MOD,
+  // (mod (op ... (mod x c) ...) c) ---> (mod (op ... x ...) c) where
+  // op is one of { NONLINEAR_MULT, MULT, PLUS }.
+  MOD_CHILD_MOD,
+  // (div (mod x c) c) --> 0
+  DIV_OVER_MOD
+};
+
+/**
+ * Converts an rewrite to a string. Note: This function is also used in
+ * `safe_print()`. Changing this functions name or signature will result in
+ * `safe_print()` printing "<unsupported>" instead of the proper strings for
+ * the enum values.
+ *
+ * @param r The rewrite
+ * @return The name of the rewrite
+ */
+const char* toString(Rewrite r);
+
+/**
+ * Writes an rewrite name to a stream.
+ *
+ * @param out The stream to write to
+ * @param r The rewrite to write to the stream
+ * @return The stream
+ */
+std::ostream& operator<<(std::ostream& out, Rewrite r);
+
+}  // namespace arith
+}  // namespace theory
+}  // namespace CVC4
+
+#endif /* CVC4__THEORY__ARITH__REWRITES_H */