Fixes for the arithmetic normal form and rewriter to handle arbitrary constants for...
authorTim King <taking@cs.nyu.edu>
Sun, 11 Nov 2012 00:28:05 +0000 (00:28 +0000)
committerTim King <taking@cs.nyu.edu>
Sun, 11 Nov 2012 00:28:05 +0000 (00:28 +0000)
src/theory/arith/arith_rewriter.cpp
src/theory/arith/arith_rewriter.h
src/theory/arith/normal_form.cpp

index b6275ba2421c44db678cdf414a7171568be07881..689f231e6cfa9317cdd8f05eaaa65978b25f3910 100644 (file)
@@ -80,58 +80,28 @@ RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
     return rewriteConstant(t);
   }else if(t.isVar()){
     return rewriteVariable(t);
-  }else if(t.getKind() == kind::MINUS){
-    return rewriteMinus(t, true);
-  }else if(t.getKind() == kind::UMINUS){
-    return rewriteUMinus(t, true);
-  }else if(t.getKind() == kind::DIVISION){
-    return RewriteResponse(REWRITE_DONE, t); // wait until t[1] is rewritten
-  }else if(t.getKind() == kind::DIVISION_TOTAL){
-    if(t[1].getKind()== kind::CONST_RATIONAL &&
-       t[1].getConst<Rational>().isZero()){
-      return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
-    }else{
-      return RewriteResponse(REWRITE_DONE, t); // wait until t[1] is rewritten
-    }
-  }else if(t.getKind() == kind::PLUS){
-    return preRewritePlus(t);
-  }else if(t.getKind() == kind::MULT){
-    return preRewriteMult(t);
-  }else if(t.getKind() == kind::INTS_DIVISION){
-    Rational intOne(1);
-    if(t[1].getKind()== kind::CONST_RATIONAL &&
-       t[1].getConst<Rational>().isOne()){
-      return RewriteResponse(REWRITE_AGAIN, t[0]);
-    }else{
-      return RewriteResponse(REWRITE_DONE, t);
-    }
-  }else if(t.getKind() == kind::INTS_DIVISION_TOTAL){
-    if(t[1].getKind()== kind::CONST_RATIONAL){
-      Rational intOne(1), intZero(0);
-      if(t[1].getConst<Rational>().isOne()){
-        return RewriteResponse(REWRITE_AGAIN, t[0]);
-      } else if(t[1].getConst<Rational>().isZero()){
-        return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
-      }
-    }
-    return RewriteResponse(REWRITE_DONE, t);
-  }else if(t.getKind() == kind::INTS_MODULUS){
-    Rational intOne(1);
-    if(t[1].getKind()== kind::CONST_RATIONAL &&
-       t[1].getConst<Rational>().isOne()){
-      return RewriteResponse(REWRITE_AGAIN, mkRationalNode(0));
-    }else{
-      return RewriteResponse(REWRITE_DONE, t);
-    }
-  }else if(t.getKind() == kind::INTS_MODULUS_TOTAL){
-    if(t[1].getKind()== kind::CONST_RATIONAL){
-      if(t[1].getConst<Rational>().isOne() || t[1].getConst<Rational>().isZero()){
-        return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
-      }
-    }
-    return RewriteResponse(REWRITE_DONE, t);
   }else{
-    Unreachable();
+    switch(t.getKind()){
+    case kind::MINUS:
+      return rewriteMinus(t, true);
+    case kind::UMINUS:
+      return rewriteUMinus(t, true);
+    case kind::DIVISION:
+      return rewriteDiv(t,true);
+    case kind::DIVISION_TOTAL:
+      return rewriteDivTotal(t,true);
+    case kind::PLUS:
+      return preRewritePlus(t);
+    case kind::MULT:
+      return preRewriteMult(t);
+    //case kind::INTS_DIVISION:
+    //case kind::INTS_MODULUS:
+    case kind::INTS_DIVISION_TOTAL:
+    case kind::INTS_MODULUS_TOTAL:
+      return rewriteIntsDivModTotal(t,true);
+    default:
+      Unreachable();
+    }
   }
 }
 RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
