Update to the ArithRewriter to remove REWRITE_AGAIN_FULL and limit REWRITE_AGAIN...
authorTim King <taking@cs.nyu.edu>
Wed, 28 Mar 2012 17:16:27 +0000 (17:16 +0000)
committerTim King <taking@cs.nyu.edu>
Wed, 28 Mar 2012 17:16:27 +0000 (17:16 +0000)
src/theory/arith/arith_rewriter.cpp
src/theory/arith/normal_form.cpp
src/theory/arith/normal_form.h

index ca0aa4d14b2dbc32863fc3d98db54bafd747596a..30568c3cac43623d51d7ae122952f4da1356dd4e 100644 (file)
@@ -54,17 +54,20 @@ RewriteResponse ArithRewriter::rewriteVariable(TNode t){
 RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
   Assert(t.getKind()== kind::MINUS);
 
-  if(t[0] == t[1]){
-    Rational zero(0);
-    Node zeroNode  = mkRationalNode(zero);
-    return RewriteResponse(REWRITE_DONE, zeroNode);
-  }
-
-  Node noMinus = makeSubtractionNode(t[0],t[1]);
   if(pre){
-    return RewriteResponse(REWRITE_DONE, noMinus);
+    if(t[0] == t[1]){
+      Rational zero(0);
+      Node zeroNode  = mkRationalNode(zero);
+      return RewriteResponse(REWRITE_DONE, zeroNode);
+    }else{
+      Node noMinus = makeSubtractionNode(t[0],t[1]);
+      return RewriteResponse(REWRITE_DONE, noMinus);
+    }
   }else{
-    return RewriteResponse(REWRITE_AGAIN_FULL, noMinus);
+    Polynomial minuend = Polynomial::parsePolynomial(t[0]);
+    Polynomial subtrahend = Polynomial::parsePolynomial(t[0]);
+    Polynomial diff = minuend - subtrahend;
+    return RewriteResponse(REWRITE_DONE, diff.getNode());
   }
 }
 
@@ -209,39 +212,6 @@ RewriteResponse ArithRewriter::postRewriteAtomConstantRHS(TNode t){
 
   Assert(cmp.isNormalForm());
   return RewriteResponse(REWRITE_DONE, cmp.getNode());
-
-
-  // Comparison cmp = Comparison::mkComparison(t.getKind(), Polynomial::parsePolynomial(left), Constant(right));
-
-  // if(cmp.isBoolean()){
-  //   return RewriteResponse(REWRITE_DONE, cmp.getNode());
-  // }
-
-  // if(cmp.getLeft().containsConstant()){
-  //   Monomial constantHead = cmp.getLeft().getHead();
-  //   Assert(constantHead.isConstant());
-
-  //   Constant constant = constantHead.getConstant();
-
-  //   Constant negativeConstantHead = -constant;
-
-  //   cmp = cmp.addConstant(negativeConstantHead);
-  // }
-  // Assert(!cmp.getLeft().containsConstant());
-
-  // if(!cmp.getLeft().getHead().coefficientIsOne()){
-  //   Monomial constantHead = cmp.getLeft().getHead();
-  //   Assert(!constantHead.isConstant());
-  //   Constant constant = constantHead.getConstant();
-
-  //   Constant inverse = Constant::mkConstant(constant.getValue().inverse());
-
-  //   cmp = cmp.multiplyConstant(inverse);
-  // }
-  // Assert(cmp.getLeft().getHead().coefficientIsOne());
-
-  // Assert(cmp.isBoolean() || cmp.isNormalForm());
-  // return RewriteResponse(REWRITE_DONE, cmp.getNode());
 }
 
 RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
@@ -252,11 +222,15 @@ RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
   if(right.getMetaKind() == kind::metakind::CONSTANT){
     return postRewriteAtomConstantRHS(atom);
   }else{
-    //Transform this to: (left - right) |><| 0
-    Node diff = makeSubtractionNode(left, right);
-    Rational qZero(0);
-    Node reduction = NodeManager::currentNM()->mkNode(atom.getKind(), diff, mkRationalNode(qZero));
-    return RewriteResponse(REWRITE_AGAIN_FULL, reduction);
+    Polynomial pleft = Polynomial::parsePolynomial(left);
+    Polynomial pright = Polynomial::parsePolynomial(right);
+
+    Polynomial diff = pleft - pright;
+
+    Constant cZero = Constant::mkConstant(Rational(0));
+    Node reduction = NodeManager::currentNM()->mkNode(atom.getKind(), diff.getNode(), cZero.getNode());
+
+    return postRewriteAtomConstantRHS(reduction);
   }
 }
 
