Merge branch '1.2.x'
[cvc5.git] / src / theory / arith / arith_rewriter.cpp
index 8d12e78fe137f126c9f668fe57af79cac856d5bf..247c09294e5341463850625919c77c6160ced21d 100644 (file)
@@ -1,13 +1,11 @@
 /*********************                                                        */
 /*! \file arith_rewriter.cpp
  ** \verbatim
- ** Original author: taking
- ** Major contributors: dejan
- ** Minor contributors (to current version): mdeters
- ** This file is part of the CVC4 prototype.
- ** Copyright (c) 2009, 2010, 2011  The Analysis of Computer Systems Group (ACSys)
- ** Courant Institute of Mathematical Sciences
- ** New York University
+ ** Original author: Tim King
+ ** Major contributors: none
+ ** Minor contributors (to current version): Morgan Deters, Dejan Jovanovic
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2013  New York University and The University of Iowa
  ** See the file COPYING in the top-level source directory for licensing
  ** information.\endverbatim
  **
 #include <set>
 #include <stack>
 
-using namespace CVC4;
-using namespace CVC4::theory;
-using namespace CVC4::theory::arith;
+namespace CVC4 {
+namespace theory {
+namespace arith {
 
-bool isVariable(TNode t){
-  return t.getMetaKind() == kind::metakind::VARIABLE;
+bool ArithRewriter::isAtom(TNode n) {
+  Kind k = n.getKind();
+  return arith::isRelationOperator(k) || k == kind::IS_INTEGER || k == kind::DIVISIBLE;
 }
 
 RewriteResponse ArithRewriter::rewriteConstant(TNode t){
-  Assert(t.getMetaKind() == kind::metakind::CONSTANT);
-  Node val = coerceToRationalNode(t);
+  Assert(t.isConst());
+  Assert(t.getKind() == kind::CONST_RATIONAL);
 
-  return RewriteResponse(REWRITE_DONE, val);
+  return RewriteResponse(REWRITE_DONE, t);
 }
 
 RewriteResponse ArithRewriter::rewriteVariable(TNode t){
-  Assert(isVariable(t));
+  Assert(t.isVar());
 
   return RewriteResponse(REWRITE_DONE, t);
 }
@@ -50,23 +49,31 @@ RewriteResponse ArithRewriter::rewriteVariable(TNode t){
 RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
   Assert(t.getKind()== kind::MINUS);
 
-  if(t[0] == t[1]){
-    Rational zero(0);
-    Node zeroNode  = mkRationalNode(zero);
-    return RewriteResponse(REWRITE_DONE, zeroNode);
-  }
-
-  Node noMinus = makeSubtractionNode(t[0],t[1]);
   if(pre){
-    return RewriteResponse(REWRITE_DONE, noMinus);
+    if(t[0] == t[1]){
+      Rational zero(0);
+      Node zeroNode  = mkRationalNode(zero);
+      return RewriteResponse(REWRITE_DONE, zeroNode);
+    }else{
+      Node noMinus = makeSubtractionNode(t[0],t[1]);
+      return RewriteResponse(REWRITE_DONE, noMinus);
+    }
   }else{
-    return RewriteResponse(REWRITE_AGAIN_FULL, noMinus);
+    Polynomial minuend = Polynomial::parsePolynomial(t[0]);
+    Polynomial subtrahend = Polynomial::parsePolynomial(t[1]);
+    Polynomial diff = minuend - subtrahend;
+    return RewriteResponse(REWRITE_DONE, diff.getNode());
   }
 }
 
 RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
   Assert(t.getKind()== kind::UMINUS);
 
+  if(t[0].getKind() == kind::CONST_RATIONAL){
+    Rational neg = -(t[0].getConst<Rational>());
+    return RewriteResponse(REWRITE_DONE, mkRationalNode(neg));
+  }
+
   Node noUminus = makeUnaryMinusNode(t[0]);
   if(pre)
     return RewriteResponse(REWRITE_DONE, noUminus);
@@ -75,54 +82,117 @@ RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
 }
 
 RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
-  if(t.getMetaKind() == kind::metakind::CONSTANT){
+  if(t.isConst()){
     return rewriteConstant(t);
-  }else if(isVariable(t)){
+  }else if(t.isVar()){
     return rewriteVariable(t);
-  }else if(t.getKind() == kind::MINUS){
-    return rewriteMinus(t, true);
-  }else if(t.getKind() == kind::UMINUS){
-    return rewriteUMinus(t, true);
-  }else if(t.getKind() == kind::DIVISION){
-    if(t[0].getKind()== kind::CONST_RATIONAL){
-      return rewriteDivByConstant(t, true);
-    }else{
+  }else{
+    switch(Kind k = t.getKind()){
+    case kind::MINUS:
+      return rewriteMinus(t, true);
+    case kind::UMINUS:
+      return rewriteUMinus(t, true);
+    case kind::DIVISION:
+    case kind::DIVISION_TOTAL:
+      return rewriteDiv(t,true);
+    case kind::PLUS:
+      return preRewritePlus(t);
+    case kind::MULT:
+      return preRewriteMult(t);
+    case kind::INTS_DIVISION:
+    case kind::INTS_MODULUS:
       return RewriteResponse(REWRITE_DONE, t);
+    case kind::INTS_DIVISION_TOTAL:
+    case kind::INTS_MODULUS_TOTAL:
+      return rewriteIntsDivModTotal(t,true);
+    case kind::ABS:
+      if(t[0].isConst()) {
+        const Rational& rat = t[0].getConst<Rational>();
+        if(rat >= 0) {
+          return RewriteResponse(REWRITE_DONE, t[0]);
+        } else {
+          return RewriteResponse(REWRITE_DONE,
+                                 NodeManager::currentNM()->mkConst(-rat));
+        }
+      }
+      return RewriteResponse(REWRITE_DONE, t);
+    case kind::IS_INTEGER:
+    case kind::TO_INTEGER:
+      return RewriteResponse(REWRITE_DONE, t);
+    case kind::TO_REAL:
+      return RewriteResponse(REWRITE_DONE, t[0]);
+    default:
+      Unhandled(k);
     }
-  }else if(t.getKind() == kind::PLUS){
-    return preRewritePlus(t);
-  }else if(t.getKind() == kind::MULT){
-    return preRewriteMult(t);
-  }else{
-    Unreachable();
   }
 }
 RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
-  if(t.getMetaKind() == kind::metakind::CONSTANT){
+  if(t.isConst()){
     return rewriteConstant(t);
-  }else if(isVariable(t)){
+  }else if(t.isVar()){
     return rewriteVariable(t);
-  }else if(t.getKind() == kind::MINUS){
-    return rewriteMinus(t, false);
-  }else if(t.getKind() == kind::UMINUS){
-    return rewriteUMinus(t, false);
-  }else if(t.getKind() == kind::DIVISION){
-    return rewriteDivByConstant(t, false);
-  }else if(t.getKind() == kind::PLUS){
-    return postRewritePlus(t);
-  }else if(t.getKind() == kind::MULT){
-    return postRewriteMult(t);
   }else{
-    Unreachable();
+    switch(t.getKind()){
+    case kind::MINUS:
+      return rewriteMinus(t, false);
+    case kind::UMINUS:
+      return rewriteUMinus(t, false);
+    case kind::DIVISION:
+    case kind::DIVISION_TOTAL:
+      return rewriteDiv(t, false);
+    case kind::PLUS:
+      return postRewritePlus(t);
+    case kind::MULT:
+      return postRewriteMult(t);
+    case kind::INTS_DIVISION:
+    case kind::INTS_MODULUS:
+      return RewriteResponse(REWRITE_DONE, t);
+    case kind::INTS_DIVISION_TOTAL:
+    case kind::INTS_MODULUS_TOTAL:
+      return rewriteIntsDivModTotal(t, false);
+    case kind::ABS:
+      if(t[0].isConst()) {
+        const Rational& rat = t[0].getConst<Rational>();
+        if(rat >= 0) {
+          return RewriteResponse(REWRITE_DONE, t[0]);
+        } else {
+          return RewriteResponse(REWRITE_DONE,
+                                 NodeManager::currentNM()->mkConst(-rat));
+        }
+      }
+    case kind::TO_REAL:
+      return RewriteResponse(REWRITE_DONE, t[0]);
+    case kind::TO_INTEGER:
+      if(t[0].isConst()) {
+        return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(Rational(t[0].getConst<Rational>().floor())));
+      }
+      if(t[0].getType().isInteger()) {
+        return RewriteResponse(REWRITE_DONE, t[0]);
+      }
+      //Unimplemented("TO_INTEGER, nonconstant");
+      //return rewriteToInteger(t);
+      return RewriteResponse(REWRITE_DONE, t);
+    case kind::IS_INTEGER:
+      if(t[0].isConst()) {
+        return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(t[0].getConst<Rational>().getDenominator() == 1));
+      }
+      if(t[0].getType().isInteger()) {
+        return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
+      }
+      //Unimplemented("IS_INTEGER, nonconstant");
+      //return rewriteIsInteger(t);
+      return RewriteResponse(REWRITE_DONE, t);
+    default:
+      Unreachable();
+    }
   }
 }
 
+
 RewriteResponse ArithRewriter::preRewriteMult(TNode t){
   Assert(t.getKind()== kind::MULT);
 
   // Rewrite multiplications with a 0 argument and to 0
-  Integer intZero;
-
   Rational qZero(0);
 
   for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
@@ -130,14 +200,6 @@ RewriteResponse ArithRewriter::preRewriteMult(TNode t){
       if((*i).getConst<Rational>() == qZero) {
         return RewriteResponse(REWRITE_DONE, mkRationalNode(qZero));
       }
-    } else if((*i).getKind() == kind::CONST_INTEGER) {
-      if((*i).getConst<Integer>() == intZero) {
-        if(t.getType().isInteger()) {
-          return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(intZero));
-        } else {
-          return RewriteResponse(REWRITE_DONE, mkRationalNode(qZero));
-        }
-      }
     }
   }
   return RewriteResponse(REWRITE_DONE, t);
@@ -178,57 +240,36 @@ RewriteResponse ArithRewriter::postRewriteMult(TNode t){
   return RewriteResponse(REWRITE_DONE, res.getNode());
 }
 
-RewriteResponse ArithRewriter::postRewriteAtomConstantRHS(TNode t){
-  TNode left  = t[0];
-  TNode right = t[1];
-
-  Comparison cmp = Comparison::mkComparison(t.getKind(), Polynomial::parsePolynomial(left), Constant(right));
-
-  if(cmp.isBoolean()){
-    return RewriteResponse(REWRITE_DONE, cmp.getNode());
-  }
-
-  if(cmp.getLeft().containsConstant()){
-    Monomial constantHead = cmp.getLeft().getHead();
-    Assert(constantHead.isConstant());
-
-    Constant constant = constantHead.getConstant();
-
-    Constant negativeConstantHead = -constant;
-
-    cmp = cmp.addConstant(negativeConstantHead);
-  }
-  Assert(!cmp.getLeft().containsConstant());
-
-  if(!cmp.getLeft().getHead().coefficientIsOne()){
-    Monomial constantHead = cmp.getLeft().getHead();
-    Assert(!constantHead.isConstant());
-    Constant constant = constantHead.getConstant();
-
-    Constant inverse = Constant::mkConstant(constant.getValue().inverse());
-
-    cmp = cmp.multiplyConstant(inverse);
+RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
+  if(atom.getKind() == kind::IS_INTEGER) {
+    if(atom[0].isConst()) {
+      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(atom[0].getConst<Rational>().isIntegral()));
+    }
+    if(atom[0].getType().isInteger()) {
+      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
+    }
+    // not supported, but this isn't the right place to complain
+    return RewriteResponse(REWRITE_DONE, atom);
+  } else if(atom.getKind() == kind::DIVISIBLE) {
+    if(atom[0].isConst()) {
+      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(bool((atom[0].getConst<Rational>() / atom.getOperator().getConst<Divisible>().k).isIntegral())));
+    }
+    if(atom.getOperator().getConst<Divisible>().k.isOne()) {
+      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
+    }
+    return RewriteResponse(REWRITE_AGAIN, NodeManager::currentNM()->mkNode(kind::EQUAL, NodeManager::currentNM()->mkNode(kind::INTS_MODULUS_TOTAL, atom[0], NodeManager::currentNM()->mkConst(Rational(atom.getOperator().getConst<Divisible>().k))), NodeManager::currentNM()->mkConst(Rational(0))));
   }
-  Assert(cmp.getLeft().getHead().coefficientIsOne());
 
-  Assert(cmp.isBoolean() || cmp.isNormalForm());
-  return RewriteResponse(REWRITE_DONE, cmp.getNode());
-}
-
-RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
   // left |><| right
   TNode left = atom[0];
   TNode right = atom[1];
 
-  if(right.getMetaKind() == kind::metakind::CONSTANT){
-    return postRewriteAtomConstantRHS(atom);
-  }else{
-    //Transform this to: (left - right) |><| 0
-    Node diff = makeSubtractionNode(left, right);
-    Rational qZero(0);
-    Node reduction = NodeManager::currentNM()->mkNode(atom.getKind(), diff, mkRationalNode(qZero));
-    return RewriteResponse(REWRITE_AGAIN_FULL, reduction);
-  }
+  Polynomial pleft = Polynomial::parsePolynomial(left);
+  Polynomial pright = Polynomial::parsePolynomial(right);
+
+  Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
+  Assert(cmp.isNormalForm());
+  return RewriteResponse(REWRITE_DONE, cmp.getNode());
 }
 
 RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
@@ -240,38 +281,23 @@ RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
     if(atom[0] == atom[1]) {
       return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
     }
+  }else if(atom.getKind() == kind::GT){
+    Node leq = currNM->mkNode(kind::LEQ, atom[0], atom[1]);
+    return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, leq));
+  }else if(atom.getKind() == kind::LT){
+    Node geq = currNM->mkNode(kind::GEQ, atom[0], atom[1]);
+    return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, geq));
+  }else if(atom.getKind() == kind::IS_INTEGER){
+    if(atom[0].getType().isInteger()){
+      return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
+    }
+  }else if(atom.getKind() == kind::DIVISIBLE){
+    if(atom.getOperator().getConst<Divisible>().k.isOne()){
+      return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
+    }
   }
 
-  Node reduction = atom;
-
-  if(atom[1].getMetaKind() != kind::metakind::CONSTANT) {
-    // left |><| right
-    TNode left = atom[0];
-    TNode right = atom[1];
-
-    //Transform this to: (left - right) |><| 0
-    Node diff = makeSubtractionNode(left, right);
-    Rational qZero(0);
-    reduction = currNM->mkNode(atom.getKind(), diff, mkRationalNode(qZero));
-  }
-
-  if(reduction.getKind() == kind::GT){
-    Node leq = currNM->mkNode(kind::LEQ, reduction[0], reduction[1]);
-    reduction = currNM->mkNode(kind::NOT, leq);
-  }else if(reduction.getKind() == kind::LT){
-    Node geq = currNM->mkNode(kind::GEQ, reduction[0], reduction[1]);
-    reduction = currNM->mkNode(kind::NOT, geq);
-  }
-  /* BREADCRUMB : Move this rewrite into preprocessing
-  else if( Options::current()->rewriteArithEqualities && reduction.getKind() == kind::EQUAL){
-    Node geq = currNM->mkNode(kind::GEQ, reduction[0], reduction[1]);
-    Node leq = currNM->mkNode(kind::LEQ, reduction[0], reduction[1]);
-    reduction = currNM->mkNode(kind::AND, geq, leq);
-  }
-  */
-
-
-  return RewriteResponse(REWRITE_DONE, reduction);
+  return RewriteResponse(REWRITE_DONE, atom);
 }
 
 RewriteResponse ArithRewriter::postRewrite(TNode t){
@@ -316,26 +342,96 @@ Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
   return diff;
 }
 
-RewriteResponse ArithRewriter::rewriteDivByConstant(TNode t, bool pre){
-  Assert(t.getKind()== kind::DIVISION);
+RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
+  Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind()== kind::DIVISION);
+
 
   Node left = t[0];
   Node right = t[1];
-  Assert(right.getKind()== kind::CONST_RATIONAL);
+  if(right.getKind() == kind::CONST_RATIONAL){
+    const Rational& den = right.getConst<Rational>();
+
+    if(den.isZero()){
+      if(t.getKind() == kind::DIVISION_TOTAL){
+        return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
+      }else{
+        // This is unsupported, but this is not a good place to complain
+        return RewriteResponse(REWRITE_DONE, t);
+      }
+    }
+    Assert(den != Rational(0));
 
+    if(left.getKind() == kind::CONST_RATIONAL){
+      const Rational& num = left.getConst<Rational>();
+      Rational div = num / den;
+      Node result =  mkRationalNode(div);
+      return RewriteResponse(REWRITE_DONE, result);
+    }
 
-  const Rational& den = right.getConst<Rational>();
+    Rational div = den.inverse();
+
+    Node result = mkRationalNode(div);
+
+    Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
+    if(pre){
+      return RewriteResponse(REWRITE_DONE, mult);
+    }else{
+      return RewriteResponse(REWRITE_AGAIN, mult);
+    }
+  }else{
+    return RewriteResponse(REWRITE_DONE, t);
+  }
+}
 