@@ -139,33 +109,32 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
     return rewriteConstant(t);
   }else if(t.isVar()){
     return rewriteVariable(t);
-  }else if(t.getKind() == kind::MINUS){
-    return rewriteMinus(t, false);
-  }else if(t.getKind() == kind::UMINUS){
-    return rewriteUMinus(t, false);
-  }else if(t.getKind() == kind::DIVISION ||
-           t.getKind() == kind::DIVISION_TOTAL){
-    return rewriteDiv(t, false);
-  }else if(t.getKind() == kind::PLUS){
-    return postRewritePlus(t);
-  }else if(t.getKind() == kind::MULT){
-    return postRewriteMult(t);
-  }else if(t.getKind() == kind::INTS_DIVISION ||
-           t.getKind() == kind::INTS_MODULUS){
-    return RewriteResponse(REWRITE_DONE, t);
-  }else if(t.getKind() == kind::INTS_DIVISION_TOTAL ||
-           t.getKind() == kind::INTS_MODULUS_TOTAL){
-    if(t[1].getKind() == kind::CONST_RATIONAL &&
-       t[1].getConst<Rational>().isZero()){
-      return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
-    }else{
-      return RewriteResponse(REWRITE_DONE, t);
-    }
   }else{
-    Unreachable();
+    switch(t.getKind()){
+    case kind::MINUS:
+      return rewriteMinus(t, false);
+    case kind::UMINUS:
+      return rewriteUMinus(t, false);
+    case kind::DIVISION:
+      return rewriteDiv(t, false);
+    case kind::DIVISION_TOTAL:
+      return rewriteDivTotal(t, false);
+    case kind::PLUS:
+      return postRewritePlus(t);
+    case kind::MULT:
+      return postRewriteMult(t);
+      //case kind::INTS_DIVISION:
+      //case kind::INTS_MODULUS:
+    case kind::INTS_DIVISION_TOTAL:
+    case kind::INTS_MODULUS_TOTAL:
+      return rewriteIntsDivModTotal(t, false);
+    default:
+      Unreachable();
+    }
   }
 }
 
+
 RewriteResponse ArithRewriter::preRewriteMult(TNode t){
   Assert(t.getKind()== kind::MULT);
 
@@ -217,19 +186,6 @@ RewriteResponse ArithRewriter::postRewriteMult(TNode t){
   return RewriteResponse(REWRITE_DONE, res.getNode());
 }
 
-// RewriteResponse ArithRewriter::postRewriteAtomConstantRHS(TNode t){
-//   TNode left  = t[0];
-//   TNode right = t[1];
-
-//   Polynomial pLeft = Polynomial::parsePolynomial(left);
-  
-
-//   Comparison cmp = Comparison::mkComparison(t.getKind(), Polynomial::parsePolynomial(left), Constant(right));
-
-//   Assert(cmp.isNormalForm());
-//   return RewriteResponse(REWRITE_DONE, cmp.getNode());
-// }
-
 RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
   // left |><| right
   TNode left = atom[0];
@@ -304,9 +260,52 @@ Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
 
   return diff;
 }
-
 RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
-  Assert(t.getKind()== kind::DIVISION || t.getKind() == kind::DIVISION_TOTAL);
+  Assert(t.getKind()== kind::DIVISION);
+
+  Node left = t[0];
+  Node right = t[1];
+
+  if(right.getKind() == kind::CONST_RATIONAL &&
+     left.getKind() != kind::CONST_RATIONAL){
+
+    const Rational& den = right.getConst<Rational>();
+
+    Assert(!den.isZero());
+
+    Rational div = den.inverse();
+    Node result =  mkRationalNode(div);
+    Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
+    if(pre){
+      return RewriteResponse(REWRITE_DONE, mult);
+    }else{
+      return RewriteResponse(REWRITE_AGAIN, mult);
+    }
+  }
+
+  if(pre){
+    if(right.getKind() != kind::CONST_RATIONAL ||
+       left.getKind() != kind::CONST_RATIONAL){
+      return RewriteResponse(REWRITE_DONE, t);
+    }
+  }
+
+  Assert(right.getKind() == kind::CONST_RATIONAL);
+  Assert(left.getKind() == kind::CONST_RATIONAL);
+
+  const Rational& den = right.getConst<Rational>();
+
+  Assert(!den.isZero());
+
+  const Rational& num = left.getConst<Rational>();
+  Rational div = num / den;
+  Node result =  mkRationalNode(div);
+  return RewriteResponse(REWRITE_DONE, result);
+}
+
+RewriteResponse ArithRewriter::rewriteDivTotal(TNode t, bool pre){
+  Assert(t.getKind() == kind::DIVISION_TOTAL);
+
 
   Node left = t[0];
   Node right = t[1];
@@ -314,14 +313,17 @@ RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
     const Rational& den = right.getConst<Rational>();
 
     if(den.isZero()){
-      if(t.getKind() == kind::DIVISION_TOTAL){
-        return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
-      }else{
-        return RewriteResponse(REWRITE_DONE, t);
-      }
+      return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
     }
     Assert(den != Rational(0));
 
+    if(left.getKind() == kind::CONST_RATIONAL){
+      const Rational& num = left.getConst<Rational>();
+      Rational div = num / den;
+      Node result =  mkRationalNode(div);
+      return RewriteResponse(REWRITE_DONE, result);
+    }
+
     Rational div = den.inverse();
 
     Node result = mkRationalNode(div);
@@ -337,6 +339,48 @@ RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
   }
 }
 
+RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre){
+  Kind k = t.getKind();
+  // Assert(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL ||
+  //        k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
+
+  //Leaving the function as before (INTS_MODULUS can be handled),
+  // but restricting its use here
+  Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL);
+  TNode n = t[0], d = t[1];
+  bool dIsConstant = d.getKind() == kind::CONST_RATIONAL;
+  if(dIsConstant && d.getConst<Rational>().isZero()){
+    if(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL){
+      return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
+    }else{
+      // Do nothing for k == INTS_MODULUS
+      return RewriteResponse(REWRITE_DONE, t);
+    }
+  }else if(dIsConstant && d.getConst<Rational>().isOne()){
+    if(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL){
+      return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
+    }else{
+      Assert(k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
+      return RewriteResponse(REWRITE_AGAIN, n);
+    }
+  }else if(dIsConstant && n.getKind() == kind::CONST_RATIONAL){
+    Assert(d.getConst<Rational>().isIntegral());
+    Assert(n.getConst<Rational>().isIntegral());
+    Assert(!d.getConst<Rational>().isZero());
+    Integer di = d.getConst<Rational>().getNumerator();
+    Integer ni = n.getConst<Rational>().getNumerator();
+
+    bool isDiv = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
+
+    Integer result = isDiv ? ni.floorDivideQuotient(di) : ni.floorDivideRemainder(di);
+
+    Node resultNode = mkRationalNode(Rational(result));
+    return RewriteResponse(REWRITE_DONE, resultNode);
+  }else{
+    return RewriteResponse(REWRITE_DONE, t);
+  }
+}
+
 }/* CVC4::theory::arith namespace */
 }/* CVC4::theory namespace */
 }/* CVC4 namespace */
index 986ff369d91ea6c367fb754c6a2b4a50f8f77e7e..10e2555358598b187ae034a02f6cc4b065ca3fdb 100644 (file)
@@ -50,6 +50,8 @@ private:
   static RewriteResponse rewriteMinus(TNode t, bool pre);
   static RewriteResponse rewriteUMinus(TNode t, bool pre);
   static RewriteResponse rewriteDiv(TNode t, bool pre);
+  static RewriteResponse rewriteDivTotal(TNode t, bool pre);
+  static RewriteResponse rewriteIntsDivModTotal(TNode t, bool pre);
 
   static RewriteResponse preRewritePlus(TNode t);
   static RewriteResponse postRewritePlus(TNode t);
index 9bd0a3b6cd7d7d7bc238474c581edf82acbf0989..c863bf3c5b27c4bf8c43dcb3c476e876266d1e65 100644 (file)
@@ -27,9 +27,9 @@ namespace arith {
 
 bool Variable::isDivMember(Node n){
   switch(n.getKind()){
-  case kind::DIVISION:
-  case kind::INTS_DIVISION:
-  case kind::INTS_MODULUS:
+    //case kind::DIVISION:
+    //case kind::INTS_DIVISION:
+    //case kind::INTS_MODULUS:
   case kind::DIVISION_TOTAL:
   case kind::INTS_DIVISION_TOTAL:
   case kind::INTS_MODULUS_TOTAL: