Merge branch '1.2.x'
[cvc5.git] / src / theory / arith / arith_rewriter.cpp
index 30568c3cac43623d51d7ae122952f4da1356dd4e..247c09294e5341463850625919c77c6160ced21d 100644 (file)
@@ -1,13 +1,11 @@
 /*********************                                                        */
 /*! \file arith_rewriter.cpp
  ** \verbatim
- ** Original author: taking
+ ** Original author: Tim King
  ** Major contributors: none
- ** Minor contributors (to current version): mdeters, dejan
- ** This file is part of the CVC4 prototype.
- ** Copyright (c) 2009, 2010, 2011  The Analysis of Computer Systems Group (ACSys)
- ** Courant Institute of Mathematical Sciences
- ** New York University
+ ** Minor contributors (to current version): Morgan Deters, Dejan Jovanovic
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2013  New York University and The University of Iowa
  ** See the file COPYING in the top-level source directory for licensing
  ** information.\endverbatim
  **
@@ -30,23 +28,20 @@ namespace CVC4 {
 namespace theory {
 namespace arith {
 
-bool isVariable(TNode t){
-  return t.getMetaKind() == kind::metakind::VARIABLE;
-}
-
 bool ArithRewriter::isAtom(TNode n) {
-  return arith::isRelationOperator(n.getKind());
+  Kind k = n.getKind();
+  return arith::isRelationOperator(k) || k == kind::IS_INTEGER || k == kind::DIVISIBLE;
 }
 
 RewriteResponse ArithRewriter::rewriteConstant(TNode t){
-  Assert(t.getMetaKind() == kind::metakind::CONSTANT);
-  Node val = coerceToRationalNode(t);
+  Assert(t.isConst());
+  Assert(t.getKind() == kind::CONST_RATIONAL);
 
-  return RewriteResponse(REWRITE_DONE, val);
+  return RewriteResponse(REWRITE_DONE, t);
 }
 
 RewriteResponse ArithRewriter::rewriteVariable(TNode t){
-  Assert(isVariable(t));
+  Assert(t.isVar());
 
   return RewriteResponse(REWRITE_DONE, t);
 }
@@ -65,7 +60,7 @@ RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
     }
   }else{
     Polynomial minuend = Polynomial::parsePolynomial(t[0]);
-    Polynomial subtrahend = Polynomial::parsePolynomial(t[0]);
+    Polynomial subtrahend = Polynomial::parsePolynomial(t[1]);
     Polynomial diff = minuend - subtrahend;
     return RewriteResponse(REWRITE_DONE, diff.getNode());
   }
@@ -74,6 +69,11 @@ RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
 RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
   Assert(t.getKind()== kind::UMINUS);
 
+  if(t[0].getKind() == kind::CONST_RATIONAL){
+    Rational neg = -(t[0].getConst<Rational>());
+    return RewriteResponse(REWRITE_DONE, mkRationalNode(neg));
+  }
+
   Node noUminus = makeUnaryMinusNode(t[0]);
   if(pre)
     return RewriteResponse(REWRITE_DONE, noUminus);
@@ -82,73 +82,117 @@ RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
 }
 
 RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
-  if(t.getMetaKind() == kind::metakind::CONSTANT){
+  if(t.isConst()){
     return rewriteConstant(t);
-  }else if(isVariable(t)){
+  }else if(t.isVar()){
     return rewriteVariable(t);
-  }else if(t.getKind() == kind::MINUS){
-    return rewriteMinus(t, true);
-  }else if(t.getKind() == kind::UMINUS){
-    return rewriteUMinus(t, true);
-  }else if(t.getKind() == kind::DIVISION){
-    if(t[0].getKind()== kind::CONST_RATIONAL){
-      return rewriteDivByConstant(t, true);
-    }else{
+  }else{
+    switch(Kind k = t.getKind()){
+    case kind::MINUS:
+      return rewriteMinus(t, true);
+    case kind::UMINUS:
+      return rewriteUMinus(t, true);
+    case kind::DIVISION:
+    case kind::DIVISION_TOTAL:
+      return rewriteDiv(t,true);
+    case kind::PLUS:
+      return preRewritePlus(t);
+    case kind::MULT:
+      return preRewriteMult(t);
+    case kind::INTS_DIVISION:
+    case kind::INTS_MODULUS:
       return RewriteResponse(REWRITE_DONE, t);
-    }
-  }else if(t.getKind() == kind::PLUS){
-    return preRewritePlus(t);
-  }else if(t.getKind() == kind::MULT){
-    return preRewriteMult(t);
-  }else if(t.getKind() == kind::INTS_DIVISION){
-    Integer intOne(1);
-    if(t[1].getKind()== kind::CONST_INTEGER && t[1].getConst<Integer>() == intOne){
-      return RewriteResponse(REWRITE_AGAIN, t[0]);
-    }else{
+    case kind::INTS_DIVISION_TOTAL:
+    case kind::INTS_MODULUS_TOTAL:
+      return rewriteIntsDivModTotal(t,true);
+    case kind::ABS:
+      if(t[0].isConst()) {
+        const Rational& rat = t[0].getConst<Rational>();
+        if(rat >= 0) {
+          return RewriteResponse(REWRITE_DONE, t[0]);
+        } else {
+          return RewriteResponse(REWRITE_DONE,
+                                 NodeManager::currentNM()->mkConst(-rat));
+        }
+      }
       return RewriteResponse(REWRITE_DONE, t);
-    }
-  }else if(t.getKind() == kind::INTS_MODULUS){
-    Integer intOne(1);
-    if(t[1].getKind()== kind::CONST_INTEGER && t[1].getConst<Integer>() == intOne){
-      Integer intZero(0);
-      return RewriteResponse(REWRITE_AGAIN, mkIntegerNode(intZero));
-    }else{
+    case kind::IS_INTEGER:
+    case kind::TO_INTEGER:
       return RewriteResponse(REWRITE_DONE, t);
+    case kind::TO_REAL:
+      return RewriteResponse(REWRITE_DONE, t[0]);
+    default:
+      Unhandled(k);
     }
-  }else{
-    Unreachable();
   }
 }
 RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
-  if(t.getMetaKind() == kind::metakind::CONSTANT){
+  if(t.isConst()){
     return rewriteConstant(t);
-  }else if(isVariable(t)){
+  }else if(t.isVar()){
     return rewriteVariable(t);
-  }else if(t.getKind() == kind::MINUS){
-    return rewriteMinus(t, false);
-  }else if(t.getKind() == kind::UMINUS){
-    return rewriteUMinus(t, false);
-  }else if(t.getKind() == kind::DIVISION){
-    return rewriteDivByConstant(t, false);
-  }else if(t.getKind() == kind::PLUS){
-    return postRewritePlus(t);
-  }else if(t.getKind() == kind::MULT){
-    return postRewriteMult(t);
-  }else if(t.getKind() == kind::INTS_DIVISION){
-    return RewriteResponse(REWRITE_DONE, t);
-  }else if(t.getKind() == kind::INTS_MODULUS){
-    return RewriteResponse(REWRITE_DONE, t);
   }else{
-    Unreachable();
+    switch(t.getKind()){
+    case kind::MINUS:
+      return rewriteMinus(t, false);
+    case kind::UMINUS:
+      return rewriteUMinus(t, false);
+    case kind::DIVISION:
+    case kind::DIVISION_TOTAL:
+      return rewriteDiv(t, false);
+    case kind::PLUS:
+      return postRewritePlus(t);
+    case kind::MULT:
+      return postRewriteMult(t);
+    case kind::INTS_DIVISION:
+    case kind::INTS_MODULUS:
+      return RewriteResponse(REWRITE_DONE, t);
+    case kind::INTS_DIVISION_TOTAL:
+    case kind::INTS_MODULUS_TOTAL:
+      return rewriteIntsDivModTotal(t, false);
+    case kind::ABS:
+      if(t[0].isConst()) {
+        const Rational& rat = t[0].getConst<Rational>();
+        if(rat >= 0) {
+          return RewriteResponse(REWRITE_DONE, t[0]);
+        } else {
+          return RewriteResponse(REWRITE_DONE,
+                                 NodeManager::currentNM()->mkConst(-rat));
+        }
+      }
+    case kind::TO_REAL:
+      return RewriteResponse(REWRITE_DONE, t[0]);
+    case kind::TO_INTEGER:
+      if(t[0].isConst()) {
+        return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(Rational(t[0].getConst<Rational>().floor())));
+      }
+      if(t[0].getType().isInteger()) {
+        return RewriteResponse(REWRITE_DONE, t[0]);
+      }
+      //Unimplemented("TO_INTEGER, nonconstant");
+      //return rewriteToInteger(t);
+      return RewriteResponse(REWRITE_DONE, t);
+    case kind::IS_INTEGER:
+      if(t[0].isConst()) {
+        return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(t[0].getConst<Rational>().getDenominator() == 1));
+      }
+      if(t[0].getType().isInteger()) {
+        return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
+      }
+      //Unimplemented("IS_INTEGER, nonconstant");
+      //return rewriteIsInteger(t);
+      return RewriteResponse(REWRITE_DONE, t);
+    default:
+      Unreachable();
+    }
   }
 }
 
+
 RewriteResponse ArithRewriter::preRewriteMult(TNode t){
   Assert(t.getKind()== kind::MULT);
 
   // Rewrite multiplications with a 0 argument and to 0
-  Integer intZero;
-
   Rational qZero(0);
 
   for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
@@ -156,14 +200,6 @@ RewriteResponse ArithRewriter::preRewriteMult(TNode t){
       if((*i).getConst<Rational>() == qZero) {
         return RewriteResponse(REWRITE_DONE, mkRationalNode(qZero));
       }
-    } else if((*i).getKind() == kind::CONST_INTEGER) {
-      if((*i).getConst<Integer>() == intZero) {
-        if(t.getType().isInteger()) {
-          return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(intZero));
-        } else {
-          return RewriteResponse(REWRITE_DONE, mkRationalNode(qZero));
-        }
-      }
     }
   }
   return RewriteResponse(REWRITE_DONE, t);
@@ -204,34 +240,36 @@ RewriteResponse ArithRewriter::postRewriteMult(TNode t){
   return RewriteResponse(REWRITE_DONE, res.getNode());
 }
 
-RewriteResponse ArithRewriter::postRewriteAtomConstantRHS(TNode t){
-  TNode left  = t[0];
-  TNode right = t[1];
-
-  Comparison cmp = Comparison::mkNormalComparison(t.getKind(), Polynomial::parsePolynomial(left), Constant(right));
-
-  Assert(cmp.isNormalForm());
-  return RewriteResponse(REWRITE_DONE, cmp.getNode());
-}
-
 RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
+  if(atom.getKind() == kind::IS_INTEGER) {
+    if(atom[0].isConst()) {
+      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(atom[0].getConst<Rational>().isIntegral()));
+    }
+    if(atom[0].getType().isInteger()) {
+      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
+    }
+    // not supported, but this isn't the right place to complain
+    return RewriteResponse(REWRITE_DONE, atom);
+  } else if(atom.getKind() == kind::DIVISIBLE) {
+    if(atom[0].isConst()) {
+      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(bool((atom[0].getConst<Rational>() / atom.getOperator().getConst<Divisible>().k).isIntegral())));
+    }
+    if(atom.getOperator().getConst<Divisible>().k.isOne()) {
+      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
+    }
+    return RewriteResponse(REWRITE_AGAIN, NodeManager::currentNM()->mkNode(kind::EQUAL, NodeManager::currentNM()->mkNode(kind::INTS_MODULUS_TOTAL, atom[0], NodeManager::currentNM()->mkConst(Rational(atom.getOperator().getConst<Divisible>().k))), NodeManager::currentNM()->mkConst(Rational(0))));
+  }
+
   // left |><| right
   TNode left = atom[0];
   TNode right = atom[1];
 
-  if(right.getMetaKind() == kind::metakind::CONSTANT){
-    return postRewriteAtomConstantRHS(atom);
-  }else{
-    Polynomial pleft = Polynomial::parsePolynomial(left);
-    Polynomial pright = Polynomial::parsePolynomial(right);
-
-    Polynomial diff = pleft - pright;
+  Polynomial pleft = Polynomial::parsePolynomial(left);
+  Polynomial pright = Polynomial::parsePolynomial(right);
 
-    Constant cZero = Constant::mkConstant(Rational(0));
-    Node reduction = NodeManager::currentNM()->mkNode(atom.getKind(), diff.getNode(), cZero.getNode());
-
-    return postRewriteAtomConstantRHS(reduction);
-  }
+  Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
+  Assert(cmp.isNormalForm());
+  return RewriteResponse(REWRITE_DONE, cmp.getNode());
 }
 
 RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
@@ -249,6 +287,14 @@ RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
   }else if(atom.getKind() == kind::LT){
     Node geq = currNM->mkNode(kind::GEQ, atom[0], atom[1]);
     return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, geq));
+  }else if(atom.getKind() == kind::IS_INTEGER){
+    if(atom[0].getType().isInteger()){
+      return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
+    }
+  }else if(atom.getKind() == kind::DIVISIBLE){
+    if(atom.getOperator().getConst<Divisible>().k.isOne()){
+      return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
+    }
   }
 
   return RewriteResponse(REWRITE_DONE, atom);
@@ -296,27 +342,93 @@ Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
   return diff;
 }
 
-RewriteResponse ArithRewriter::rewriteDivByConstant(TNode t, bool pre){
-  Assert(t.getKind()== kind::DIVISION);
+RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
+  Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind()== kind::DIVISION);
+
 
   Node left = t[0];
   Node right = t[1];
-  Assert(right.getKind()== kind::CONST_RATIONAL);
+  if(right.getKind() == kind::CONST_RATIONAL){
+    const Rational& den = right.getConst<Rational>();
+
+    if(den.isZero()){
+      if(t.getKind() == kind::DIVISION_TOTAL){
+        return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
+      }else{
+        // This is unsupported, but this is not a good place to complain
+        return RewriteResponse(REWRITE_DONE, t);
+      }
+    }
+    Assert(den != Rational(0));
 
+    if(left.getKind() == kind::CONST_RATIONAL){
+      const Rational& num = left.getConst<Rational>();
+      Rational div = num / den;
+      Node result =  mkRationalNode(div);
+      return RewriteResponse(REWRITE_DONE, result);
+    }
 
-  const Rational& den = right.getConst<Rational>();
+    Rational div = den.inverse();
 
-  Assert(den != Rational(0));
+    Node result = mkRationalNode(div);
 
-  Rational div = den.inverse();
+    Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
+    if(pre){
+      return RewriteResponse(REWRITE_DONE, mult);
+    }else{
+      return RewriteResponse(REWRITE_AGAIN, mult);
+    }
+  }else{
+    return RewriteResponse(REWRITE_DONE, t);
+  }
+}
 
-  Node result = mkRationalNode(div);
+RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre){
+  Kind k = t.getKind();
+  // Assert(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL ||
+  //        k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
+
+  //Leaving the function as before (INTS_MODULUS can be handled),
+  // but restricting its use here
+  Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL);
+  TNode n = t[0], d = t[1];
+  bool dIsConstant = d.getKind() == kind::CONST_RATIONAL;
+  if(dIsConstant && d.getConst<Rational>().isZero()){
+    if(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL){
+      return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
+    }else{
+      // Do nothing for k == INTS_MODULUS
+      return RewriteResponse(REWRITE_DONE, t);
+    }
+  }else if(dIsConstant && d.getConst<Rational>().isOne()){
+    if(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL){
+      return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
+    }else{
+      Assert(k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
+      return RewriteResponse(REWRITE_AGAIN, n);
+    }
+  }else if(dIsConstant && d.getConst<Rational>().isNegativeOne()){
+    if(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL){
+      return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
+    }else{
+      Assert(k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
+      return RewriteResponse(REWRITE_AGAIN, NodeManager::currentNM()->mkNode(kind::UMINUS, n));
+    }
+  }else if(dIsConstant && n.getKind() == kind::CONST_RATIONAL){
+    Assert(d.getConst<Rational>().isIntegral());
+    Assert(n.getConst<Rational>().isIntegral());
+    Assert(!d.getConst<Rational>().isZero());
+    Integer di = d.getConst<Rational>().getNumerator();
+    Integer ni = n.getConst<Rational>().getNumerator();
 
-  Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
-  if(pre){
-    return RewriteResponse(REWRITE_DONE, mult);
+    bool isDiv = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
+
+    Integer result = isDiv ? ni.euclidianDivideQuotient(di) : ni.euclidianDivideRemainder(di);
+
+    Node resultNode = mkRationalNode(Rational(result));
+    return RewriteResponse(REWRITE_DONE, resultNode);
   }else{
-    return RewriteResponse(REWRITE_AGAIN, mult);
+    return RewriteResponse(REWRITE_DONE, t);
   }
 }