-  Assert(den != Rational(0));
+RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre){
+  Kind k = t.getKind();
+  // Assert(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL ||
+  //        k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
+
+  //Leaving the function as before (INTS_MODULUS can be handled),
+  // but restricting its use here
+  Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL);
+  TNode n = t[0], d = t[1];
+  bool dIsConstant = d.getKind() == kind::CONST_RATIONAL;
+  if(dIsConstant && d.getConst<Rational>().isZero()){
+    if(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL){
+      return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
+    }else{
+      // Do nothing for k == INTS_MODULUS
+      return RewriteResponse(REWRITE_DONE, t);
+    }
+  }else if(dIsConstant && d.getConst<Rational>().isOne()){
+    if(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL){
+      return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
+    }else{
+      Assert(k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
+      return RewriteResponse(REWRITE_AGAIN, n);
+    }
+  }else if(dIsConstant && d.getConst<Rational>().isNegativeOne()){
+    if(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL){
+      return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
+    }else{
+      Assert(k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
+      return RewriteResponse(REWRITE_AGAIN, NodeManager::currentNM()->mkNode(kind::UMINUS, n));
+    }
+  }else if(dIsConstant && n.getKind() == kind::CONST_RATIONAL){
+    Assert(d.getConst<Rational>().isIntegral());
+    Assert(n.getConst<Rational>().isIntegral());
+    Assert(!d.getConst<Rational>().isZero());
+    Integer di = d.getConst<Rational>().getNumerator();
+    Integer ni = n.getConst<Rational>().getNumerator();
 
-  Rational div = den.inverse();
+    bool isDiv = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
 
-  Node result = mkRationalNode(div);
+    Integer result = isDiv ? ni.euclidianDivideQuotient(di) : ni.euclidianDivideRemainder(di);
 
-  Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
-  if(pre){
-    return RewriteResponse(REWRITE_DONE, mult);
+    Node resultNode = mkRationalNode(Rational(result));
+    return RewriteResponse(REWRITE_DONE, resultNode);
   }else{
-    return RewriteResponse(REWRITE_AGAIN, mult);
+    return RewriteResponse(REWRITE_DONE, t);
   }
 }
+
+}/* CVC4::theory::arith namespace */
+}/* CVC4::theory namespace */
+}/* CVC4 namespace */