* New normal form for arithmetic is in place.
authorTim King <taking@cs.nyu.edu>
Mon, 13 Sep 2010 16:08:21 +0000 (16:08 +0000)
committerTim King <taking@cs.nyu.edu>
Mon, 13 Sep 2010 16:08:21 +0000 (16:08 +0000)
* src/theory/arith/normal_form.{h,cpp} contains the description for the new
  normal form as well as utilities for dealing with the normal form.
* src/theory/arith/next_arith_rewriter.{h,cpp} contains the new rewriter.
  The new rewriter implements preRewrite() and postRewrite() for arithmetic.
* src/theory/arith/arith_rewriter.{h,cpp} have been removed.
* TheoryArith::rewrite() has been removed.
* Arithmetic with the new normal form outperforms the trunk where the branch
  occurred (-r797) on 46% of the examples in QF_LRA. (33% have no noticeable
  difference.) Some important optimizations are stilling pending to the code
  for handling the new normal form. (Bug 196.)

14 files changed:
src/expr/node_manager.cpp
src/theory/arith/Makefile.am
src/theory/arith/arith_rewriter.cpp [deleted file]
src/theory/arith/arith_rewriter.h [deleted file]
src/theory/arith/arith_utilities.h
src/theory/arith/kinds
src/theory/arith/next_arith_rewriter.cpp [new file with mode: 0644]
src/theory/arith/next_arith_rewriter.h [new file with mode: 0644]
src/theory/arith/normal_form.cpp [new file with mode: 0644]
src/theory/arith/normal_form.h [new file with mode: 0644]
src/theory/arith/tableau.h
src/theory/arith/theory_arith.cpp
src/theory/arith/theory_arith.h
test/unit/theory/theory_arith_white.h

index d017ad799d2c2b7690f1430e0b4309ed4be6648f..37ed4fe20905e548b45412130355f04baf6ef453 100644 (file)
@@ -235,6 +235,9 @@ TypeNode NodeManager::getType(TNode n, bool check)
     case kind::APPLY_UF:
       typeNode = CVC4::theory::uf::UfTypeRule::computeType(this, n, check);
       break;
+    case kind::IDENTITY:
+      typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n, check);
+      break;
     case kind::PLUS:
       typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n, check);
       break;
index e500f5cf81776b62270d61395a7169980b249b79..ead39082c0ae8e7c6900bcc5b6947f724af7a4f5 100644 (file)
@@ -7,8 +7,10 @@ noinst_LTLIBRARIES = libarith.la
 
 libarith_la_SOURCES = \
        theory_arith_type_rules.h \
-       arith_rewriter.h \
-       arith_rewriter.cpp \
+       next_arith_rewriter.h \
+       next_arith_rewriter.cpp \
+       normal_form.h\
+       normal_form.cpp \
        arith_utilities.h \
        arith_constants.h \
        arith_activity.h \
diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp
deleted file mode 100644 (file)
index ba1445d..0000000
+++ /dev/null
@@ -1,557 +0,0 @@
-/*********************                                                        */
-/*! \file arith_rewriter.cpp
- ** \verbatim
- ** Original author: taking
- ** Major contributors: none
- ** Minor contributors (to current version): mdeters
- ** This file is part of the CVC4 prototype.
- ** Copyright (c) 2009, 2010  The Analysis of Computer Systems Group (ACSys)
- ** Courant Institute of Mathematical Sciences
- ** New York University
- ** See the file COPYING in the top-level source directory for licensing
- ** information.\endverbatim
- **
- ** \brief [[ Add one-line brief description here ]]
- **
- ** [[ Add lengthier description here ]]
- ** \todo document this file
- **/
-
-
-#include "theory/arith/arith_rewriter.h"
-#include "theory/arith/arith_utilities.h"
-
-#include <vector>
-#include <set>
-#include <stack>
-
-
-using namespace CVC4;
-using namespace CVC4::theory;
-using namespace CVC4::theory::arith;
-
-
-
-
-
-Kind multKind(Kind k, int sgn);
-
-/**
- * Performs a quick check to see if it is easy to rewrite to
- * this normal form
- *        v |><| b
- * Also writes relations with constants on both sides to TRUE or FALSE.
- * If it can, it returns true and sets res to this value.
- *
- * This is for optimizing rewriteAtom() to avoid the more computationally
- * expensive general rewriting procedure.
- *
- * If simplification is not done, it returns Node::null()
- */
-Node almostVarOrConstEqn(TNode atom, Kind k, TNode left, TNode right){
-  Assert(atom.getKind() == k);
-  Assert(isRelationOperator(k));
-  Assert(atom[0] == left);
-  Assert(atom[1] == right);
-  bool leftIsConst  =  left.getMetaKind() == kind::metakind::CONSTANT;
-  bool rightIsConst = right.getMetaKind() == kind::metakind::CONSTANT;
-
-  bool leftIsVar = left.getMetaKind() == kind::metakind::VARIABLE;
-  bool rightIsVar = right.getMetaKind() == kind::metakind::VARIABLE;
-
-  if(leftIsConst && rightIsConst){
-    Rational lc = coerceToRational(left);
-    Rational rc = coerceToRational(right);
-    bool res = evaluateConstantPredicate(k,lc, rc);
-    return mkBoolNode(res);
-  }else if(leftIsVar && rightIsConst){
-    if(right.getKind() == kind::CONST_RATIONAL){
-      return atom;
-    }else{
-      return NodeManager::currentNM()->mkNode(k,left,coerceToRationalNode(right));
-    }
-  }else if(leftIsConst && rightIsVar){
-    if(left.getKind() == kind::CONST_RATIONAL){
-      return NodeManager::currentNM()->mkNode(multKind(k,-1),right,left);
-    }else{
-      Node q_left = coerceToRationalNode(left);
-      return NodeManager::currentNM()->mkNode(multKind(k,-1),right,q_left);
-    }
-  }
-
-  return Node::null();
-}
-
-Node ArithRewriter::rewriteAtomCore(TNode atom){
-
-  Kind k = atom.getKind();
-  Assert(isRelationOperator(k));
-
-  // left |><| right
-  TNode left = atom[0];
-  TNode right = atom[1];
-
-  Node nf = almostVarOrConstEqn(atom, k,left,right);
-  if(nf != Node::null() ){
-    return nf;
-  }
-
-
-  //Transform this to: (left- right) |><| 0
-  Node diff = makeSubtractionNode(left, right);
-
-  Node rewritten = rewrite(diff);
-  // rewritten =_{Reals} left - right => rewritten |><| 0
-
-  if(rewritten.getMetaKind() == kind::metakind::CONSTANT){
-    // Case 1 rewritten : c
-    Rational c = rewritten.getConst<Rational>();
-    bool res = evaluateConstantPredicate(k, c, d_constants->d_ZERO);
-    nf = mkBoolNode(res);
-  }else if(rewritten.getMetaKind() == kind::metakind::VARIABLE){
-    // Case 2 rewritten : v
-    nf = NodeManager::currentNM()->mkNode(k, rewritten, d_constants->d_ZERO_NODE);
-  }else{
-    // Case 3 rewritten : (+ c p_1 p_2 ... p_N)  |  not(N=1 and c=0 and p_1.d=1)
-    Rational c = rewritten[0].getConst<Rational>();
-    c = -c;
-    TNode p_1 = rewritten[1];
-    Rational d = p_1[0].getConst<Rational>();
-    d = d.inverse();
-    c = c * d;
-    Node newRight = mkRationalNode(c);
-    Kind newKind = multKind(k, d.sgn());
-    int N = rewritten.getNumChildren() - 1;
-
-    if(N==1){
-      int M = p_1.getNumChildren()-1;
-      if(M == 1){ // v |><| b
-        TNode v = p_1[1];
-        nf = NodeManager::currentNM()->mkNode(newKind, v, newRight);
-      }else{ // p |><| b
-        Node newLeft = multPnfByNonZero(p_1, d);
-        nf = NodeManager::currentNM()->mkNode(newKind, newLeft, newRight);
-      }
-    }else{ //(+ p_1 .. p_N)  |><| b
-      NodeBuilder<> plus(kind::PLUS);
-      for(int i=1; i<=N; ++i){
-        TNode p_i = rewritten[i];
-        plus << multPnfByNonZero(p_i, d);
-      }
-      Node newLeft = plus;
-      nf = NodeManager::currentNM()->mkNode(newKind, newLeft, newRight);
-    }
-  }
-
-  return nf;
-}
-
-Node ArithRewriter::rewriteAtom(TNode atom){
-  Node rewritten = rewriteAtomCore(atom);
-  if(rewritten.getKind() == kind::LT){
-    Node geq = NodeManager::currentNM()->mkNode(kind::GEQ, rewritten[0], rewritten[1]);
-    return NodeManager::currentNM()->mkNode(kind::NOT, geq);
-  }else if(rewritten.getKind() == kind::GT){
-    Node leq = NodeManager::currentNM()->mkNode(kind::LEQ, rewritten[0], rewritten[1]);
-    return NodeManager::currentNM()->mkNode(kind::NOT, leq);
-  }else{
-    return rewritten;
-  }
-}
-
-
-/* cmp( (* d v_1 v_2 ... v_M), (* d' v'_1 v'_2 ... v'_M'):
- *      if(M == M'):
- *      then tupleCompare(v_i, v'_i)
- *      else M -M'
- */
-struct pnfLessThan {
-  bool operator()(Node p0, Node p1) {
-    int p0_M = p0.getNumChildren() -1;
-    int p1_M = p1.getNumChildren() -1;
-    if(p0_M == p1_M){
-      for(int i=1; i<= p0_M; ++i){
-        if(p0[i] != p1[i]){
-          return p0[i] < p1[i];
-        }
-      }
-      return false; //p0 == p1 in this order
-    }else{
-      return p0_M < p1_M;
-    }
-  }
-};
-
-//Two pnfs are equal up to their coefficients
-bool pnfsMatch(TNode p0, TNode p1){
-
-  unsigned M = p0.getNumChildren()-1;
-  if (M+1 != p1.getNumChildren()){
-    return false;
-  }
-
-  for(unsigned i=1; i <= M; ++i){
-    if(p0[i] != p1[i])
-      return false;
-  }
-  return true;
-}
-
-Node addMatchingPnfs(TNode p0, TNode p1){
-  Assert(pnfsMatch(p0,p1));
-
-  unsigned M = p0.getNumChildren()-1;
-
-  Rational c0 = p0[0].getConst<Rational>();
-  Rational c1 = p1[0].getConst<Rational>();
-
-  Rational addedC = c0 + c1;
-  Node newC = mkRationalNode(addedC);
-  NodeBuilder<> nb(kind::MULT);
-  nb << newC;
-  for(unsigned i=1; i <= M; ++i){
-    nb << p0[i];
-  }
-  Node newPnf = nb;
-  return newPnf;
-}
-
-void ArithRewriter::sortAndCombineCoefficients(std::vector<Node>& pnfs){
-  using namespace std;
-
-  /* combined contains exactly 1 representative per for each pnf.
-   * This is maintained by combining the coefficients for pnfs.
-   * that is equal according to pnfLessThan.
-   */
-  typedef set<Node, pnfLessThan> PnfSet;
-  PnfSet combined;
-
-  for(vector<Node>::iterator i=pnfs.begin(); i != pnfs.end(); ++i){
-    Node pnf = *i;
-    PnfSet::iterator pos = combined.find(pnf);
-
-    if(pos == combined.end()){
-      combined.insert(pnf);
-    }else{
-      Node current = *pos;
-      Node sum = addMatchingPnfs(pnf, current);
-      combined.erase(pos);
-      combined.insert(sum);
-    }
-  }
-  pnfs.clear();
-  for(PnfSet::iterator i=combined.begin(); i != combined.end(); ++i){
-    Node pnf = *i;
-    if(pnf[0].getConst<Rational>() != d_constants->d_ZERO){
-      //after combination the coefficient may be zero
-      pnfs.push_back(pnf);
-    }
-  }
-}
-
-Node ArithRewriter::var2pnf(TNode variable){
-  return NodeManager::currentNM()->mkNode(kind::MULT,d_constants->d_ONE_NODE,variable);
-}
-
-Node ArithRewriter::rewritePlus(TNode t){
-  using namespace std;
-
-  Rational accumulator;
-  vector<Node> pnfs;
-
-  for(TNode::iterator i = t.begin(); i!= t.end(); ++i){
-    TNode child = *i;
-    Node rewrittenChild = rewrite(child);
-
-    if(rewrittenChild.getMetaKind() == kind::metakind::CONSTANT){//c
-      Rational c = rewrittenChild.getConst<Rational>();
-      accumulator = accumulator + c;
-    }else if(rewrittenChild.getMetaKind() == kind::metakind::VARIABLE){ //v
-      Node pnf = var2pnf(rewrittenChild);
-      pnfs.push_back(pnf);
-    }else{ //(+ c p_1 p_2 ... p_N)
-      Rational c = rewrittenChild[0].getConst<Rational>();
-      accumulator = accumulator + c;
-      int N = rewrittenChild.getNumChildren() - 1;
-      for(int i=1; i<=N; ++i){
-        TNode pnf = rewrittenChild[i];
-        pnfs.push_back(pnf);
-      }
-    }
-  }
-  sortAndCombineCoefficients(pnfs);
-  if(pnfs.size() == 0){
-    return mkRationalNode(accumulator);
-  }
-
-  // pnfs.size() >= 1
-
-  //Enforce not(N=1 and c=0 and p_1.d=1)
-  if(pnfs.size() == 1){
-    Node p_1 = *(pnfs.begin());
-    if(p_1[0].getConst<Rational>() == d_constants->d_ONE){
-      if(accumulator == d_constants->d_ZERO){  // 0 + (* 1 var) |-> var
-        Node var = p_1[1];
-        return var;
-      }
-    }
-  }
-
-  //We must be in this case
-  //(+ c p_1 p_2 ... p_N)  |  not(N=1 and c=0 and p_1.d=1)
-
-  NodeBuilder<> nb(kind::PLUS);
-  nb << mkRationalNode(accumulator);
-  Debug("arithrewrite") << mkRationalNode(accumulator) << std::endl;
-  for(vector<Node>::iterator i = pnfs.begin(); i != pnfs.end(); ++i){
-    nb << *i;
-    Debug("arithrewrite") << (*i) << std::endl;
-
-  }
-
-  Node normalForm = nb;
-  return normalForm;
-}
-
-//Does not enforce
-//5) v_i are of metakind VARIABLE,
-//6) v_i are in increasing (not strict) nodeOrder,
-Node toPnf(Rational& c, std::set<Node>& variables){
-  NodeBuilder<> nb(kind::MULT);
-  nb << mkRationalNode(c);
-
-  for(std::set<Node>::iterator i = variables.begin(); i != variables.end(); ++i){
-    nb << *i;
-  }
-  Node pnf = nb;
-  return pnf;
-}
-
-Node distribute(TNode n, TNode sum){
-  NodeBuilder<> nb(kind::PLUS);
-  for(TNode::iterator i=sum.begin(); i!=sum.end(); ++i){
-    Node prod = NodeManager::currentNM()->mkNode(kind::MULT, n, *i);
-    nb << prod;
-  }
-  return nb;
-}
-Node distributeSum(TNode sum, TNode distribSum){
-  NodeBuilder<> nb(kind::PLUS);
-  for(TNode::iterator i=sum.begin(); i!=sum.end(); ++i){
-    Node dist = distribute(*i, distribSum);
-    for(Node::iterator j=dist.begin(); j!=dist.end(); ++j){
-      nb << *j;
-    }
-  }
-  return nb;
-}
-
-Node ArithRewriter::rewriteMult(TNode t){
-
-  using namespace std;
-
-  Rational accumulator(1,1);
-  set<Node> variables;
-  vector<Node> sums;
-
-  //These stacks need to be kept in lock step
-  stack<TNode> mult_iterators_nodes;
-  stack<TNode::const_iterator> mult_iterators_iters;
-
-  mult_iterators_nodes.push(t);
-  mult_iterators_iters.push(t.begin());
-
-  while(!mult_iterators_nodes.empty()){
-    TNode mult = mult_iterators_nodes.top();
-    TNode::const_iterator i = mult_iterators_iters.top();
-
-    mult_iterators_nodes.pop();
-    mult_iterators_iters.pop();
-
-    for(; i != mult.end(); ++i){
-      TNode child = *i;
-      if(child.getKind() == kind::MULT){ //TODO add not rewritten already checks
-        ++i;
-        mult_iterators_nodes.push(mult);
-        mult_iterators_iters.push(i);
-
-        mult_iterators_nodes.push(child);
-        mult_iterators_iters.push(child.begin());
-        break;
-      }
-      Node rewrittenChild = rewrite(child);
-
-      if(rewrittenChild.getMetaKind() == kind::metakind::CONSTANT){//c
-        Rational c = rewrittenChild.getConst<Rational>();
-        accumulator = accumulator * c;
-        if(accumulator == d_constants->d_ZERO){
-          return d_constants->d_ZERO_NODE;
-        }
-      }else if(rewrittenChild.getMetaKind() == kind::metakind::VARIABLE){ //v
-        variables.insert(rewrittenChild);
-      }else{ //(+ c p_1 p_2 ... p_N)
-        sums.push_back(rewrittenChild);
-      }
-    }
-  }
-  // accumulator * (\prod var_i) *(\prod sum_j)
-
-  if(sums.size() == 0){ //accumulator * (\prod var_i)
-    if(variables.size() == 0){ //accumulator
-      return mkRationalNode(accumulator);
-    }else if(variables.size() == 1 && accumulator == d_constants->d_ONE){ // var_1
-      Node var = *(variables.begin());
-      return var;
-    }else{
-      //We need to return (+ c p_1 p_2 ... p_N)
-      //To accomplish this:
-      //  let pnf = pnf(accumulator * (\prod var_i)) in (+ 0 pnf)
-      Node pnf = toPnf(accumulator, variables);
-      Node normalForm = NodeManager::currentNM()->mkNode(kind::PLUS, d_constants->d_ZERO_NODE, pnf);
-      return normalForm;
-    }
-  }else{
-    vector<Node>::iterator sum_iter = sums.begin();
-    // \sum t
-    // t \in Q \cup A
-    // where A = lfp {\prod s | s \in Q \cup Variables \cup A}
-    Node distributed = *sum_iter;
-    ++sum_iter;
-    while(sum_iter != sums.end()){
-      Node curr = *sum_iter;
-      distributed = distributeSum(curr, distributed);
-      ++sum_iter;
-    }
-    if(variables.size() >= 1){
-      Node pnf = toPnf(accumulator, variables);
-      distributed = distribute(pnf, distributed);
-    }else{
-      Node constant = mkRationalNode(accumulator);
-      distributed = distribute(constant, distributed);
-    }
-
-    Node nf_distributed = rewrite(distributed);
-    return nf_distributed;
-  }
-}
-
-Node ArithRewriter::rewriteDivByConstant(TNode t){
-  Assert(t.getKind()== kind::DIVISION);
-
-  Node left = t[0];
-  Node reRight = rewrite(t[1]);
-  Assert(reRight.getKind()== kind::CONST_RATIONAL);
-
-
-  Rational den = reRight.getConst<Rational>();
-
-  Assert(den != d_constants->d_ZERO);
-
-  Rational div = den.inverse();
-
-  Node result = mkRationalNode(div);
-
-  Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
-
-  Node reMult = rewrite(mult);
-
-  return reMult;
-}
-
-Node ArithRewriter::rewriteTerm(TNode t){
-  if(t.getMetaKind() == kind::metakind::CONSTANT){
-    return coerceToRationalNode(t);
-  }else if(t.getMetaKind() == kind::metakind::VARIABLE){
-    return t;
-  }else if(t.getKind() == kind::MULT){
-    return rewriteMult(t);
-  }else if(t.getKind() == kind::PLUS){
-    return rewritePlus(t);
-  }else if(t.getKind() == kind::DIVISION){
-    return rewriteDivByConstant(t);
-  }else if(t.getKind() == kind::MINUS){
-    Node sub = makeSubtractionNode(t[0],t[1]);
-    return rewrite(sub);
-  }else if(t.getKind() == kind::UMINUS){
-    Node sub = makeUnaryMinusNode(t[0]);
-    return rewrite(sub);
-  }else{
-    Unhandled(t);
-  }
-}
-
-
-/**
- * Given a node in PNF pnf = (* d p_1 p_2 .. p_M) and a rational q != 0
- * constuct a node equal to q * pnf that is in pnf.
- *
- * The claim is that this is always okay:
- * If d' = q*d, p' = (* d' p_1 p_2 .. p_M) =_{Reals} q * pnf.
- */
-Node ArithRewriter::multPnfByNonZero(TNode pnf, Rational& q){
-  Assert(q != d_constants->d_ZERO);
-  //TODO Assert(isPNF(pnf) );
-
-  int M = pnf.getNumChildren()-1;
-  Rational d = pnf[0].getConst<Rational>();
-  Rational new_d = d*q;
-
-
-  NodeBuilder<> mult(kind::MULT);
-  mult << mkRationalNode(new_d);
-  for(int i=1; i<=M; ++i){
-    mult << pnf[i];
-  }
-
-  Node result = mult;
-  return result;
-}
-
-Node ArithRewriter::makeUnaryMinusNode(TNode n){
-  Node tmp = NodeManager::currentNM()->mkNode(kind::MULT,d_constants->d_NEGATIVE_ONE_NODE,n);
-  return tmp;
-}
-
-Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
-  Node negR = makeUnaryMinusNode(r);
-  Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
-
-  return diff;
-}
-
-
-Kind multKind(Kind k, int sgn){
-  using namespace kind;
-
-  if(sgn < 0){
-
-    switch(k){
-    case LT: return GT;
-    case LEQ: return GEQ;
-    case EQUAL: return EQUAL;
-    case GEQ: return LEQ;
-    case GT: return LT;
-    default:
-      Unhandled(k);
-    }
-    return NULL_EXPR;
-  }else{
-    return k;
-  }
-}
-
-Node ArithRewriter::rewrite(TNode n){
-  Debug("arithrewriter") << "Trace rewrite:" << n << std::endl;
-
-  Node res;
-
-  if(isRelationOperator(n.getKind())){
-    res = rewriteAtom(n);
-  }else{
-    res = rewriteTerm(n);
-  }
-
-  Debug("arithrewriter") << "Trace rewrite:" << n << "|->"<< res << std::endl;
-
-  return res;
-}
diff --git a/src/theory/arith/arith_rewriter.h b/src/theory/arith/arith_rewriter.h
deleted file mode 100644 (file)
index a76ee6e..0000000
+++ /dev/null
@@ -1,123 +0,0 @@
-/*********************                                                        */
-/*! \file arith_rewriter.h
- ** \verbatim
- ** Original author: taking
- ** Major contributors: mdeters
- ** Minor contributors (to current version): none
- ** This file is part of the CVC4 prototype.
- ** Copyright (c) 2009, 2010  The Analysis of Computer Systems Group (ACSys)
- ** Courant Institute of Mathematical Sciences
- ** New York University
- ** See the file COPYING in the top-level source directory for licensing
- ** information.\endverbatim
- **
- ** \brief [[ Add one-line brief description here ]]
- **
- ** [[ Add lengthier description here ]]
- ** \todo document this file
- **/
-
-
-#include "expr/node.h"
-#include "util/rational.h"
-#include "theory/arith/arith_constants.h"
-
-#ifndef __CVC4__THEORY__ARITH__REWRITER_H
-#define __CVC4__THEORY__ARITH__REWRITER_H
-
-namespace CVC4 {
-namespace theory {
-namespace arith {
-
-
-/***********************************************/
-/***************** Normal Form *****************/
-/***********************************************/
-/***********************************************/
-
-/**
- * Normal form for predicates:
- *    TRUE
- *    FALSE
- *    v |><| b
- *    p |><| b
- *    (+ p_1 .. p_N)  |><| b
- *  where
- *   1) b is of type CONST_RATIONAL
- *   2) |><| is of kind <, <=, =, >= or >
- *   3) p, p_i is in PNF,
- *   4) p.M >= 2
- *   5) p_i's are in strictly ascending <p,
- *   6) N >= 2,
- *   7) the kind of (+ p_1 .. p_N) is an N arity PLUS,
- *   8) p.d, p_1.d are 1,
- *   9) v has metakind variable, and
- *
- * PNF(t):
- *    (* d v_1 v_2 ... v_M)
- *  where
- *   1) d is of type CONST_RATIONAL,
- *   2) d != 0,
- *   4) M>=1,
- *   5) v_i are of metakind VARIABLE,
- *   6) v_i are in increasing (not strict) nodeOrder, and
- *   7) the kind of t is an M+1 arity MULT.
- *
- * <p is defined over PNF as follows (skipping some symmetry):
- *   cmp( (* d v_1 v_2 ... v_M), (* d' v'_1 v'_2 ... v'_M'):
- *      if(M == M'):
- *      then tupleCompare(v_i, v'_i)
- *      else M -M'
- *
- * Rewrite Normal Form for Terms:
- *    b
- *    v
- *    (+ c p_1 p_2 ... p_N)  |  not(N=1 and c=0 and p_1.d=1)
- *  where
- *   1) b,c is of type CONST_RATIONAL,
- *   3) p_i is in PNF,
- *   4) N >= 1
- *   5) the kind of (+ c p_1 p_2 ... p_N) is an N+1 arity PLUS,
- *   6) and p_i's are in strictly <p.
- *
- */
-
-class ArithRewriter{
-private:
-  ArithConstants* d_constants;
-
-  //This is where the core of the work is done for rewriteAtom
-  //With a few additional checks done by rewriteAtom
-  Node rewriteAtomCore(TNode atom);
-  Node rewriteAtom(TNode atom);
-
-  Node rewriteTerm(TNode t);
-  Node rewriteMult(TNode t);
-  Node rewritePlus(TNode t);
-  Node rewriteMinus(TNode t);
-  Node makeSubtractionNode(TNode l, TNode r);
-  Node makeUnaryMinusNode(TNode n);
-
-
-  Node var2pnf(TNode variable);
-
-  Node multPnfByNonZero(TNode pnf, Rational& q);
-
-  Node rewriteDivByConstant(TNode t);
-  void sortAndCombineCoefficients(std::vector<Node>& pnfs);
-
-
-public:
-  ArithRewriter(ArithConstants* ac) :
-    d_constants(ac)
-  {}
-  Node rewrite(TNode t);
-
-};
-
-
-}; /* namesapce arith */
-}; /* namespace theory */
-}; /* namespace CVC4 */
-
-#endif /* __CVC4__THEORY__ARITH__REWRITER_H */
index fa3356c6075f1f346eeb37d35437c8435ef2888f..6706ad76a02d4360d8d636f79f5bf39cf558338a 100644 (file)
@@ -27,7 +27,7 @@ namespace CVC4 {
 namespace theory {
 namespace arith {
 
-inline Node mkRationalNode(Rational& q){
+inline Node mkRationalNode(const Rational& q){
   return NodeManager::currentNM()->mkConst<Rational>(q);
 }
 
@@ -87,6 +87,21 @@ inline bool isRelationOperator(Kind k){
   }
 }
 
+/** is k \in {LT, LEQ, EQ, GEQ, GT} */
+inline Kind negateRelationKind(Kind k){
+  using namespace kind;
+
+  switch(k){
+  case LT: return GT;
+  case LEQ: return GEQ;
+  case EQUAL: return EQUAL;
+  case GEQ: return LEQ;
+  case GT: return LT;
+
+  default:
+    Unreachable();
+  }
+}
 inline bool evaluateConstantPredicate(Kind k, const Rational& left, const Rational& right){
   using namespace kind;
 
index 99f7258da07d8194b1cd510c252f4ebc8717e0fb..07d48b1f6bca527d3d2b8672d425b08c11bea032 100644 (file)
@@ -12,6 +12,8 @@ operator MINUS 2 "arithmetic binary subtraction operator"
 operator UMINUS 1 "arithmetic unary negation"
 operator DIVISION 2 "arithmetic division"
 
+operator IDENTITY 1 "identity function"
+
 constant CONST_RATIONAL \
     ::CVC4::Rational \
     ::CVC4::RationalHashStrategy \
diff --git a/src/theory/arith/next_arith_rewriter.cpp b/src/theory/arith/next_arith_rewriter.cpp
new file mode 100644 (file)
index 0000000..c14f806
--- /dev/null
@@ -0,0 +1,326 @@
+/*********************                                                        */
+/*! \file arith_rewriter.cpp
+ ** \verbatim
+ ** Original author: taking
+ ** Major contributors: none
+ ** Minor contributors (to current version): mdeters
+ ** This file is part of the CVC4 prototype.
+ ** Copyright (c) 2009, 2010  The Analysis of Computer Systems Group (ACSys)
+ ** Courant Institute of Mathematical Sciences
+ ** New York University
+ ** See the file COPYING in the top-level source directory for licensing
+ ** information.\endverbatim
+ **
+ ** \brief [[ Add one-line brief description here ]]
+ **
+ ** [[ Add lengthier description here ]]
+ ** \todo document this file
+ **/
+
+
+#include "theory/theory.h"
+#include "theory/arith/normal_form.h"
+#include "theory/arith/next_arith_rewriter.h"
+#include "theory/arith/arith_utilities.h"
+
+#include <vector>
+#include <set>
+#include <stack>
+
+
+using namespace CVC4;
+using namespace CVC4::theory;
+using namespace CVC4::theory::arith;
+
+bool isVariable(TNode t){
+  return t.getMetaKind() == kind::metakind::VARIABLE;
+}
+
+RewriteResponse NextArithRewriter::rewriteConstant(TNode t){
+  Assert(t.getMetaKind() == kind::metakind::CONSTANT);
+  Node val = coerceToRationalNode(t);
+
+  return RewriteComplete(val);
+}
+
+RewriteResponse NextArithRewriter::rewriteVariable(TNode t){
+  Assert(isVariable(t));
+
+  return RewriteComplete(t);
+}
+
+RewriteResponse NextArithRewriter::rewriteMinus(TNode t, bool pre){
+  Assert(t.getKind()== kind::MINUS);
+
+  if(t[0] == t[1]) return RewriteComplete(d_constants->d_ZERO_NODE);
+
+  Node noMinus = makeSubtractionNode(t[0],t[1]);
+  if(pre){
+    return RewriteComplete(noMinus);
+  }else{
+    return FullRewriteNeeded(noMinus);
+  }
+}
+
+RewriteResponse NextArithRewriter::rewriteUMinus(TNode t, bool pre){
+  Assert(t.getKind()== kind::UMINUS);
+
+  Node noUminus = makeUnaryMinusNode(t[0]);
+  if(pre)
+    return RewriteComplete(noUminus);
+  else
+    return RewriteAgain(noUminus);
+}
+
+RewriteResponse NextArithRewriter::preRewriteTerm(TNode t){
+  if(t.getMetaKind() == kind::metakind::CONSTANT){
+    return rewriteConstant(t);
+  }else if(isVariable(t)){
+    return rewriteVariable(t);
+  }else if(t.getKind() == kind::MINUS){
+    return rewriteMinus(t, true);
+  }else if(t.getKind() == kind::UMINUS){
+    return rewriteUMinus(t, true);
+  }else if(t.getKind() == kind::DIVISION){
+    if(t[0].getKind()== kind::CONST_RATIONAL){
+      return rewriteDivByConstant(t, true);
+    }else{
+      return RewriteComplete(t);
+    }
+  }else if(t.getKind() == kind::PLUS){
+    return preRewritePlus(t);
+  }else if(t.getKind() == kind::MULT){
+    return preRewriteMult(t);
+  }else{
+    Unreachable();
+  }
+}
+RewriteResponse NextArithRewriter::postRewriteTerm(TNode t){
+  if(t.getMetaKind() == kind::metakind::CONSTANT){
+    return rewriteConstant(t);
+  }else if(isVariable(t)){
+    return rewriteVariable(t);
+  }else if(t.getKind() == kind::MINUS){
+    return rewriteMinus(t, false);
+  }else if(t.getKind() == kind::UMINUS){
+    return rewriteUMinus(t, false);
+  }else if(t.getKind() == kind::DIVISION){
+    return rewriteDivByConstant(t, false);
+  }else if(t.getKind() == kind::PLUS){
+    return postRewritePlus(t);
+  }else if(t.getKind() == kind::MULT){
+    return postRewriteMult(t);
+  }else{
+    Unreachable();
+  }
+}
+
+RewriteResponse NextArithRewriter::preRewriteMult(TNode t){
+  Assert(t.getKind()== kind::MULT);
+
+  // Rewrite multiplications with a 0 argument and to 0
+  Integer intZero;
+
+  for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
+    if((*i).getKind() == kind::CONST_RATIONAL) {
+      if((*i).getConst<Rational>() == d_constants->d_ZERO) {
+        return RewriteComplete(d_constants->d_ZERO_NODE);
+      }
+    } else if((*i).getKind() == kind::CONST_INTEGER) {
+      if((*i).getConst<Integer>() == intZero) {
+        if(t.getType().isInteger()) {
+          return RewriteComplete(NodeManager::currentNM()->mkConst(intZero));
+        } else {
+          return RewriteComplete(d_constants->d_ZERO_NODE);
+        }
+      }
+    }
+  }
+  return RewriteComplete(t);
+}
+RewriteResponse NextArithRewriter::preRewritePlus(TNode t){
+  Assert(t.getKind()== kind::PLUS);
+
+  return RewriteComplete(t);
+}
+
+RewriteResponse NextArithRewriter::postRewritePlus(TNode t){
+  Assert(t.getKind()== kind::PLUS);
+
+  Polynomial res = Polynomial::mkZero();
+
+  for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
+    Node curr = *i;
+    Polynomial currPoly = Polynomial::parsePolynomial(curr);
+
+    res = res + currPoly;
+  }
+
+  return RewriteComplete(res.getNode());
+}
+
+RewriteResponse NextArithRewriter::postRewriteMult(TNode t){
+  Assert(t.getKind()== kind::MULT);
+
+  Polynomial res = Polynomial::mkOne();
+
+  for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
+    Node curr = *i;
+    Polynomial currPoly = Polynomial::parsePolynomial(curr);
+
+    res = res * currPoly;
+  }
+
+  return RewriteComplete(res.getNode());
+}
+
+RewriteResponse NextArithRewriter::postRewriteAtomConstantRHS(TNode t){
+  TNode left  = t[0];
+  TNode right = t[1];
+
+
+  Comparison cmp = Comparison::mkComparison(t.getKind(), Polynomial::parsePolynomial(left), Constant(right));
+
+  if(cmp.isBoolean()){
+    return RewriteComplete(cmp.getNode());
+  }
+
+  if(cmp.getLeft().containsConstant()){
+    Monomial constantHead = cmp.getLeft().getHead();
+    Assert(constantHead.isConstant());
+
+    Constant constant = constantHead.getConstant();
+
+    Constant negativeConstantHead = -constant;
+
+    cmp = cmp.addConstant(negativeConstantHead);
+  }
+  Assert(!cmp.getLeft().containsConstant());
+
+  if(!cmp.getLeft().getHead().coefficientIsOne()){
+    Monomial constantHead = cmp.getLeft().getHead();
+    Assert(!constantHead.isConstant());
+    Constant constant = constantHead.getConstant();
+
+    Constant inverse = Constant::mkConstant(constant.getValue().inverse());
+
+    cmp = cmp.multiplyConstant(inverse);
+  }
+  Assert(cmp.getLeft().getHead().coefficientIsOne());
+
+  Assert(cmp.isBoolean() || cmp.isNormalForm());
+  return RewriteComplete(cmp.getNode());
+}
+
+RewriteResponse NextArithRewriter::postRewriteAtom(TNode atom){
+  // left |><| right
+  TNode left = atom[0];
+  TNode right = atom[1];
+
+  if(right.getMetaKind() == kind::metakind::CONSTANT){
+    return postRewriteAtomConstantRHS(atom);
+  }else{
+    //Transform this to: (left - right) |><| 0
+    Node diff = makeSubtractionNode(left, right);
+    Node reduction = NodeManager::currentNM()->mkNode(atom.getKind(), diff, d_constants->d_ZERO_NODE);
+    return FullRewriteNeeded(reduction);
+  }
+}
+
+RewriteResponse NextArithRewriter::preRewriteAtom(TNode atom){
+  Assert(isAtom(atom));
+  NodeManager* currNM = NodeManager::currentNM();
+
+  if(atom.getKind() == kind::EQUAL) {
+    if(atom[0] == atom[1]) {
+      return RewriteComplete(currNM->mkConst(true));
+    }
+  }
+
+  Node reduction = atom;
+
+  if(atom[1].getMetaKind() != kind::metakind::CONSTANT){
+    // left |><| right
+    TNode left = atom[0];
+    TNode right = atom[1];
+
+    //Transform this to: (left - right) |><| 0
+    Node diff = makeSubtractionNode(left, right);
+    reduction = currNM->mkNode(atom.getKind(), diff, d_constants->d_ZERO_NODE);
+  }
+
+  if(reduction.getKind() == kind::GT){
+    Node leq = currNM->mkNode(kind::LEQ, reduction[0], reduction[1]);
+    reduction = currNM->mkNode(kind::NOT, leq);
+  }else if(reduction.getKind() == kind::LT){
+    Node geq = currNM->mkNode(kind::GEQ, reduction[0], reduction[1]);
+    reduction = currNM->mkNode(kind::NOT, geq);
+  }
+
+  return RewriteComplete(reduction);
+}
+
+RewriteResponse NextArithRewriter::postRewrite(TNode t){
+  if(isTerm(t)){
+    RewriteResponse response = postRewriteTerm(t);
+    if(Debug.isOn("arith::rewriter") && response.isDone()) {
+      Polynomial::parsePolynomial(response.getNode());
+    }
+    return response;
+  }else if(isAtom(t)){
+    RewriteResponse response = postRewriteAtom(t);
+    if(Debug.isOn("arith::rewriter") && response.isDone()) {
+      Comparison::parseNormalForm(response.getNode());
+    }
+    return response;
+  }else{
+    Unreachable();
+    return RewriteComplete(Node::null());
+  }
+}
+
+RewriteResponse NextArithRewriter::preRewrite(TNode t){
+  if(isTerm(t)){
+    return preRewriteTerm(t);
+  }else if(isAtom(t)){
+    return preRewriteAtom(t);
+  }else{
+    Unreachable();
+    return RewriteComplete(Node::null());
+  }
+}
+
+Node NextArithRewriter::makeUnaryMinusNode(TNode n){
+  return NodeManager::currentNM()->mkNode(kind::MULT,d_constants->d_NEGATIVE_ONE_NODE,n);
+}
+
+Node NextArithRewriter::makeSubtractionNode(TNode l, TNode r){
+  Node negR = makeUnaryMinusNode(r);
+  Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
+
+  return diff;
+}
+
+RewriteResponse NextArithRewriter::rewriteDivByConstant(TNode t, bool pre){
+  Assert(t.getKind()== kind::DIVISION);
+
+  Node left = t[0];
+  Node right = t[1];
+  Assert(right.getKind()== kind::CONST_RATIONAL);
+
+
+  const Rational& den = right.getConst<Rational>();
+
+  Assert(den != d_constants->d_ZERO);
+
+  Rational div = den.inverse();
+
+  Node result = mkRationalNode(div);
+
+  Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
+  if(pre){
+    return RewriteComplete(mult);
+  }else{
+    return RewriteAgain(mult);
+  }
+}
diff --git a/src/theory/arith/next_arith_rewriter.h b/src/theory/arith/next_arith_rewriter.h
new file mode 100644 (file)
index 0000000..7f1ec0f
--- /dev/null
@@ -0,0 +1,74 @@
+/*********************                                                        */
+/*! \file arith_rewriter.h
+ ** \verbatim
+ ** Original author: taking
+ ** Major contributors: mdeters
+ ** Minor contributors (to current version): none
+ ** This file is part of the CVC4 prototype.
+ ** Copyright (c) 2009, 2010  The Analysis of Computer Systems Group (ACSys)
+ ** Courant Institute of Mathematical Sciences
+ ** New York University
+ ** See the file COPYING in the top-level source directory for licensing
+ ** information.\endverbatim
+ **
+ ** \brief [[ Add one-line brief description here ]]
+ **
+ ** [[ Add lengthier description here ]]
+ ** \todo document this file
+ **/
+
+#include "theory/arith/arith_constants.h"
+#include "theory/theory.h"
+#include "theory/arith/normal_form.h"
+
+#ifndef __CVC4__THEORY__ARITH__REWRITER_NEXT_H
+#define __CVC4__THEORY__ARITH__REWRITER_NEXT_H
+
+namespace CVC4 {
+namespace theory {
+namespace arith {
+
+class NextArithRewriter{
+private:
+  ArithConstants* d_constants;
+
+  Node makeSubtractionNode(TNode l, TNode r);
+  Node makeUnaryMinusNode(TNode n);
+
+  RewriteResponse preRewriteTerm(TNode t);
+  RewriteResponse postRewriteTerm(TNode t);
+
+  RewriteResponse rewriteVariable(TNode t);
+  RewriteResponse rewriteConstant(TNode t);
+  RewriteResponse rewriteMinus(TNode t, bool pre);
+  RewriteResponse rewriteUMinus(TNode t, bool pre);
+  RewriteResponse rewriteDivByConstant(TNode t, bool pre);
+
+  RewriteResponse preRewritePlus(TNode t);
+  RewriteResponse postRewritePlus(TNode t);
+
+  RewriteResponse preRewriteMult(TNode t);
+  RewriteResponse postRewriteMult(TNode t);
+
+
+  RewriteResponse preRewriteAtom(TNode t);
+  RewriteResponse postRewriteAtom(TNode t);
+  RewriteResponse postRewriteAtomConstantRHS(TNode t);
+
+public:
+  NextArithRewriter(ArithConstants* ac) : d_constants(ac) {}
+
+  RewriteResponse preRewrite(TNode n);
+  RewriteResponse postRewrite(TNode n);
+
+private:
+  bool isAtom(TNode n) const { return isRelationOperator(n.getKind()); }
+  bool isTerm(TNode n) const { return !isAtom(n); }
+};
+
+
+}; /* namesapce arith */
+}; /* namespace theory */
+}; /* namespace CVC4 */
+
+#endif /* __CVC4__THEORY__ARITH__REWRITER_NEXT_H */
diff --git a/src/theory/arith/normal_form.cpp b/src/theory/arith/normal_form.cpp
new file mode 100644 (file)
index 0000000..18e3184
--- /dev/null
@@ -0,0 +1,250 @@
+
+#include "theory/arith/normal_form.h"
+#include <list>
+
+using namespace std;
+using namespace CVC4;
+using namespace CVC4::theory;
+using namespace CVC4::theory::arith;
+
+bool VarList::isSorted(iterator start, iterator end){
+  return __gnu_cxx::is_sorted(start, end);
+}
+
+bool VarList::isMember(Node n){
+  if(n.getNumChildren() == 0){
+    return Variable::isMember(n);
+  }else if(n.getKind() == kind::MULT){
+    Node::iterator curr = n.begin(), end = n.end();
+    Node prev = *curr;
+    if(!Variable::isMember(prev)) return false;
+
+    while( (++curr) != end){
+      if(!Variable::isMember(*curr)) return false;
+      if(!(prev <= *curr)) return false;
+      prev = *curr;
+    }
+    return true;
+  }else{
+    return false;
+  }
+}
+int VarList::cmp(const VarList& vl) const{
+  int dif = this->size() - vl.size();
+  if (dif == 0){
+    return this->getNode().getId() - vl.getNode().getId();
+  }else if(dif < 0){
+    return -1;
+  }else{
+    return 1;
+  }
+}
+
+VarList VarList::parseVarList(Node n){
+  if(n.getNumChildren() == 0){
+    return VarList(Variable(n));
+  }else{
+    Assert(n.getKind() == kind::MULT);
+    for(Node::iterator i=n.begin(), end = n.end(); i!=end; ++i){
+      Assert(Variable::isMember(*i));
+    }
+    return VarList(n);
+  }
+}
+
+VarList VarList::operator*(const VarList& vl) const{
+  if(this->empty()){
+    return vl;
+  }else if(vl.empty()){
+    return *this;
+  }else{
+    vector<Variable> result;
+    vector<Variable> thisAsVec = this->asList();
+    vector<Variable> vlAsVec = vl.asList();
+    back_insert_iterator<vector<Variable> > bii(result);
+
+    merge(thisAsVec.begin(), thisAsVec.end(), vlAsVec.begin(), vlAsVec.end(), bii);
+
+    return VarList::mkVarList(result);
+  }
+}
+
+std::vector<Variable> VarList::asList() const {
+  vector<Variable> res;
+  for(iterator i = begin(), e = end(); i != e; ++i){
+    res.push_back(*i);
+  }
+  return res;
+}
+
+Monomial Monomial::mkMonomial(const Constant& c, const VarList& vl){
+  if(c.isZero() || vl.empty() ){
+    return Monomial(c);
+  }else if(c.isOne()){
+    return Monomial(vl);
+  }else{
+    return Monomial(c, vl);
+  }
+}
+Monomial Monomial::parseMonomial(Node n){
+  if(n.getKind() == kind::CONST_RATIONAL){
+    return Monomial(Constant(n));
+  }else if(multStructured(n)){
+    return Monomial::mkMonomial(Constant(n[0]),VarList::parseVarList(n[1]));
+  }else{
+    return Monomial(VarList::parseVarList(n));
+  }
+}
+
+Monomial Monomial::operator*(const Monomial& mono) const {
+  Constant newConstant = this->getConstant() * mono.getConstant();
+  VarList newVL = this->getVarList() * mono.getVarList();
+
+  return Monomial::mkMonomial(newConstant, newVL);
+}
+
+vector<Monomial> Monomial::sumLikeTerms(const vector<Monomial> & monos){
+  Assert(isSorted(monos));
+
+  Debug("blah") << "start sumLikeTerms" << std::endl;
+  printList(monos);
+  vector<Monomial> outMonomials;
+  typedef vector<Monomial>::const_iterator iterator;
+  for(iterator rangeIter = monos.begin(), end=monos.end(); rangeIter != end;){
+    Rational constant = (*rangeIter).getConstant().getValue();
+    VarList varList  = (*rangeIter).getVarList();
+    ++rangeIter;
+    while(rangeIter != end && varList == (*rangeIter).getVarList()){
+      constant += (*rangeIter).getConstant().getValue();
+      ++rangeIter;
+    }
+    if(constant != 0){
+      Constant asConstant = Constant::mkConstant(constant);
+      Monomial nonZero = Monomial::mkMonomial(asConstant, varList);
+      outMonomials.push_back(nonZero);
+    }
+  }
+  Debug("blah") << "outmonomials" << std::endl;
+  printList(monos);
+  Debug("blah") << "end sumLikeTerms" << std::endl;
+
+  Assert(isStrictlySorted(outMonomials));
+  return outMonomials;
+}
+
+void Monomial::printList(const std::vector<Monomial>& monos){
+  typedef std::vector<Monomial>::const_iterator iterator;
+  for(iterator i = monos.begin(), end = monos.end(); i != end; ++i){
+    Debug("blah") <<  ((*i).getNode()) << std::endl;
+  }
+}
+
+Polynomial Polynomial::operator+(const Polynomial& vl) const{
+  this->printList();
+  vl.printList();
+
+  std::vector<Monomial> sortedMonos;
+  std::back_insert_iterator<std::vector<Monomial> > bii(sortedMonos);
+  std::merge(begin(), end(), vl.begin(), vl.end(), bii);
+
+  std::vector<Monomial> combined = Monomial::sumLikeTerms(sortedMonos);
+
+  Polynomial result = mkPolynomial(combined);
+  result.printList();
+  return result;
+}
+
+Polynomial Polynomial::operator*(const Monomial& mono) const{
+  if(mono.isZero()){
+    return Polynomial(mono); //Don't multiply by zero
+  }else{
+    std::vector<Monomial> newMonos;
+    for(iterator i = this->begin(), end = this->end(); i != end; ++i){
+      newMonos.push_back(mono * (*i));
+    }
+    return Polynomial::mkPolynomial(newMonos);
+  }
+}
+
+Polynomial Polynomial::operator*(const Polynomial& poly) const{
+
+  Polynomial res = Polynomial::mkZero();
+  for(iterator i = this->begin(), end = this->end(); i != end; ++i){
+    Monomial curr = *i;
+    Polynomial prod = poly * curr;
+    Polynomial sum  = res + prod;
+    res = sum;
+  }
+  return res;
+}
+
+
+Node Comparison::toNode(Kind k, const Polynomial& l, const Constant& r){
+  Assert(!l.isConstant());
+  Assert(isRelationOperator(k));
+  switch(k){
+  case kind::GEQ:
+  case kind::EQUAL:
+  case kind::LEQ:
+    return NodeManager::currentNM()->mkNode(k, l.getNode(),r.getNode());
+  case kind::LT:
+    return NodeManager::currentNM()->mkNode(kind::NOT, toNode(kind::GEQ,l,r));
+  case kind::GT:
+    return NodeManager::currentNM()->mkNode(kind::NOT, toNode(kind::LEQ,l,r));
+  default:
+    Unreachable();
+  }
+}
+
+Comparison Comparison::parseNormalForm(TNode n){
+  if(n.getKind() == kind::CONST_BOOLEAN){
+    return Comparison(n.getConst<bool>());
+  }else{
+    bool negated = n.getKind() == kind::NOT;
+    Node relation = negated ? n[0] : n;
+    Assert( !negated ||
+            relation.getKind() == kind::LEQ ||
+            relation.getKind() == kind::GEQ);
+
+    Polynomial left = Polynomial::parsePolynomial(relation[0]);
+    Constant right(relation[1]);
+
+    Kind newOperator = relation.getKind();
+    if(negated){
+      if(newOperator == kind::LEQ){
+        newOperator = kind::GT;
+      }else{
+        newOperator = kind::LT;
+      }
+    }
+    return Comparison(n, newOperator, left, right);
+  }
+}
+
+Comparison Comparison::mkComparison(Kind k, const Polynomial& left, const Constant& right){
+  Assert(isRelationOperator(k));
+  if(left.isConstant()){
+    const Rational& rConst =  left.getNode().getConst<Rational>();
+    const Rational& lConst = right.getNode().getConst<Rational>();
+    bool res = evaluateConstantPredicate(k, lConst, rConst);
+    return Comparison(res);
+  }else{
+    return Comparison(toNode(k, left, right), k, left, right);
+  }
+}
+
+Comparison Comparison::addConstant(const Constant& constant) const{
+  Assert(!isBoolean());
+  Monomial mono(constant);
+  Polynomial constAsPoly( mono );
+  Polynomial newLeft =  getLeft() + constAsPoly;
+  Constant newRight = getRight() + constant;
+  return mkComparison(oper, newLeft, newRight);
+}
+
+Comparison Comparison::multiplyConstant(const Constant& constant) const{
+  Assert(!isBoolean());
+  Kind newOper = (constant.getValue() < 0) ? negateRelationKind(oper) : oper;
+
+  return mkComparison(newOper, left*Monomial(constant), right*constant);
+}
diff --git a/src/theory/arith/normal_form.h b/src/theory/arith/normal_form.h
new file mode 100644 (file)
index 0000000..1f7bc6b
--- /dev/null
@@ -0,0 +1,613 @@
+
+#include "expr/node.h"
+#include "util/rational.h"
+#include "theory/arith/arith_constants.h"
+#include "theory/arith/arith_utilities.h"
+
+#include <list>
+#include <algorithm>
+#include <ext/algorithm>
+
+#ifndef __CVC4__THEORY__ARITH__NORMAL_FORM_H
+#define __CVC4__THEORY__ARITH__NORMAL_FORM_H
+
+namespace CVC4 {
+namespace theory {
+namespace arith {
+
+/***********************************************/
+/***************** Normal Form *****************/
+/***********************************************/
+/***********************************************/
+
+/**
+ * Section 1: Languages
+ * The normal form for arithmetic nodes is defined by the language
+ * accepted by the following BNFs with some guard conditions.
+ * (The guard conditions are in Section 3 for completeness.)
+ *
+ * variable := n
+ *   where
+ *     n.getMetaKind() == metakind::VARIABLE
+
+ * constant := n
+ *   where
+ *     n.getKind() == kind::CONST_RATIONAL
+
+ * var_list := variable | (* [variable])
+ *   where
+ *     len [variable] >= 2
+ *     isSorted varOrder [variable]
+
+ * monomial := constant | var_list | (* constant' var_list')
+ *   where
+ *     constant' \not\in {0,1}
+
+ * polynomial := monomial | (+ [monomial])
+ *   where
+ *     len [monomial] >= 2
+ *     isStrictlySorted monoOrder [monomial]
+
+ * restricted_cmp := (|><| polynomial constant)
+ *   where
+ *     |><| is GEQ, EQ, or EQ
+ *     not (exists constantMonomial (monomialList polynomial))
+ *     monomialCoefficient (head (monomialList polynomial)) == 1
+
+ * comparison := TRUE | FALSE | restricted_cmp | (not restricted_cmp)
+
+ * Normal Form for terms := polynomial
+ * Normal Form for atoms := comparison
+ */
+
+/**
+ * Section 2: Helper Classes
+ * The langauges accepted by each of these defintions
+ * roughly corresponds to one of the following helper classes:
+ *  Variable
+ *  Constant
+ *  VarList
+ *  Monomial
+ *  Polynomial
+ *  Comparison
+ *
+ * Each of the classes obeys the following contracts/design decisions:
+ * -Calling isMember(Node node) on a node returns true iff that node is a
+ *  a member of the language. Note: isMember is O(n).
+ * -Calling isNormalForm() on a helper class object returns true iff that
+ *  helper class currently represents a normal form object.
+ * -If isNormalForm() is false, then this object must have been made
+ *  using a mk*() factory function.
+ * -If isNormalForm() is true, calling getNode() on all of these classes
+ *  returns a node that would be accepted by the corresponding language.
+ *  And if isNormalForm() is false, returns Node::null().
+ * -Each of the classes is immutable.
+ * -Public facing constuctors have a 1-to-1 correspondence with one of
+ *  production rules in the above grammar.
+ * -Public facing constuctors are required to fail in debug mode when the
+ *  guards of the production rule are not strictly met.
+ *  For example: Monomial(Constant(1),VarList(Variable(x))) must fail.
+ * -When a class has a Class parseClass(Node node) function,
+ *  if isMember(node) is true, the function is required to return an instance
+ *  of the helper class, instance, s.t. instance.getNode() == node.
+ *  And if isMember(node) is false, this throws an assertion failure in debug
+ *  mode and has undefined behaviour if not in debug mode.
+ * -Only public facing constructors, parseClass(node), and mk*() functions are
+ *  considered privledged functions for the helper class.
+ * -Only privledged functions may use private constructors, and access
+ *  private data members.
+ * -All non-privledges functions are considered utility functions and
+ *  must use a privledged function in order to create an instance of the class.
+ */
+
+/**
+ * Section 3: Guard Conditions Misc.
+ *
+ *
+ *  var_list_len vl =
+ *    match vl with
+ *       variable -> 1
+ *     | (* [variable]) -> len [variable]
+ *
+ *  order res =
+ *    match res with
+ *       Empty -> (0,Node::null())
+ *     | NonEmpty(vl) -> (var_list_len vl, vl)
+ *
+ *  var_listOrder a b = tuple_cmp (order a) (order b)
+ *
+ *  monomialVarList monomial =
+ *    match monomial with
+ *        constant -> Empty
+ *      | var_list -> NonEmpty(var_list)
+ *      | (* constant' var_list') -> NonEmpty(var_list')
+ *
+ *  monoOrder m0 m1 = var_listOrder (monomialVarList m0) (monomialVarList m1)
+ *
+ *  constantMonomial monomial =
+ *    match monomial with
+ *        constant -> true
+ *      | var_list -> false
+ *      | (* constant' var_list') -> false
+ *
+ *  monomialCoefficient monomial =
+ *    match monomial with
+ *        constant -> constant
+ *      | var_list -> Constant(1)
+ *      | (* constant' var_list') -> constant'
+ *
+ *  monomialList polynomial =
+ *    match polynomial with
+ *        monomial -> monomial::[]
+ *      | (+ [monomial]) -> [monomial]
+ */
+
+/**
+ * A NodeWrapper is a class that is a thinly veiled container of a Node object.
+ */
+class NodeWrapper {
+private:
+  Node node;
+public:
+  NodeWrapper(Node n) : node(n) {}
+  const Node& getNode() const { return node; }
+};
+
+class Variable : public NodeWrapper {
+public:
+  Variable(Node n) : NodeWrapper(n) {
+    Assert(isMember(getNode()));
+  }
+
+  static bool isMember(Node n) {
+    return n.getMetaKind() == kind::metakind::VARIABLE;
+  }
+
+  bool isNormalForm() { return isMember(getNode()); }
+
+  bool operator<(const Variable& v) const{ return getNode() < v.getNode();}
+  bool operator==(const Variable& v) const{ return getNode() == v.getNode();}
+
+};
+
+class Constant : public NodeWrapper {
+public:
+  Constant(Node n) : NodeWrapper(n) {
+    Assert(isMember(getNode()));
+  }
+
+  static bool isMember(Node n) {
+    return n.getKind() == kind::CONST_RATIONAL;
+  }
+
+  bool isNormalForm() { return isMember(getNode()); }
+
+  static Constant mkConstant(Node n) {
+    return Constant(coerceToRationalNode(n));
+  }
+
+  static Constant mkConstant(const Rational& rat){
+    return Constant(mkRationalNode(rat));
+  }
+
+  const Rational& getValue() const {
+    return getNode().getConst<Rational>();
+  }
+
+  bool isZero() const{ return getValue() == 0; }
+  bool isOne() const{ return getValue() == 1; }
+
+  Constant operator*(const Constant& other) const{
+    return mkConstant(getValue() * other.getValue());
+  }
+  Constant operator+(const Constant& other) const{
+    return mkConstant(getValue() + other.getValue());
+  }
+  Constant operator-() const{
+    return mkConstant(-getValue());
+  }
+};
+
+template <class GetNodeIterator>
+inline Node makeNode(Kind k, GetNodeIterator start, GetNodeIterator end){
+  NodeBuilder<> nb(k);
+
+  while(start != end){
+    nb << (*start).getNode();
+    ++start;
+  }
+  return Node(nb);
+}
+
+/**
+ * A VarList is a sorted list of variables representing a product.
+ * If the VarList is empty, it represents an empty product or 1.
+ * If the VarList has size 1, it represents a single variable.
+ *
+ * A non-sorted VarList can never be successfully made in debug mode.
+ */
+class VarList {
+private:
+  Node backingNode;
+
+  static Node multList(const std::vector<Variable>& list){
+    Assert(list.size() >= 2);
+
+    return makeNode(kind::MULT, list.begin(), list.end());
+  }
+  static Node makeTuple(Node n){
+    return NodeManager::currentNM()->mkNode(kind::IDENTITY, n);
+  }
+
+  VarList() : backingNode(Node::null()){}
+
+  VarList(Node n){
+    backingNode = (Variable::isMember(n)) ? makeTuple(n) : n;
+
+    Assert(isSorted(begin(), end()));
+  }
+
+public:
+  class iterator  {
+  private:
+    Node::iterator d_iter;
+  public:
+    explicit iterator(Node::iterator i) : d_iter(i) {}
+
+    inline Variable operator*(){
+      return Variable(*d_iter);
+    }
+
+    bool operator==(const iterator& i){
+      return d_iter == i.d_iter;
+    }
+
+    bool operator!=(const iterator& i){
+      return d_iter != i.d_iter;
+    }
+
+    iterator operator++() {
+      ++d_iter;
+      return *this;
+    }
+
+    iterator operator++(int) {
+      return iterator(d_iter++);
+    }
+  };
+
+  Node getNode() const{
+    if(singleton()){
+      return backingNode[0];
+    }else{
+      return backingNode;
+    }
+  }
+
+  iterator begin() const{
+    return iterator(backingNode.begin());
+  }
+  iterator end() const{
+    return iterator(backingNode.end());
+  }
+
+  VarList(Variable v) : backingNode(makeTuple(v.getNode())){
+    Assert(isSorted(begin(), end()));
+  }
+  VarList(const std::vector<Variable>& l) : backingNode(multList(l)){
+    Assert(l.size() >= 2);
+    Assert(isSorted(begin(), end()));
+  }
+
+  static bool isMember(Node n);
+
+  bool isNormalForm() const{
+    return !empty();
+  }
+
+  static VarList mkEmptyVarList(){
+    return VarList();
+  }
+
+
+  /** There are no restrictions on the size of l */
+  static VarList mkVarList(const std::vector<Variable>& l){
+    if(l.size() == 0){
+      return mkEmptyVarList();
+    }else if(l.size() == 1){
+      return VarList((*l.begin()).getNode());
+    }else{
+      return VarList(l);
+    }
+  }
+
+  int size() const{ return backingNode.getNumChildren(); }
+  bool empty() const { return size() == 0; }
+  bool singleton() const { return backingNode.getKind() == kind::IDENTITY; }
+
+  static VarList parseVarList(Node n);
+
+  VarList operator*(const VarList& vl) const;
+
+  int cmp(const VarList& vl) const;
+
+  bool operator<(const VarList& vl) const{ return cmp(vl) < 0; }
+
+  bool operator==(const VarList& vl) const{ return cmp(vl) == 0; }
+
+  std::vector<Variable> asList() const;
+
+private:
+  bool isSorted(iterator start, iterator end);
+};
+
+class Monomial : public NodeWrapper {
+private:
+  Constant constant;
+  VarList varList;
+  Monomial(Node n, const Constant& c, const VarList& vl):
+    NodeWrapper(n), constant(c), varList(vl)
+  {
+    Assert(!c.isZero() ||  vl.empty() );
+    Assert( c.isZero() || !vl.empty() );
+
+    Assert(!c.isOne() || !multStructured(n));
+  }
+
+  static Node makeMultNode(const Constant& c, const VarList& vl){
+    Assert(!c.isZero());
+    Assert(!c.isOne());
+    Assert(!vl.empty());
+    return NodeManager::currentNM()->mkNode(kind::MULT, c.getNode(), vl.getNode());
+  }
+
+  static bool multStructured(Node n){
+    return n.getKind() ==  kind::MULT &&
+      n[0].getKind() == kind::CONST_RATIONAL &&
+      n.getNumChildren() == 2;
+  }
+
+public:
+
+  Monomial(const Constant& c):
+    NodeWrapper(c.getNode()), constant(c), varList(VarList::mkEmptyVarList())
+  { }
+
+  Monomial(const VarList& vl):
+    NodeWrapper(vl.getNode()), constant(Constant::mkConstant(1)), varList(vl)
+  {
+    Assert( !varList.empty() );
+  }
+
+  Monomial(const Constant& c, const VarList& vl):
+    NodeWrapper(makeMultNode(c,vl)), constant(c), varList(vl)
+  {
+    Assert( !c.isZero() );
+    Assert( !c.isOne() );
+    Assert( !varList.empty() );
+
+    Assert(multStructured(getNode()));
+  }
+
+  /** Makes a monomial with no restrictions on c and vl. */
+  static Monomial mkMonomial(const Constant& c, const VarList& vl);
+
+
+  static Monomial parseMonomial(Node n);
+
+  static Monomial mkZero(){
+    return Monomial(Constant::mkConstant(0));
+  }
+  static Monomial mkOne(){
+    return Monomial(Constant::mkConstant(1));
+  }
+  const Constant& getConstant() const{ return constant; }
+  const VarList& getVarList() const{ return varList; }
+
+  bool isConstant() const{
+    return varList.empty();
+  }
+
+  bool isZero() const{
+    return constant.isZero();
+  }
+
+  bool coefficientIsOne() const {
+    return constant.isOne();
+  }
+
+  Monomial operator*(const Monomial& mono) const;
+
+
+  int cmp(const Monomial& mono) const{
+    return getVarList().cmp(mono.getVarList());
+  }
+
+  bool operator<(const Monomial& vl) const{
+    return cmp(vl) < 0;
+  }
+
+  bool operator==(const Monomial& vl) const{
+    return cmp(vl) == 0;
+  }
+
+  static bool isSorted(const std::vector<Monomial>& m){
+    return __gnu_cxx::is_sorted(m.begin(), m.end());
+  }
+
+  static bool isStrictlySorted(const std::vector<Monomial>& m){
+    return isSorted(m) && std::adjacent_find(m.begin(),m.end()) == m.end();
+  }
+
+  /**
+   * Given a sorted list of monomials, this function transforms this
+   * into a strictly sorted list of monomials that does not contain zero.
+   */
+  static std::vector<Monomial> sumLikeTerms(const std::vector<Monomial>& monos);
+
+  static void printList(const std::vector<Monomial>& monos);
+};
+
+class Polynomial : public NodeWrapper {
+private:
+  std::vector<Monomial> monos;
+
+  Polynomial(Node n, const std::vector<Monomial>& m):
+    NodeWrapper(n), monos(m)
+  {
+    Assert( !monos.empty() );
+    Assert( Monomial::isStrictlySorted(monos) );
+  }
+
+  static Node makePlusNode(const std::vector<Monomial>& m){
+    Assert(m.size() >= 2);
+
+    return makeNode(kind::PLUS, m.begin(), m.end());
+  }
+
+public:
+  typedef std::vector<Monomial>::const_iterator iterator;
+
+  iterator begin() const{ return monos.begin(); }
+  iterator end() const{ return monos.end(); }
+
+  Polynomial(const Monomial& m):
+    NodeWrapper(m.getNode()), monos()
+  {
+    monos.push_back(m);
+  }
+  Polynomial(const std::vector<Monomial>& m):
+    NodeWrapper(makePlusNode(m)), monos(m)
+  {
+    Assert( monos.size() >= 2);
+    Assert( Monomial::isStrictlySorted(monos) );
+  }
+
+
+  static Polynomial mkPolynomial(const std::vector<Monomial>& m){
+    if(m.size() == 0){
+      return Polynomial(Monomial::mkZero());
+    }else if(m.size() == 1){
+      return Polynomial((*m.begin()));
+    }else{
+      return Polynomial(m);
+    }
+  }
+
+  static Polynomial parsePolynomial(Node n){
+    std::vector<Monomial> monos;
+    if(n.getKind() == kind::PLUS){
+      for(Node::iterator i=n.begin(), end=n.end(); i != end; ++i){
+        monos.push_back(Monomial::parseMonomial(*i));
+      }
+    }else{
+      monos.push_back(Monomial::parseMonomial(n));
+    }
+    return Polynomial(n,monos);
+  }
+
+  static Polynomial mkZero(){
+    return Polynomial(Monomial::mkZero());
+  }
+  static Polynomial mkOne(){
+    return Polynomial(Monomial::mkOne());
+  }
+  bool isZero() const{
+    return (monos.size() == 1) && (getHead().isZero());
+  }
+
+  bool isConstant() const{
+    return (monos.size() == 1) && (getHead().isConstant());
+  }
+
+  bool containsConstant() const{
+    return getHead().isConstant();
+  }
+
+  Monomial getHead() const{
+    return *(begin());
+  }
+
+  Polynomial getTail() const{
+    Assert(monos.size() >= 1);
+
+    iterator start = begin()+1;
+    std::vector<Monomial> subrange(start, end());
+    return mkPolynomial(subrange);
+  }
+
+  void printList() const{
+    Debug("blah") << "start list" << std::endl;
+    Monomial::printList(monos);
+    Debug("blah") << "end list" << std::endl;
+  }
+
+  Polynomial operator+(const Polynomial& vl) const;
+
+  Polynomial operator*(const Monomial& mono) const;
+
+  Polynomial operator*(const Polynomial& poly) const;
+
+};
+
+class Comparison : public NodeWrapper {
+private:
+  Kind oper;
+  Polynomial left;
+  Constant right;
+
+  static Node toNode(Kind k, const Polynomial& l, const Constant& r);
+
+  Comparison(TNode n, Kind k, const Polynomial& l, const Constant& r):
+    NodeWrapper(n), oper(k), left(l), right(r)
+  { }
+public:
+  Comparison(bool val) :
+    NodeWrapper(NodeManager::currentNM()->mkConst(val)),
+    oper(kind::CONST_BOOLEAN),
+    left(Polynomial::mkZero()),
+    right(Constant::mkConstant(0))
+  { }
+
+  Comparison(Kind k, const Polynomial& l, const Constant& r):
+    NodeWrapper(toNode(k, l, r)), oper(k), left(l), right(r)
+  {
+    Assert(isRelationOperator(oper));
+    Assert(!left.containsConstant());
+    Assert(left.getHead().getConstant().isOne());
+  }
+
+  static Comparison mkComparison(Kind k, const Polynomial& left, const Constant& right);
+
+  bool isBoolean() const{
+    return (oper == kind::CONST_BOOLEAN);
+  }
+
+  bool isNormalForm() const{
+    if(isBoolean()){
+      return true;
+    }else if(left.containsConstant()){
+      return false;
+    }else if(left.getHead().getConstant().isOne()){
+      return true;
+    }else{
+      return false;
+    }
+  }
+
+  const Polynomial& getLeft() const { return left; }
+  const Constant& getRight() const { return right; }
+
+  Comparison addConstant(const Constant& constant) const;
+  Comparison multiplyConstant(const Constant& constant) const;
+
+  static Comparison parseNormalForm(TNode n);
+};
+
+
+
+}; /* namesapce arith */
+}; /* namespace theory */
+}; /* namespace CVC4 */
+
+#endif /* __CVC4__THEORY__ARITH__NORMAL_FORM_H */
index 12d93d9fedb196443dd6a12711093d63e9ffb02a..603eb52788a972d54b037cd51287f725301fb418 100644 (file)
@@ -23,7 +23,7 @@
 
 #include "theory/arith/basic.h"
 #include "theory/arith/arith_activity.h"
-
+#include "theory/arith/normal_form.h"
 
 #include <ext/hash_map>
 #include <map>
@@ -52,21 +52,21 @@ public:
    * Construct a row equal to:
    *   basic = \sum_{x_i} c_i * x_i
    */
-  Row(TNode basic, TNode sum):
+  Row(TNode basic, const Polynomial& sum):
     d_x_i(basic),
     d_coeffs(){
 
     Assert(d_x_i.getMetaKind() == kind::metakind::VARIABLE);
-    Assert(sum.getKind() == kind::PLUS);
-
-    for(TNode::iterator iter=sum.begin(); iter != sum.end(); ++iter){
-      TNode pair = *iter;
-      Assert(pair.getKind() == kind::MULT);
-      Assert(pair.getNumChildren() == 2);
-      TNode coeff = pair[0];
-      TNode var_i = pair[1];
+
+    for(Polynomial::iterator iter=sum.begin(), end = sum.end(); iter != end; ++iter){
+      const Monomial& mono = *iter;
+
+      Assert(!mono.isConstant());
+
+      TNode coeff = mono.getConstant().getNode();
+      TNode var_i =  mono.getVarList().getNode();
+
       Assert(coeff.getKind() == kind::CONST_RATIONAL);
-      Assert(var_i.getKind() == kind::VARIABLE);
 
       Assert(!has(var_i));
       d_coeffs[var_i] = coeff.getConst<Rational>();
@@ -192,14 +192,13 @@ private:
 public:
 
   void addRow(TNode eq){
-    Assert(eq.getKind() == kind::EQUAL);
-    Assert(eq.getNumChildren() == 2);
-
     TNode var = eq[0];
-    TNode sum = eq[1];
+    TNode sumNode = eq[1];
 
     Assert(var.getAttribute(IsBasic()));
 
+    Polynomial sum = Polynomial::parsePolynomial(sumNode);
+
     //The new basic variable cannot already be a basic variable
     Assert(!isActiveBasicVariable(var));
     d_activeBasicVars.insert(var);
@@ -208,13 +207,11 @@ public:
 
     //A variable in the row may have been made non-basic already.
     //If this is the case we fake pivoting this variable
-    for(TNode::iterator sumIter = sum.begin(); sumIter!=sum.end(); ++sumIter){
-      TNode child = *sumIter;
-      Assert(child.getKind() == kind::MULT);
-      Assert(child.getNumChildren() == 2);
-      Assert(child[0].getKind() == kind::CONST_RATIONAL);
-      TNode c = child[1];
-      Assert(var.getMetaKind() == kind::metakind::VARIABLE);
+    for(Polynomial::iterator sumIter = sum.begin(); sumIter!= sum.end(); ++sumIter){
+      const Monomial& child = *sumIter;
+
+      Assert(!child.isConstant());
+      TNode c = child.getVarList().getNode();
       if(isActiveBasicVariable(c)){
         Row* row_c = lookup(c);
         row_var->subsitute(*row_c);
index 157c45160a1ca7aed26729a451bfbb9b8caf7283..8f17b01a98ea7770d3a06afd31cb71ec0b08a361 100644 (file)
 #include "theory/arith/basic.h"
 #include "theory/arith/arith_activity.h"
 
-#include "theory/arith/arith_rewriter.h"
+#include "theory/arith/next_arith_rewriter.h"
 #include "theory/arith/arith_propagator.h"
 
 #include "theory/arith/theory_arith.h"
+#include "theory/arith/normal_form.h"
+
 #include <map>
 #include <stdint.h>
 
@@ -55,7 +57,7 @@ TheoryArith::TheoryArith(int id, context::Context* c, OutputChannel& out) :
   d_constants(NodeManager::currentNM()),
   d_partialModel(c),
   d_diseq(c),
-  d_rewriter(&d_constants),
+  d_nextRewriter(&d_constants),
   d_propagator(c),
   d_statistics()
 {
@@ -109,22 +111,9 @@ bool isBasicSum(TNode n){
 
 bool isNormalAtom(TNode n){
 
-  if(!(n.getKind() == LEQ|| n.getKind() == GEQ || n.getKind() == EQUAL)){
-    return false;
-  }
-  TNode left = n[0];
-  TNode right = n[1];
-  if(right.getKind() != CONST_RATIONAL){
-    return false;
-  }
-  if(left.getMetaKind() == metakind::VARIABLE){
-    return true;
-  }else if(isBasicSum(left)){
-    return true;
-  }else{
-    return false;
-  }
+  Comparison parse = Comparison::parseNormalForm(n);
 
+  return parse.isNormalForm();
 }
 
 
@@ -213,7 +202,6 @@ void TheoryArith::preRegisterTerm(TNode n) {
     if(left.getKind() == PLUS){
       //We may need to introduce a slack variable.
       Assert(left.getNumChildren() >= 2);
-      Assert(isBasicSum(left));
       if(!left.hasAttribute(Slack())){
         setupSlack(left);
       }
@@ -229,11 +217,9 @@ void TheoryArith::setupSlack(TNode left){
   left.setAttribute(Slack(), slack);
   makeBasic(slack);
 
-  Node slackEqLeft = NodeManager::currentNM()->mkNode(EQUAL,slack,left);
-
-  Debug("slack") << "slack " << slackEqLeft << endl;
+  Node eq = NodeManager::currentNM()->mkNode(kind::EQUAL, slack, left);
 
-  d_tableau.addRow(slackEqLeft);
+  d_tableau.addRow(eq);
 
   setupVariable(slack);
 }
@@ -316,56 +302,9 @@ DeltaRational TheoryArith::computeRowValueUsingSavedAssignment(TNode x){
 }
 
 RewriteResponse TheoryArith::preRewrite(TNode n, bool topLevel) {
-  // ensure a hard link to the node we're returning
-  Node out;
-
-  // Look for multiplications with a 0 argument and rewrite the whole
-  // thing as 0
-  if(n.getKind() == MULT) {
-    Rational ratZero;
-    Integer intZero;
-    for(TNode::iterator i = n.begin(); i != n.end(); ++i) {
-      if((*i).getKind() == CONST_RATIONAL) {
-        if((*i).getConst<Rational>() == ratZero) {
-          out = NodeManager::currentNM()->mkConst(ratZero);
-          break;
-        }
-      } else if((*i).getKind() == CONST_INTEGER) {
-        if((*i).getConst<Integer>() == intZero) {
-          if(n.getType().isInteger()) {
-            out = NodeManager::currentNM()->mkConst(intZero);
-            break;
-          } else {
-            out = NodeManager::currentNM()->mkConst(ratZero);
-            break;
-          }
-        }
-      }
-    }
-  } else if(n.getKind() == EQUAL) {
-    if(n[0] == n[1]) {
-      out = NodeManager::currentNM()->mkConst(true);
-    }
-  }
-
-  if(out.isNull()) {
-    // no preRewrite to perform
-    return RewriteComplete(Node(n));
-  } else {
-    // out is always a constant, so doesn't need to be rewritten again
-    return RewriteComplete(out);
-  }
+  return d_nextRewriter.preRewrite(n);
 }
 
-Node TheoryArith::rewrite(TNode n){
-  Debug("arith") << "rewrite(" << n << ")" << endl;
-
-  Node result = d_rewriter.rewrite(n);
-  Debug("arith-rewrite") << "rewrite(" << n << ") -> " << result << endl;
-  return result;
-}
-
-
 void TheoryArith::registerTerm(TNode tn){
   Debug("arith") << "registerTerm(" << tn << ")" << endl;
 }
index 7367f57268f4523650b011421411c3741d4099a5..03be7a77b2f2fd32e29d4ed9fa30a55fa338218c 100644 (file)
@@ -28,7 +28,7 @@
 
 #include "theory/arith/delta_rational.h"
 #include "theory/arith/tableau.h"
-#include "theory/arith/arith_rewriter.h"
+#include "theory/arith/next_arith_rewriter.h"
 #include "theory/arith/partial_model.h"
 #include "theory/arith/arith_propagator.h"
 
@@ -94,7 +94,7 @@ private:
   /**
    * The rewriter module for arithmetic.
    */
-  ArithRewriter d_rewriter;
+  NextArithRewriter d_nextRewriter;
 
   ArithUnatePropagator d_propagator;
 
@@ -102,11 +102,6 @@ public:
   TheoryArith(int id, context::Context* c, OutputChannel& out);
   ~TheoryArith();
 
-  /**
-   * Rewrites a node to a unique normal form given in normal_form_notes.txt
-   */
-  Node rewrite(TNode n);
-
   /**
    * Rewriting optimizations.
    */
@@ -116,7 +111,7 @@ public:
    * Plug in old rewrite to the new (pre,post)rewrite interface.
    */
   RewriteResponse postRewrite(TNode n, bool topLevel) {
-    return RewriteComplete(topLevel ? rewrite(n) : Node(n));
+    return d_nextRewriter.postRewrite(n);
   }
 
   /**
index ea1ee698f53c06f3e863263bf1bd3b7198923179..763e03fdb42d1198467fcf8cd0a14f123407488e 100644 (file)
@@ -65,6 +65,32 @@ public:
 
   TheoryArithWhite() : d_level(Theory::FULL_EFFORT), d_zero(0), d_one(1), debug(false) {}
 
+  void fakeTheoryEnginePreprocess(TNode inp){
+    Node rewrite = inp; //FIXME this needs to enforce that inp is fully rewritten already!
+
+    if(debug) cout << rewrite << inp << endl;
+
+    std::list<Node> toPreregister;
+
+    toPreregister.push_back(rewrite);
+    for(std::list<Node>::iterator i = toPreregister.begin(); i != toPreregister.end(); ++i){
+      Node n = *i;
+      preregistered->insert(n);
+
+      for(Node::iterator citer = n.begin(); citer != n.end(); ++citer){
+        Node c = *citer;
+        if(preregistered->find(c) == preregistered->end()){
+          toPreregister.push_back(c);
+        }
+      }
+    }
+    for(std::list<Node>::reverse_iterator i = toPreregister.rbegin(); i != toPreregister.rend(); ++i){
+      Node n = *i;
+      if(debug) cout << n.getId() << " "<< n << endl;
+      d_arith->preRegisterTerm(n);
+    }
+  }
+
   void setUp() {
     d_ctxt = new Context;
     d_nm = new NodeManager(d_ctxt);
@@ -92,42 +118,14 @@ public:
     delete d_ctxt;
   }
 
-  Node fakeTheoryEnginePreprocess(TNode inp){
-    Node rewrite = d_arith->rewrite(inp);
-
-    if(debug) cout << rewrite << inp << endl;
-
-    std::list<Node> toPreregister;
-
-    toPreregister.push_back(rewrite);
-    for(std::list<Node>::iterator i = toPreregister.begin(); i != toPreregister.end(); ++i){
-      Node n = *i;
-      preregistered->insert(n);
-
-      for(Node::iterator citer = n.begin(); citer != n.end(); ++citer){
-        Node c = *citer;
-        if(preregistered->find(c) == preregistered->end()){
-          toPreregister.push_back(c);
-        }
-      }
-    }
-    for(std::list<Node>::reverse_iterator i = toPreregister.rbegin(); i != toPreregister.rend(); ++i){
-      Node n = *i;
-      if(debug) cout << n.getId() << " "<< n << endl;
-      d_arith->preRegisterTerm(n);
-    }
-
-    return rewrite;
-  }
-
   void testAssert() {
     Node x = d_nm->mkVar(*d_realType);
     Node c = d_nm->mkConst<Rational>(d_zero);
 
     Node leq = d_nm->mkNode(LEQ, x, c);
-    Node rLeq = fakeTheoryEnginePreprocess(leq);
+    fakeTheoryEnginePreprocess(leq);
 
-    d_arith->assertFact(rLeq);
+    d_arith->assertFact(leq);
 
     d_arith->check(d_level);
 
@@ -143,51 +141,19 @@ public:
     return dis;
   }
 
-  void testAssertEqualityEagerSplit() {
-    Node x = d_nm->mkVar(*d_realType);
-    Node c = d_nm->mkConst<Rational>(d_zero);
-
-    Node eq = d_nm->mkNode(EQUAL, x, c);
-    Node expectedDisjunct = simulateSplit(x,c);
-
-    Node rEq = fakeTheoryEnginePreprocess(eq);
-
-    d_arith->assertFact(rEq);
-
-    d_arith->check(d_level);
-
-    TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 1u);
-
-    TS_ASSERT_EQUALS(d_outputChannel.getIthNode(0), expectedDisjunct);
-    TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(0), AUG_LEMMA);
-
-  }
-  void testLtRewrite() {
-    Node x = d_nm->mkVar(*d_realType);
-    Node c = d_nm->mkConst<Rational>(d_zero);
-
-    Node lt = d_nm->mkNode(LT, x, c);
-    Node geq = d_nm->mkNode(GEQ, x, c);
-    Node expectedRewrite = d_nm->mkNode(NOT, geq);
-
-    Node rewrite = d_arith->rewrite(lt);
-
-    TS_ASSERT_EQUALS(expectedRewrite, rewrite);
-  }
-
   void testBasicConflict() {
     Node x = d_nm->mkVar(*d_realType);
     Node c = d_nm->mkConst<Rational>(d_zero);
 
     Node eq = d_nm->mkNode(EQUAL, x, c);
-    Node lt = d_nm->mkNode(LT, x, c);
+    Node lt = d_nm->mkNode(NOT, d_nm->mkNode(GEQ, x, c));
     Node expectedDisjunct = simulateSplit(x,c);
 
-    Node rEq = fakeTheoryEnginePreprocess(eq);
-    Node rLt = fakeTheoryEnginePreprocess(lt);
+    fakeTheoryEnginePreprocess(eq);
+    fakeTheoryEnginePreprocess(lt);
 
-    d_arith->assertFact(rEq);
-    d_arith->assertFact(rLt);
+    d_arith->assertFact(eq);
+    d_arith->assertFact(lt);
 
 
     d_arith->check(d_level);
@@ -198,7 +164,7 @@ public:
 
     TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(1), CONFLICT);
 
-    Node expectedClonflict = d_nm->mkNode(AND, rEq, rLt);
+    Node expectedClonflict = d_nm->mkNode(AND, eq, lt);
 
     TS_ASSERT_EQUALS(d_outputChannel.getIthNode(1), expectedClonflict);
   }
@@ -208,13 +174,13 @@ public:
     Node c = d_nm->mkConst<Rational>(d_zero);
 
     Node eq = d_nm->mkNode(EQUAL, x, c);
-    Node lt = d_nm->mkNode(LT, x, c);
+    Node lt = d_nm->mkNode(NOT, d_nm->mkNode(GEQ, x, c));
     Node expectedDisjunct = simulateSplit(x,c);
 
-    Node rEq = fakeTheoryEnginePreprocess(eq);
-    Node rLt = fakeTheoryEnginePreprocess(lt);
+    fakeTheoryEnginePreprocess(eq);
+    fakeTheoryEnginePreprocess(lt);
 
-    d_arith->assertFact(rEq);
+    d_arith->assertFact(eq);
 
 
     d_arith->check(d_level);
@@ -236,29 +202,30 @@ public:
 
     Node leq0 = d_nm->mkNode(LEQ, x, c0);
     Node leq1 = d_nm->mkNode(LEQ, x, c1);
-    Node lt1 = d_nm->mkNode(LT, x, c1);
+    Node geq1 = d_nm->mkNode(GEQ, x, c1);
+    Node lt1 = d_nm->mkNode(NOT, geq1);
 
-    Node rLeq0 = fakeTheoryEnginePreprocess(leq0);
-    Node rLt1 = fakeTheoryEnginePreprocess(lt1);
-    Node rLeq1 = fakeTheoryEnginePreprocess(leq1);
+    fakeTheoryEnginePreprocess(leq0);
+    fakeTheoryEnginePreprocess(leq1);
+    fakeTheoryEnginePreprocess(geq1);
 
-    d_arith->assertFact(rLt1);
+    d_arith->assertFact(lt1);
 
 
     d_arith->check(d_level);
     d_arith->propagate(d_level);
 
 #ifdef CVC4_ASSERTIONS
-    TS_ASSERT_THROWS(  d_arith->explain(rLeq0, d_level), AssertionException );
-    TS_ASSERT_THROWS(  d_arith->explain(rLt1, d_level), AssertionException );
+    TS_ASSERT_THROWS(  d_arith->explain(leq0, d_level), AssertionException );
+    TS_ASSERT_THROWS(  d_arith->explain(lt1, d_level), AssertionException );
 #endif
-    d_arith->explain(rLeq1, d_level);
+    d_arith->explain(leq1, d_level);
 
     TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 2u);
     TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(0), PROPAGATE);
     TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(1), EXPLANATION);
     TS_ASSERT_EQUALS(d_outputChannel.getIthNode(0), leq1);
-    TS_ASSERT_EQUALS(d_outputChannel.getIthNode(1), rLt1);
+    TS_ASSERT_EQUALS(d_outputChannel.getIthNode(1), lt1);
   }
 
 
@@ -269,24 +236,25 @@ public:
 
     Node leq0 = d_nm->mkNode(LEQ, x, c0);
     Node leq1 = d_nm->mkNode(LEQ, x, c1);
-    Node lt1 = d_nm->mkNode(LT, x, c1);
+    Node geq1 = d_nm->mkNode(GEQ, x, c1);
+    Node lt1 = d_nm->mkNode(NOT, geq1);
 
-    Node rLeq0 = fakeTheoryEnginePreprocess(leq0);
-    Node rLt1 = fakeTheoryEnginePreprocess(lt1);
-    Node rLeq1 = fakeTheoryEnginePreprocess(leq1);
+    fakeTheoryEnginePreprocess(leq0);
+    fakeTheoryEnginePreprocess(leq1);
+    fakeTheoryEnginePreprocess(geq1);
 
-    d_arith->assertFact(rLeq0);
+    d_arith->assertFact(leq0);
 
 
     d_arith->check(d_level);
     d_arith->propagate(d_level);
 
 
-    d_arith->explain(rLt1, d_level);
+    d_arith->explain(lt1, d_level);
 #ifdef CVC4_ASSERTIONS
-    TS_ASSERT_THROWS(  d_arith->explain(rLeq0, d_level), AssertionException );
+    TS_ASSERT_THROWS(  d_arith->explain(leq0, d_level), AssertionException );
 #endif
-    d_arith->explain(rLeq1, d_level);
+    d_arith->explain(leq1, d_level);
 
     TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 4u);
     TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(0), PROPAGATE);
@@ -294,12 +262,12 @@ public:
     TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(2), EXPLANATION);
     TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(3), EXPLANATION);
 
-    TS_ASSERT_EQUALS(d_outputChannel.getIthNode(1), rLt1);
-    TS_ASSERT_EQUALS(d_outputChannel.getIthNode(0), rLeq1);
+    TS_ASSERT_EQUALS(d_outputChannel.getIthNode(1), lt1);
+    TS_ASSERT_EQUALS(d_outputChannel.getIthNode(0), leq1);
 
 
-    TS_ASSERT_EQUALS(d_outputChannel.getIthNode(2), rLeq0);
-    TS_ASSERT_EQUALS(d_outputChannel.getIthNode(3), rLeq0);
+    TS_ASSERT_EQUALS(d_outputChannel.getIthNode(2), leq0);
+    TS_ASSERT_EQUALS(d_outputChannel.getIthNode(3), leq0);
   }
   void testTPLeq1() {
     Node x = d_nm->mkVar(*d_realType);
@@ -308,22 +276,23 @@ public:
 
     Node leq0 = d_nm->mkNode(LEQ, x, c0);
     Node leq1 = d_nm->mkNode(LEQ, x, c1);
-    Node lt1 = d_nm->mkNode(LT, x, c1);
+    Node geq1 = d_nm->mkNode(GEQ, x, c1);
+    Node lt1 = d_nm->mkNode(NOT, geq1);
 
-    Node rLeq0 = fakeTheoryEnginePreprocess(leq0);
-    Node rLt1 = fakeTheoryEnginePreprocess(lt1);
-    Node rLeq1 = fakeTheoryEnginePreprocess(leq1);
+    fakeTheoryEnginePreprocess(leq0);
+    fakeTheoryEnginePreprocess(leq1);
+    fakeTheoryEnginePreprocess(geq1);
 
-    d_arith->assertFact(rLeq1);
+    d_arith->assertFact(leq1);
 
 
     d_arith->check(d_level);
     d_arith->propagate(d_level);
 
 #ifdef CVC4_ASSERTIONS
-    TS_ASSERT_THROWS(  d_arith->explain(rLeq0, d_level), AssertionException );
-    TS_ASSERT_THROWS(  d_arith->explain(rLeq1, d_level), AssertionException );
-    TS_ASSERT_THROWS(  d_arith->explain(rLt1, d_level), AssertionException );
+    TS_ASSERT_THROWS(  d_arith->explain(leq0, d_level), AssertionException );
+    TS_ASSERT_THROWS(  d_arith->explain(leq1, d_level), AssertionException );
+    TS_ASSERT_THROWS(  d_arith->explain(lt1, d_level), AssertionException );
 #endif
 
     TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 0u);