@@ -269,38 +243,15 @@ RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
     if(atom[0] == atom[1]) {
       return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
     }
+  }else if(atom.getKind() == kind::GT){
+    Node leq = currNM->mkNode(kind::LEQ, atom[0], atom[1]);
+    return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, leq));
+  }else if(atom.getKind() == kind::LT){
+    Node geq = currNM->mkNode(kind::GEQ, atom[0], atom[1]);
+    return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, geq));
   }
 
-  Node reduction = atom;
-
-  if(atom[1].getMetaKind() != kind::metakind::CONSTANT) {
-    // left |><| right
-    TNode left = atom[0];
-    TNode right = atom[1];
-
-    //Transform this to: (left - right) |><| 0
-    Node diff = makeSubtractionNode(left, right);
-    Rational qZero(0);
-    reduction = currNM->mkNode(atom.getKind(), diff, mkRationalNode(qZero));
-  }
-
-  if(reduction.getKind() == kind::GT){
-    Node leq = currNM->mkNode(kind::LEQ, reduction[0], reduction[1]);
-    reduction = currNM->mkNode(kind::NOT, leq);
-  }else if(reduction.getKind() == kind::LT){
-    Node geq = currNM->mkNode(kind::GEQ, reduction[0], reduction[1]);
-    reduction = currNM->mkNode(kind::NOT, geq);
-  }
-  /* BREADCRUMB : Move this rewrite into preprocessing
-  else if( Options::current()->rewriteArithEqualities && reduction.getKind() == kind::EQUAL){
-    Node geq = currNM->mkNode(kind::GEQ, reduction[0], reduction[1]);
-    Node leq = currNM->mkNode(kind::LEQ, reduction[0], reduction[1]);
-    reduction = currNM->mkNode(kind::AND, geq, leq);
-  }
-  */
-
-
-  return RewriteResponse(REWRITE_DONE, reduction);
+  return RewriteResponse(REWRITE_DONE, atom);
 }
 
 RewriteResponse ArithRewriter::postRewrite(TNode t){
index a4dc78c9f65846c613c3a00564e9694e1f63ba28..31cd8cd708f227d2f0e632de1359bd0da289a326 100644 (file)
@@ -122,6 +122,14 @@ Monomial Monomial::parseMonomial(Node n) {
     return Monomial(VarList::parseVarList(n));
   }
 }
+Monomial Monomial::operator*(const Constant& c) const {
+  if(c.isZero()){
+    return mkZero();
+  }else{
+    Constant newConstant = this->getConstant() * c;
+    return Monomial::mkMonomial(newConstant, getVarList());
+  }
+}
 
 Monomial Monomial::operator*(const Monomial& mono) const {
   Constant newConstant = this->getConstant() * mono.getConstant();
@@ -174,6 +182,28 @@ Polynomial Polynomial::operator+(const Polynomial& vl) const {
   return result;
 }
 
+Polynomial Polynomial::operator-(const Polynomial& vl) const {
+  Constant negOne = Constant::mkConstant(Rational(-1));
+
+  return *this + (vl*negOne);
+}
+
+Polynomial Polynomial::operator*(const Constant& c) const{
+  if(c.isZero()){
+    return Polynomial::mkZero();
+  }else if(c.isOne()){
+    return *this;
+  }else{
+    std::vector<Monomial> newMonos;
+    for(iterator i = this->begin(), end = this->end(); i != end; ++i) {
+      newMonos.push_back((*i)*c);
+    }
+
+    Assert(Monomial::isStrictlySorted(newMonos));
+    return Polynomial::mkPolynomial(newMonos);
+  }
+}
+
 Polynomial Polynomial::operator*(const Monomial& mono) const {
   if(mono.isZero()) {
     return Polynomial(mono); //Don't multiply by zero
index 71d7c96f496cacfbb23e4bda9b66d1b57f278bd8..a5e1e0cec10c5c47557fb8753377f3c046f93fab 100644 (file)
@@ -572,6 +572,7 @@ public:
     return coefficientIsOne() || constant.getValue() == -1;
   }
 
+  Monomial operator*(const Constant& c) const;
   Monomial operator*(const Monomial& mono) const;
 
 
@@ -854,7 +855,9 @@ public:
   }
 
   Polynomial operator+(const Polynomial& vl) const;
+  Polynomial operator-(const Polynomial& vl) const;
 
+  Polynomial operator*(const Constant& c) const;
   Polynomial operator*(const Monomial& mono) const;
 
   Polynomial operator*(const Polynomial& poly) const;