From: Tim King Date: Mon, 13 Sep 2010 16:08:21 +0000 (+0000) Subject: * New normal form for arithmetic is in place. X-Git-Tag: cvc5-1.0.0~8872 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=0e18d60841c2a7cd5c079b6c0dacf5d61afb4835;p=cvc5.git * New normal form for arithmetic is in place. * src/theory/arith/normal_form.{h,cpp} contains the description for the new normal form as well as utilities for dealing with the normal form. * src/theory/arith/next_arith_rewriter.{h,cpp} contains the new rewriter. The new rewriter implements preRewrite() and postRewrite() for arithmetic. * src/theory/arith/arith_rewriter.{h,cpp} have been removed. * TheoryArith::rewrite() has been removed. * Arithmetic with the new normal form outperforms the trunk where the branch occurred (-r797) on 46% of the examples in QF_LRA. (33% have no noticeable difference.) Some important optimizations are stilling pending to the code for handling the new normal form. (Bug 196.) --- diff --git a/src/expr/node_manager.cpp b/src/expr/node_manager.cpp index d017ad799..37ed4fe20 100644 --- a/src/expr/node_manager.cpp +++ b/src/expr/node_manager.cpp @@ -235,6 +235,9 @@ TypeNode NodeManager::getType(TNode n, bool check) case kind::APPLY_UF: typeNode = CVC4::theory::uf::UfTypeRule::computeType(this, n, check); break; + case kind::IDENTITY: + typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n, check); + break; case kind::PLUS: typeNode = CVC4::theory::arith::ArithOperatorTypeRule::computeType(this, n, check); break; diff --git a/src/theory/arith/Makefile.am b/src/theory/arith/Makefile.am index e500f5cf8..ead39082c 100644 --- a/src/theory/arith/Makefile.am +++ b/src/theory/arith/Makefile.am @@ -7,8 +7,10 @@ noinst_LTLIBRARIES = libarith.la libarith_la_SOURCES = \ theory_arith_type_rules.h \ - arith_rewriter.h \ - arith_rewriter.cpp \ + next_arith_rewriter.h \ + next_arith_rewriter.cpp \ + normal_form.h\ + normal_form.cpp \ arith_utilities.h \ arith_constants.h \ arith_activity.h \ diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp deleted file mode 100644 index ba1445df8..000000000 --- a/src/theory/arith/arith_rewriter.cpp +++ /dev/null @@ -1,557 +0,0 @@ -/********************* */ -/*! \file arith_rewriter.cpp - ** \verbatim - ** Original author: taking - ** Major contributors: none - ** Minor contributors (to current version): mdeters - ** This file is part of the CVC4 prototype. - ** Copyright (c) 2009, 2010 The Analysis of Computer Systems Group (ACSys) - ** Courant Institute of Mathematical Sciences - ** New York University - ** See the file COPYING in the top-level source directory for licensing - ** information.\endverbatim - ** - ** \brief [[ Add one-line brief description here ]] - ** - ** [[ Add lengthier description here ]] - ** \todo document this file - **/ - - -#include "theory/arith/arith_rewriter.h" -#include "theory/arith/arith_utilities.h" - -#include -#include -#include - - -using namespace CVC4; -using namespace CVC4::theory; -using namespace CVC4::theory::arith; - - - - - -Kind multKind(Kind k, int sgn); - -/** - * Performs a quick check to see if it is easy to rewrite to - * this normal form - * v |><| b - * Also writes relations with constants on both sides to TRUE or FALSE. - * If it can, it returns true and sets res to this value. - * - * This is for optimizing rewriteAtom() to avoid the more computationally - * expensive general rewriting procedure. - * - * If simplification is not done, it returns Node::null() - */ -Node almostVarOrConstEqn(TNode atom, Kind k, TNode left, TNode right){ - Assert(atom.getKind() == k); - Assert(isRelationOperator(k)); - Assert(atom[0] == left); - Assert(atom[1] == right); - bool leftIsConst = left.getMetaKind() == kind::metakind::CONSTANT; - bool rightIsConst = right.getMetaKind() == kind::metakind::CONSTANT; - - bool leftIsVar = left.getMetaKind() == kind::metakind::VARIABLE; - bool rightIsVar = right.getMetaKind() == kind::metakind::VARIABLE; - - if(leftIsConst && rightIsConst){ - Rational lc = coerceToRational(left); - Rational rc = coerceToRational(right); - bool res = evaluateConstantPredicate(k,lc, rc); - return mkBoolNode(res); - }else if(leftIsVar && rightIsConst){ - if(right.getKind() == kind::CONST_RATIONAL){ - return atom; - }else{ - return NodeManager::currentNM()->mkNode(k,left,coerceToRationalNode(right)); - } - }else if(leftIsConst && rightIsVar){ - if(left.getKind() == kind::CONST_RATIONAL){ - return NodeManager::currentNM()->mkNode(multKind(k,-1),right,left); - }else{ - Node q_left = coerceToRationalNode(left); - return NodeManager::currentNM()->mkNode(multKind(k,-1),right,q_left); - } - } - - return Node::null(); -} - -Node ArithRewriter::rewriteAtomCore(TNode atom){ - - Kind k = atom.getKind(); - Assert(isRelationOperator(k)); - - // left |><| right - TNode left = atom[0]; - TNode right = atom[1]; - - Node nf = almostVarOrConstEqn(atom, k,left,right); - if(nf != Node::null() ){ - return nf; - } - - - //Transform this to: (left- right) |><| 0 - Node diff = makeSubtractionNode(left, right); - - Node rewritten = rewrite(diff); - // rewritten =_{Reals} left - right => rewritten |><| 0 - - if(rewritten.getMetaKind() == kind::metakind::CONSTANT){ - // Case 1 rewritten : c - Rational c = rewritten.getConst(); - bool res = evaluateConstantPredicate(k, c, d_constants->d_ZERO); - nf = mkBoolNode(res); - }else if(rewritten.getMetaKind() == kind::metakind::VARIABLE){ - // Case 2 rewritten : v - nf = NodeManager::currentNM()->mkNode(k, rewritten, d_constants->d_ZERO_NODE); - }else{ - // Case 3 rewritten : (+ c p_1 p_2 ... p_N) | not(N=1 and c=0 and p_1.d=1) - Rational c = rewritten[0].getConst(); - c = -c; - TNode p_1 = rewritten[1]; - Rational d = p_1[0].getConst(); - d = d.inverse(); - c = c * d; - Node newRight = mkRationalNode(c); - Kind newKind = multKind(k, d.sgn()); - int N = rewritten.getNumChildren() - 1; - - if(N==1){ - int M = p_1.getNumChildren()-1; - if(M == 1){ // v |><| b - TNode v = p_1[1]; - nf = NodeManager::currentNM()->mkNode(newKind, v, newRight); - }else{ // p |><| b - Node newLeft = multPnfByNonZero(p_1, d); - nf = NodeManager::currentNM()->mkNode(newKind, newLeft, newRight); - } - }else{ //(+ p_1 .. p_N) |><| b - NodeBuilder<> plus(kind::PLUS); - for(int i=1; i<=N; ++i){ - TNode p_i = rewritten[i]; - plus << multPnfByNonZero(p_i, d); - } - Node newLeft = plus; - nf = NodeManager::currentNM()->mkNode(newKind, newLeft, newRight); - } - } - - return nf; -} - -Node ArithRewriter::rewriteAtom(TNode atom){ - Node rewritten = rewriteAtomCore(atom); - if(rewritten.getKind() == kind::LT){ - Node geq = NodeManager::currentNM()->mkNode(kind::GEQ, rewritten[0], rewritten[1]); - return NodeManager::currentNM()->mkNode(kind::NOT, geq); - }else if(rewritten.getKind() == kind::GT){ - Node leq = NodeManager::currentNM()->mkNode(kind::LEQ, rewritten[0], rewritten[1]); - return NodeManager::currentNM()->mkNode(kind::NOT, leq); - }else{ - return rewritten; - } -} - - -/* cmp( (* d v_1 v_2 ... v_M), (* d' v'_1 v'_2 ... v'_M'): - * if(M == M'): - * then tupleCompare(v_i, v'_i) - * else M -M' - */ -struct pnfLessThan { - bool operator()(Node p0, Node p1) { - int p0_M = p0.getNumChildren() -1; - int p1_M = p1.getNumChildren() -1; - if(p0_M == p1_M){ - for(int i=1; i<= p0_M; ++i){ - if(p0[i] != p1[i]){ - return p0[i] < p1[i]; - } - } - return false; //p0 == p1 in this order - }else{ - return p0_M < p1_M; - } - } -}; - -//Two pnfs are equal up to their coefficients -bool pnfsMatch(TNode p0, TNode p1){ - - unsigned M = p0.getNumChildren()-1; - if (M+1 != p1.getNumChildren()){ - return false; - } - - for(unsigned i=1; i <= M; ++i){ - if(p0[i] != p1[i]) - return false; - } - return true; -} - -Node addMatchingPnfs(TNode p0, TNode p1){ - Assert(pnfsMatch(p0,p1)); - - unsigned M = p0.getNumChildren()-1; - - Rational c0 = p0[0].getConst(); - Rational c1 = p1[0].getConst(); - - Rational addedC = c0 + c1; - Node newC = mkRationalNode(addedC); - NodeBuilder<> nb(kind::MULT); - nb << newC; - for(unsigned i=1; i <= M; ++i){ - nb << p0[i]; - } - Node newPnf = nb; - return newPnf; -} - -void ArithRewriter::sortAndCombineCoefficients(std::vector& pnfs){ - using namespace std; - - /* combined contains exactly 1 representative per for each pnf. - * This is maintained by combining the coefficients for pnfs. - * that is equal according to pnfLessThan. - */ - typedef set PnfSet; - PnfSet combined; - - for(vector::iterator i=pnfs.begin(); i != pnfs.end(); ++i){ - Node pnf = *i; - PnfSet::iterator pos = combined.find(pnf); - - if(pos == combined.end()){ - combined.insert(pnf); - }else{ - Node current = *pos; - Node sum = addMatchingPnfs(pnf, current); - combined.erase(pos); - combined.insert(sum); - } - } - pnfs.clear(); - for(PnfSet::iterator i=combined.begin(); i != combined.end(); ++i){ - Node pnf = *i; - if(pnf[0].getConst() != d_constants->d_ZERO){ - //after combination the coefficient may be zero - pnfs.push_back(pnf); - } - } -} - -Node ArithRewriter::var2pnf(TNode variable){ - return NodeManager::currentNM()->mkNode(kind::MULT,d_constants->d_ONE_NODE,variable); -} - -Node ArithRewriter::rewritePlus(TNode t){ - using namespace std; - - Rational accumulator; - vector pnfs; - - for(TNode::iterator i = t.begin(); i!= t.end(); ++i){ - TNode child = *i; - Node rewrittenChild = rewrite(child); - - if(rewrittenChild.getMetaKind() == kind::metakind::CONSTANT){//c - Rational c = rewrittenChild.getConst(); - accumulator = accumulator + c; - }else if(rewrittenChild.getMetaKind() == kind::metakind::VARIABLE){ //v - Node pnf = var2pnf(rewrittenChild); - pnfs.push_back(pnf); - }else{ //(+ c p_1 p_2 ... p_N) - Rational c = rewrittenChild[0].getConst(); - accumulator = accumulator + c; - int N = rewrittenChild.getNumChildren() - 1; - for(int i=1; i<=N; ++i){ - TNode pnf = rewrittenChild[i]; - pnfs.push_back(pnf); - } - } - } - sortAndCombineCoefficients(pnfs); - if(pnfs.size() == 0){ - return mkRationalNode(accumulator); - } - - // pnfs.size() >= 1 - - //Enforce not(N=1 and c=0 and p_1.d=1) - if(pnfs.size() == 1){ - Node p_1 = *(pnfs.begin()); - if(p_1[0].getConst() == d_constants->d_ONE){ - if(accumulator == d_constants->d_ZERO){ // 0 + (* 1 var) |-> var - Node var = p_1[1]; - return var; - } - } - } - - //We must be in this case - //(+ c p_1 p_2 ... p_N) | not(N=1 and c=0 and p_1.d=1) - - NodeBuilder<> nb(kind::PLUS); - nb << mkRationalNode(accumulator); - Debug("arithrewrite") << mkRationalNode(accumulator) << std::endl; - for(vector::iterator i = pnfs.begin(); i != pnfs.end(); ++i){ - nb << *i; - Debug("arithrewrite") << (*i) << std::endl; - - } - - Node normalForm = nb; - return normalForm; -} - -//Does not enforce -//5) v_i are of metakind VARIABLE, -//6) v_i are in increasing (not strict) nodeOrder, -Node toPnf(Rational& c, std::set& variables){ - NodeBuilder<> nb(kind::MULT); - nb << mkRationalNode(c); - - for(std::set::iterator i = variables.begin(); i != variables.end(); ++i){ - nb << *i; - } - Node pnf = nb; - return pnf; -} - -Node distribute(TNode n, TNode sum){ - NodeBuilder<> nb(kind::PLUS); - for(TNode::iterator i=sum.begin(); i!=sum.end(); ++i){ - Node prod = NodeManager::currentNM()->mkNode(kind::MULT, n, *i); - nb << prod; - } - return nb; -} -Node distributeSum(TNode sum, TNode distribSum){ - NodeBuilder<> nb(kind::PLUS); - for(TNode::iterator i=sum.begin(); i!=sum.end(); ++i){ - Node dist = distribute(*i, distribSum); - for(Node::iterator j=dist.begin(); j!=dist.end(); ++j){ - nb << *j; - } - } - return nb; -} - -Node ArithRewriter::rewriteMult(TNode t){ - - using namespace std; - - Rational accumulator(1,1); - set variables; - vector sums; - - //These stacks need to be kept in lock step - stack mult_iterators_nodes; - stack mult_iterators_iters; - - mult_iterators_nodes.push(t); - mult_iterators_iters.push(t.begin()); - - while(!mult_iterators_nodes.empty()){ - TNode mult = mult_iterators_nodes.top(); - TNode::const_iterator i = mult_iterators_iters.top(); - - mult_iterators_nodes.pop(); - mult_iterators_iters.pop(); - - for(; i != mult.end(); ++i){ - TNode child = *i; - if(child.getKind() == kind::MULT){ //TODO add not rewritten already checks - ++i; - mult_iterators_nodes.push(mult); - mult_iterators_iters.push(i); - - mult_iterators_nodes.push(child); - mult_iterators_iters.push(child.begin()); - break; - } - Node rewrittenChild = rewrite(child); - - if(rewrittenChild.getMetaKind() == kind::metakind::CONSTANT){//c - Rational c = rewrittenChild.getConst(); - accumulator = accumulator * c; - if(accumulator == d_constants->d_ZERO){ - return d_constants->d_ZERO_NODE; - } - }else if(rewrittenChild.getMetaKind() == kind::metakind::VARIABLE){ //v - variables.insert(rewrittenChild); - }else{ //(+ c p_1 p_2 ... p_N) - sums.push_back(rewrittenChild); - } - } - } - // accumulator * (\prod var_i) *(\prod sum_j) - - if(sums.size() == 0){ //accumulator * (\prod var_i) - if(variables.size() == 0){ //accumulator - return mkRationalNode(accumulator); - }else if(variables.size() == 1 && accumulator == d_constants->d_ONE){ // var_1 - Node var = *(variables.begin()); - return var; - }else{ - //We need to return (+ c p_1 p_2 ... p_N) - //To accomplish this: - // let pnf = pnf(accumulator * (\prod var_i)) in (+ 0 pnf) - Node pnf = toPnf(accumulator, variables); - Node normalForm = NodeManager::currentNM()->mkNode(kind::PLUS, d_constants->d_ZERO_NODE, pnf); - return normalForm; - } - }else{ - vector::iterator sum_iter = sums.begin(); - // \sum t - // t \in Q \cup A - // where A = lfp {\prod s | s \in Q \cup Variables \cup A} - Node distributed = *sum_iter; - ++sum_iter; - while(sum_iter != sums.end()){ - Node curr = *sum_iter; - distributed = distributeSum(curr, distributed); - ++sum_iter; - } - if(variables.size() >= 1){ - Node pnf = toPnf(accumulator, variables); - distributed = distribute(pnf, distributed); - }else{ - Node constant = mkRationalNode(accumulator); - distributed = distribute(constant, distributed); - } - - Node nf_distributed = rewrite(distributed); - return nf_distributed; - } -} - -Node ArithRewriter::rewriteDivByConstant(TNode t){ - Assert(t.getKind()== kind::DIVISION); - - Node left = t[0]; - Node reRight = rewrite(t[1]); - Assert(reRight.getKind()== kind::CONST_RATIONAL); - - - Rational den = reRight.getConst(); - - Assert(den != d_constants->d_ZERO); - - Rational div = den.inverse(); - - Node result = mkRationalNode(div); - - Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result); - - Node reMult = rewrite(mult); - - return reMult; -} - -Node ArithRewriter::rewriteTerm(TNode t){ - if(t.getMetaKind() == kind::metakind::CONSTANT){ - return coerceToRationalNode(t); - }else if(t.getMetaKind() == kind::metakind::VARIABLE){ - return t; - }else if(t.getKind() == kind::MULT){ - return rewriteMult(t); - }else if(t.getKind() == kind::PLUS){ - return rewritePlus(t); - }else if(t.getKind() == kind::DIVISION){ - return rewriteDivByConstant(t); - }else if(t.getKind() == kind::MINUS){ - Node sub = makeSubtractionNode(t[0],t[1]); - return rewrite(sub); - }else if(t.getKind() == kind::UMINUS){ - Node sub = makeUnaryMinusNode(t[0]); - return rewrite(sub); - }else{ - Unhandled(t); - } -} - - -/** - * Given a node in PNF pnf = (* d p_1 p_2 .. p_M) and a rational q != 0 - * constuct a node equal to q * pnf that is in pnf. - * - * The claim is that this is always okay: - * If d' = q*d, p' = (* d' p_1 p_2 .. p_M) =_{Reals} q * pnf. - */ -Node ArithRewriter::multPnfByNonZero(TNode pnf, Rational& q){ - Assert(q != d_constants->d_ZERO); - //TODO Assert(isPNF(pnf) ); - - int M = pnf.getNumChildren()-1; - Rational d = pnf[0].getConst(); - Rational new_d = d*q; - - - NodeBuilder<> mult(kind::MULT); - mult << mkRationalNode(new_d); - for(int i=1; i<=M; ++i){ - mult << pnf[i]; - } - - Node result = mult; - return result; -} - -Node ArithRewriter::makeUnaryMinusNode(TNode n){ - Node tmp = NodeManager::currentNM()->mkNode(kind::MULT,d_constants->d_NEGATIVE_ONE_NODE,n); - return tmp; -} - -Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){ - Node negR = makeUnaryMinusNode(r); - Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR); - - return diff; -} - - -Kind multKind(Kind k, int sgn){ - using namespace kind; - - if(sgn < 0){ - - switch(k){ - case LT: return GT; - case LEQ: return GEQ; - case EQUAL: return EQUAL; - case GEQ: return LEQ; - case GT: return LT; - default: - Unhandled(k); - } - return NULL_EXPR; - }else{ - return k; - } -} - -Node ArithRewriter::rewrite(TNode n){ - Debug("arithrewriter") << "Trace rewrite:" << n << std::endl; - - Node res; - - if(isRelationOperator(n.getKind())){ - res = rewriteAtom(n); - }else{ - res = rewriteTerm(n); - } - - Debug("arithrewriter") << "Trace rewrite:" << n << "|->"<< res << std::endl; - - return res; -} diff --git a/src/theory/arith/arith_rewriter.h b/src/theory/arith/arith_rewriter.h deleted file mode 100644 index a76ee6e61..000000000 --- a/src/theory/arith/arith_rewriter.h +++ /dev/null @@ -1,123 +0,0 @@ -/********************* */ -/*! \file arith_rewriter.h - ** \verbatim - ** Original author: taking - ** Major contributors: mdeters - ** Minor contributors (to current version): none - ** This file is part of the CVC4 prototype. - ** Copyright (c) 2009, 2010 The Analysis of Computer Systems Group (ACSys) - ** Courant Institute of Mathematical Sciences - ** New York University - ** See the file COPYING in the top-level source directory for licensing - ** information.\endverbatim - ** - ** \brief [[ Add one-line brief description here ]] - ** - ** [[ Add lengthier description here ]] - ** \todo document this file - **/ - - -#include "expr/node.h" -#include "util/rational.h" -#include "theory/arith/arith_constants.h" - -#ifndef __CVC4__THEORY__ARITH__REWRITER_H -#define __CVC4__THEORY__ARITH__REWRITER_H - -namespace CVC4 { -namespace theory { -namespace arith { - - -/***********************************************/ -/***************** Normal Form *****************/ -/***********************************************/ -/***********************************************/ - -/** - * Normal form for predicates: - * TRUE - * FALSE - * v |><| b - * p |><| b - * (+ p_1 .. p_N) |><| b - * where - * 1) b is of type CONST_RATIONAL - * 2) |><| is of kind <, <=, =, >= or > - * 3) p, p_i is in PNF, - * 4) p.M >= 2 - * 5) p_i's are in strictly ascending = 2, - * 7) the kind of (+ p_1 .. p_N) is an N arity PLUS, - * 8) p.d, p_1.d are 1, - * 9) v has metakind variable, and - * - * PNF(t): - * (* d v_1 v_2 ... v_M) - * where - * 1) d is of type CONST_RATIONAL, - * 2) d != 0, - * 4) M>=1, - * 5) v_i are of metakind VARIABLE, - * 6) v_i are in increasing (not strict) nodeOrder, and - * 7) the kind of t is an M+1 arity MULT. - * - *

= 1 - * 5) the kind of (+ c p_1 p_2 ... p_N) is an N+1 arity PLUS, - * 6) and p_i's are in strictly & pnfs); - - -public: - ArithRewriter(ArithConstants* ac) : - d_constants(ac) - {} - Node rewrite(TNode t); - -}; - - -}; /* namesapce arith */ -}; /* namespace theory */ -}; /* namespace CVC4 */ - -#endif /* __CVC4__THEORY__ARITH__REWRITER_H */ diff --git a/src/theory/arith/arith_utilities.h b/src/theory/arith/arith_utilities.h index fa3356c60..6706ad76a 100644 --- a/src/theory/arith/arith_utilities.h +++ b/src/theory/arith/arith_utilities.h @@ -27,7 +27,7 @@ namespace CVC4 { namespace theory { namespace arith { -inline Node mkRationalNode(Rational& q){ +inline Node mkRationalNode(const Rational& q){ return NodeManager::currentNM()->mkConst(q); } @@ -87,6 +87,21 @@ inline bool isRelationOperator(Kind k){ } } +/** is k \in {LT, LEQ, EQ, GEQ, GT} */ +inline Kind negateRelationKind(Kind k){ + using namespace kind; + + switch(k){ + case LT: return GT; + case LEQ: return GEQ; + case EQUAL: return EQUAL; + case GEQ: return LEQ; + case GT: return LT; + + default: + Unreachable(); + } +} inline bool evaluateConstantPredicate(Kind k, const Rational& left, const Rational& right){ using namespace kind; diff --git a/src/theory/arith/kinds b/src/theory/arith/kinds index 99f7258da..07d48b1f6 100644 --- a/src/theory/arith/kinds +++ b/src/theory/arith/kinds @@ -12,6 +12,8 @@ operator MINUS 2 "arithmetic binary subtraction operator" operator UMINUS 1 "arithmetic unary negation" operator DIVISION 2 "arithmetic division" +operator IDENTITY 1 "identity function" + constant CONST_RATIONAL \ ::CVC4::Rational \ ::CVC4::RationalHashStrategy \ diff --git a/src/theory/arith/next_arith_rewriter.cpp b/src/theory/arith/next_arith_rewriter.cpp new file mode 100644 index 000000000..c14f806c9 --- /dev/null +++ b/src/theory/arith/next_arith_rewriter.cpp @@ -0,0 +1,326 @@ +/********************* */ +/*! \file arith_rewriter.cpp + ** \verbatim + ** Original author: taking + ** Major contributors: none + ** Minor contributors (to current version): mdeters + ** This file is part of the CVC4 prototype. + ** Copyright (c) 2009, 2010 The Analysis of Computer Systems Group (ACSys) + ** Courant Institute of Mathematical Sciences + ** New York University + ** See the file COPYING in the top-level source directory for licensing + ** information.\endverbatim + ** + ** \brief [[ Add one-line brief description here ]] + ** + ** [[ Add lengthier description here ]] + ** \todo document this file + **/ + + +#include "theory/theory.h" +#include "theory/arith/normal_form.h" +#include "theory/arith/next_arith_rewriter.h" +#include "theory/arith/arith_utilities.h" + +#include +#include +#include + + +using namespace CVC4; +using namespace CVC4::theory; +using namespace CVC4::theory::arith; + +bool isVariable(TNode t){ + return t.getMetaKind() == kind::metakind::VARIABLE; +} + +RewriteResponse NextArithRewriter::rewriteConstant(TNode t){ + Assert(t.getMetaKind() == kind::metakind::CONSTANT); + Node val = coerceToRationalNode(t); + + return RewriteComplete(val); +} + +RewriteResponse NextArithRewriter::rewriteVariable(TNode t){ + Assert(isVariable(t)); + + return RewriteComplete(t); +} + +RewriteResponse NextArithRewriter::rewriteMinus(TNode t, bool pre){ + Assert(t.getKind()== kind::MINUS); + + if(t[0] == t[1]) return RewriteComplete(d_constants->d_ZERO_NODE); + + Node noMinus = makeSubtractionNode(t[0],t[1]); + if(pre){ + return RewriteComplete(noMinus); + }else{ + return FullRewriteNeeded(noMinus); + } +} + +RewriteResponse NextArithRewriter::rewriteUMinus(TNode t, bool pre){ + Assert(t.getKind()== kind::UMINUS); + + Node noUminus = makeUnaryMinusNode(t[0]); + if(pre) + return RewriteComplete(noUminus); + else + return RewriteAgain(noUminus); +} + +RewriteResponse NextArithRewriter::preRewriteTerm(TNode t){ + if(t.getMetaKind() == kind::metakind::CONSTANT){ + return rewriteConstant(t); + }else if(isVariable(t)){ + return rewriteVariable(t); + }else if(t.getKind() == kind::MINUS){ + return rewriteMinus(t, true); + }else if(t.getKind() == kind::UMINUS){ + return rewriteUMinus(t, true); + }else if(t.getKind() == kind::DIVISION){ + if(t[0].getKind()== kind::CONST_RATIONAL){ + return rewriteDivByConstant(t, true); + }else{ + return RewriteComplete(t); + } + }else if(t.getKind() == kind::PLUS){ + return preRewritePlus(t); + }else if(t.getKind() == kind::MULT){ + return preRewriteMult(t); + }else{ + Unreachable(); + } +} +RewriteResponse NextArithRewriter::postRewriteTerm(TNode t){ + if(t.getMetaKind() == kind::metakind::CONSTANT){ + return rewriteConstant(t); + }else if(isVariable(t)){ + return rewriteVariable(t); + }else if(t.getKind() == kind::MINUS){ + return rewriteMinus(t, false); + }else if(t.getKind() == kind::UMINUS){ + return rewriteUMinus(t, false); + }else if(t.getKind() == kind::DIVISION){ + return rewriteDivByConstant(t, false); + }else if(t.getKind() == kind::PLUS){ + return postRewritePlus(t); + }else if(t.getKind() == kind::MULT){ + return postRewriteMult(t); + }else{ + Unreachable(); + } +} + +RewriteResponse NextArithRewriter::preRewriteMult(TNode t){ + Assert(t.getKind()== kind::MULT); + + // Rewrite multiplications with a 0 argument and to 0 + Integer intZero; + + for(TNode::iterator i = t.begin(); i != t.end(); ++i) { + if((*i).getKind() == kind::CONST_RATIONAL) { + if((*i).getConst() == d_constants->d_ZERO) { + return RewriteComplete(d_constants->d_ZERO_NODE); + } + } else if((*i).getKind() == kind::CONST_INTEGER) { + if((*i).getConst() == intZero) { + if(t.getType().isInteger()) { + return RewriteComplete(NodeManager::currentNM()->mkConst(intZero)); + } else { + return RewriteComplete(d_constants->d_ZERO_NODE); + } + } + } + } + return RewriteComplete(t); +} +RewriteResponse NextArithRewriter::preRewritePlus(TNode t){ + Assert(t.getKind()== kind::PLUS); + + return RewriteComplete(t); +} + +RewriteResponse NextArithRewriter::postRewritePlus(TNode t){ + Assert(t.getKind()== kind::PLUS); + + Polynomial res = Polynomial::mkZero(); + + for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){ + Node curr = *i; + Polynomial currPoly = Polynomial::parsePolynomial(curr); + + res = res + currPoly; + } + + return RewriteComplete(res.getNode()); +} + +RewriteResponse NextArithRewriter::postRewriteMult(TNode t){ + Assert(t.getKind()== kind::MULT); + + Polynomial res = Polynomial::mkOne(); + + for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){ + Node curr = *i; + Polynomial currPoly = Polynomial::parsePolynomial(curr); + + res = res * currPoly; + } + + return RewriteComplete(res.getNode()); +} + +RewriteResponse NextArithRewriter::postRewriteAtomConstantRHS(TNode t){ + TNode left = t[0]; + TNode right = t[1]; + + + Comparison cmp = Comparison::mkComparison(t.getKind(), Polynomial::parsePolynomial(left), Constant(right)); + + if(cmp.isBoolean()){ + return RewriteComplete(cmp.getNode()); + } + + if(cmp.getLeft().containsConstant()){ + Monomial constantHead = cmp.getLeft().getHead(); + Assert(constantHead.isConstant()); + + Constant constant = constantHead.getConstant(); + + Constant negativeConstantHead = -constant; + + cmp = cmp.addConstant(negativeConstantHead); + } + Assert(!cmp.getLeft().containsConstant()); + + if(!cmp.getLeft().getHead().coefficientIsOne()){ + Monomial constantHead = cmp.getLeft().getHead(); + Assert(!constantHead.isConstant()); + Constant constant = constantHead.getConstant(); + + Constant inverse = Constant::mkConstant(constant.getValue().inverse()); + + cmp = cmp.multiplyConstant(inverse); + } + Assert(cmp.getLeft().getHead().coefficientIsOne()); + + Assert(cmp.isBoolean() || cmp.isNormalForm()); + return RewriteComplete(cmp.getNode()); +} + +RewriteResponse NextArithRewriter::postRewriteAtom(TNode atom){ + // left |><| right + TNode left = atom[0]; + TNode right = atom[1]; + + if(right.getMetaKind() == kind::metakind::CONSTANT){ + return postRewriteAtomConstantRHS(atom); + }else{ + //Transform this to: (left - right) |><| 0 + Node diff = makeSubtractionNode(left, right); + Node reduction = NodeManager::currentNM()->mkNode(atom.getKind(), diff, d_constants->d_ZERO_NODE); + return FullRewriteNeeded(reduction); + } +} + +RewriteResponse NextArithRewriter::preRewriteAtom(TNode atom){ + Assert(isAtom(atom)); + NodeManager* currNM = NodeManager::currentNM(); + + if(atom.getKind() == kind::EQUAL) { + if(atom[0] == atom[1]) { + return RewriteComplete(currNM->mkConst(true)); + } + } + + Node reduction = atom; + + if(atom[1].getMetaKind() != kind::metakind::CONSTANT){ + // left |><| right + TNode left = atom[0]; + TNode right = atom[1]; + + //Transform this to: (left - right) |><| 0 + Node diff = makeSubtractionNode(left, right); + reduction = currNM->mkNode(atom.getKind(), diff, d_constants->d_ZERO_NODE); + } + + if(reduction.getKind() == kind::GT){ + Node leq = currNM->mkNode(kind::LEQ, reduction[0], reduction[1]); + reduction = currNM->mkNode(kind::NOT, leq); + }else if(reduction.getKind() == kind::LT){ + Node geq = currNM->mkNode(kind::GEQ, reduction[0], reduction[1]); + reduction = currNM->mkNode(kind::NOT, geq); + } + + return RewriteComplete(reduction); +} + +RewriteResponse NextArithRewriter::postRewrite(TNode t){ + if(isTerm(t)){ + RewriteResponse response = postRewriteTerm(t); + if(Debug.isOn("arith::rewriter") && response.isDone()) { + Polynomial::parsePolynomial(response.getNode()); + } + return response; + }else if(isAtom(t)){ + RewriteResponse response = postRewriteAtom(t); + if(Debug.isOn("arith::rewriter") && response.isDone()) { + Comparison::parseNormalForm(response.getNode()); + } + return response; + }else{ + Unreachable(); + return RewriteComplete(Node::null()); + } +} + +RewriteResponse NextArithRewriter::preRewrite(TNode t){ + if(isTerm(t)){ + return preRewriteTerm(t); + }else if(isAtom(t)){ + return preRewriteAtom(t); + }else{ + Unreachable(); + return RewriteComplete(Node::null()); + } +} + +Node NextArithRewriter::makeUnaryMinusNode(TNode n){ + return NodeManager::currentNM()->mkNode(kind::MULT,d_constants->d_NEGATIVE_ONE_NODE,n); +} + +Node NextArithRewriter::makeSubtractionNode(TNode l, TNode r){ + Node negR = makeUnaryMinusNode(r); + Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR); + + return diff; +} + +RewriteResponse NextArithRewriter::rewriteDivByConstant(TNode t, bool pre){ + Assert(t.getKind()== kind::DIVISION); + + Node left = t[0]; + Node right = t[1]; + Assert(right.getKind()== kind::CONST_RATIONAL); + + + const Rational& den = right.getConst(); + + Assert(den != d_constants->d_ZERO); + + Rational div = den.inverse(); + + Node result = mkRationalNode(div); + + Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result); + if(pre){ + return RewriteComplete(mult); + }else{ + return RewriteAgain(mult); + } +} diff --git a/src/theory/arith/next_arith_rewriter.h b/src/theory/arith/next_arith_rewriter.h new file mode 100644 index 000000000..7f1ec0fbd --- /dev/null +++ b/src/theory/arith/next_arith_rewriter.h @@ -0,0 +1,74 @@ +/********************* */ +/*! \file arith_rewriter.h + ** \verbatim + ** Original author: taking + ** Major contributors: mdeters + ** Minor contributors (to current version): none + ** This file is part of the CVC4 prototype. + ** Copyright (c) 2009, 2010 The Analysis of Computer Systems Group (ACSys) + ** Courant Institute of Mathematical Sciences + ** New York University + ** See the file COPYING in the top-level source directory for licensing + ** information.\endverbatim + ** + ** \brief [[ Add one-line brief description here ]] + ** + ** [[ Add lengthier description here ]] + ** \todo document this file + **/ + +#include "theory/arith/arith_constants.h" +#include "theory/theory.h" +#include "theory/arith/normal_form.h" + +#ifndef __CVC4__THEORY__ARITH__REWRITER_NEXT_H +#define __CVC4__THEORY__ARITH__REWRITER_NEXT_H + +namespace CVC4 { +namespace theory { +namespace arith { + +class NextArithRewriter{ +private: + ArithConstants* d_constants; + + Node makeSubtractionNode(TNode l, TNode r); + Node makeUnaryMinusNode(TNode n); + + RewriteResponse preRewriteTerm(TNode t); + RewriteResponse postRewriteTerm(TNode t); + + RewriteResponse rewriteVariable(TNode t); + RewriteResponse rewriteConstant(TNode t); + RewriteResponse rewriteMinus(TNode t, bool pre); + RewriteResponse rewriteUMinus(TNode t, bool pre); + RewriteResponse rewriteDivByConstant(TNode t, bool pre); + + RewriteResponse preRewritePlus(TNode t); + RewriteResponse postRewritePlus(TNode t); + + RewriteResponse preRewriteMult(TNode t); + RewriteResponse postRewriteMult(TNode t); + + + RewriteResponse preRewriteAtom(TNode t); + RewriteResponse postRewriteAtom(TNode t); + RewriteResponse postRewriteAtomConstantRHS(TNode t); + +public: + NextArithRewriter(ArithConstants* ac) : d_constants(ac) {} + + RewriteResponse preRewrite(TNode n); + RewriteResponse postRewrite(TNode n); + +private: + bool isAtom(TNode n) const { return isRelationOperator(n.getKind()); } + bool isTerm(TNode n) const { return !isAtom(n); } +}; + + +}; /* namesapce arith */ +}; /* namespace theory */ +}; /* namespace CVC4 */ + +#endif /* __CVC4__THEORY__ARITH__REWRITER_NEXT_H */ diff --git a/src/theory/arith/normal_form.cpp b/src/theory/arith/normal_form.cpp new file mode 100644 index 000000000..18e31848b --- /dev/null +++ b/src/theory/arith/normal_form.cpp @@ -0,0 +1,250 @@ + +#include "theory/arith/normal_form.h" +#include + +using namespace std; +using namespace CVC4; +using namespace CVC4::theory; +using namespace CVC4::theory::arith; + +bool VarList::isSorted(iterator start, iterator end){ + return __gnu_cxx::is_sorted(start, end); +} + +bool VarList::isMember(Node n){ + if(n.getNumChildren() == 0){ + return Variable::isMember(n); + }else if(n.getKind() == kind::MULT){ + Node::iterator curr = n.begin(), end = n.end(); + Node prev = *curr; + if(!Variable::isMember(prev)) return false; + + while( (++curr) != end){ + if(!Variable::isMember(*curr)) return false; + if(!(prev <= *curr)) return false; + prev = *curr; + } + return true; + }else{ + return false; + } +} +int VarList::cmp(const VarList& vl) const{ + int dif = this->size() - vl.size(); + if (dif == 0){ + return this->getNode().getId() - vl.getNode().getId(); + }else if(dif < 0){ + return -1; + }else{ + return 1; + } +} + +VarList VarList::parseVarList(Node n){ + if(n.getNumChildren() == 0){ + return VarList(Variable(n)); + }else{ + Assert(n.getKind() == kind::MULT); + for(Node::iterator i=n.begin(), end = n.end(); i!=end; ++i){ + Assert(Variable::isMember(*i)); + } + return VarList(n); + } +} + +VarList VarList::operator*(const VarList& vl) const{ + if(this->empty()){ + return vl; + }else if(vl.empty()){ + return *this; + }else{ + vector result; + vector thisAsVec = this->asList(); + vector vlAsVec = vl.asList(); + back_insert_iterator > bii(result); + + merge(thisAsVec.begin(), thisAsVec.end(), vlAsVec.begin(), vlAsVec.end(), bii); + + return VarList::mkVarList(result); + } +} + +std::vector VarList::asList() const { + vector res; + for(iterator i = begin(), e = end(); i != e; ++i){ + res.push_back(*i); + } + return res; +} + +Monomial Monomial::mkMonomial(const Constant& c, const VarList& vl){ + if(c.isZero() || vl.empty() ){ + return Monomial(c); + }else if(c.isOne()){ + return Monomial(vl); + }else{ + return Monomial(c, vl); + } +} +Monomial Monomial::parseMonomial(Node n){ + if(n.getKind() == kind::CONST_RATIONAL){ + return Monomial(Constant(n)); + }else if(multStructured(n)){ + return Monomial::mkMonomial(Constant(n[0]),VarList::parseVarList(n[1])); + }else{ + return Monomial(VarList::parseVarList(n)); + } +} + +Monomial Monomial::operator*(const Monomial& mono) const { + Constant newConstant = this->getConstant() * mono.getConstant(); + VarList newVL = this->getVarList() * mono.getVarList(); + + return Monomial::mkMonomial(newConstant, newVL); +} + +vector Monomial::sumLikeTerms(const vector & monos){ + Assert(isSorted(monos)); + + Debug("blah") << "start sumLikeTerms" << std::endl; + printList(monos); + vector outMonomials; + typedef vector::const_iterator iterator; + for(iterator rangeIter = monos.begin(), end=monos.end(); rangeIter != end;){ + Rational constant = (*rangeIter).getConstant().getValue(); + VarList varList = (*rangeIter).getVarList(); + ++rangeIter; + while(rangeIter != end && varList == (*rangeIter).getVarList()){ + constant += (*rangeIter).getConstant().getValue(); + ++rangeIter; + } + if(constant != 0){ + Constant asConstant = Constant::mkConstant(constant); + Monomial nonZero = Monomial::mkMonomial(asConstant, varList); + outMonomials.push_back(nonZero); + } + } + Debug("blah") << "outmonomials" << std::endl; + printList(monos); + Debug("blah") << "end sumLikeTerms" << std::endl; + + Assert(isStrictlySorted(outMonomials)); + return outMonomials; +} + +void Monomial::printList(const std::vector& monos){ + typedef std::vector::const_iterator iterator; + for(iterator i = monos.begin(), end = monos.end(); i != end; ++i){ + Debug("blah") << ((*i).getNode()) << std::endl; + } +} + +Polynomial Polynomial::operator+(const Polynomial& vl) const{ + this->printList(); + vl.printList(); + + std::vector sortedMonos; + std::back_insert_iterator > bii(sortedMonos); + std::merge(begin(), end(), vl.begin(), vl.end(), bii); + + std::vector combined = Monomial::sumLikeTerms(sortedMonos); + + Polynomial result = mkPolynomial(combined); + result.printList(); + return result; +} + +Polynomial Polynomial::operator*(const Monomial& mono) const{ + if(mono.isZero()){ + return Polynomial(mono); //Don't multiply by zero + }else{ + std::vector newMonos; + for(iterator i = this->begin(), end = this->end(); i != end; ++i){ + newMonos.push_back(mono * (*i)); + } + return Polynomial::mkPolynomial(newMonos); + } +} + +Polynomial Polynomial::operator*(const Polynomial& poly) const{ + + Polynomial res = Polynomial::mkZero(); + for(iterator i = this->begin(), end = this->end(); i != end; ++i){ + Monomial curr = *i; + Polynomial prod = poly * curr; + Polynomial sum = res + prod; + res = sum; + } + return res; +} + + +Node Comparison::toNode(Kind k, const Polynomial& l, const Constant& r){ + Assert(!l.isConstant()); + Assert(isRelationOperator(k)); + switch(k){ + case kind::GEQ: + case kind::EQUAL: + case kind::LEQ: + return NodeManager::currentNM()->mkNode(k, l.getNode(),r.getNode()); + case kind::LT: + return NodeManager::currentNM()->mkNode(kind::NOT, toNode(kind::GEQ,l,r)); + case kind::GT: + return NodeManager::currentNM()->mkNode(kind::NOT, toNode(kind::LEQ,l,r)); + default: + Unreachable(); + } +} + +Comparison Comparison::parseNormalForm(TNode n){ + if(n.getKind() == kind::CONST_BOOLEAN){ + return Comparison(n.getConst()); + }else{ + bool negated = n.getKind() == kind::NOT; + Node relation = negated ? n[0] : n; + Assert( !negated || + relation.getKind() == kind::LEQ || + relation.getKind() == kind::GEQ); + + Polynomial left = Polynomial::parsePolynomial(relation[0]); + Constant right(relation[1]); + + Kind newOperator = relation.getKind(); + if(negated){ + if(newOperator == kind::LEQ){ + newOperator = kind::GT; + }else{ + newOperator = kind::LT; + } + } + return Comparison(n, newOperator, left, right); + } +} + +Comparison Comparison::mkComparison(Kind k, const Polynomial& left, const Constant& right){ + Assert(isRelationOperator(k)); + if(left.isConstant()){ + const Rational& rConst = left.getNode().getConst(); + const Rational& lConst = right.getNode().getConst(); + bool res = evaluateConstantPredicate(k, lConst, rConst); + return Comparison(res); + }else{ + return Comparison(toNode(k, left, right), k, left, right); + } +} + +Comparison Comparison::addConstant(const Constant& constant) const{ + Assert(!isBoolean()); + Monomial mono(constant); + Polynomial constAsPoly( mono ); + Polynomial newLeft = getLeft() + constAsPoly; + Constant newRight = getRight() + constant; + return mkComparison(oper, newLeft, newRight); +} + +Comparison Comparison::multiplyConstant(const Constant& constant) const{ + Assert(!isBoolean()); + Kind newOper = (constant.getValue() < 0) ? negateRelationKind(oper) : oper; + + return mkComparison(newOper, left*Monomial(constant), right*constant); +} diff --git a/src/theory/arith/normal_form.h b/src/theory/arith/normal_form.h new file mode 100644 index 000000000..1f7bc6be3 --- /dev/null +++ b/src/theory/arith/normal_form.h @@ -0,0 +1,613 @@ + +#include "expr/node.h" +#include "util/rational.h" +#include "theory/arith/arith_constants.h" +#include "theory/arith/arith_utilities.h" + +#include +#include +#include + +#ifndef __CVC4__THEORY__ARITH__NORMAL_FORM_H +#define __CVC4__THEORY__ARITH__NORMAL_FORM_H + +namespace CVC4 { +namespace theory { +namespace arith { + +/***********************************************/ +/***************** Normal Form *****************/ +/***********************************************/ +/***********************************************/ + +/** + * Section 1: Languages + * The normal form for arithmetic nodes is defined by the language + * accepted by the following BNFs with some guard conditions. + * (The guard conditions are in Section 3 for completeness.) + * + * variable := n + * where + * n.getMetaKind() == metakind::VARIABLE + + * constant := n + * where + * n.getKind() == kind::CONST_RATIONAL + + * var_list := variable | (* [variable]) + * where + * len [variable] >= 2 + * isSorted varOrder [variable] + + * monomial := constant | var_list | (* constant' var_list') + * where + * constant' \not\in {0,1} + + * polynomial := monomial | (+ [monomial]) + * where + * len [monomial] >= 2 + * isStrictlySorted monoOrder [monomial] + + * restricted_cmp := (|><| polynomial constant) + * where + * |><| is GEQ, EQ, or EQ + * not (exists constantMonomial (monomialList polynomial)) + * monomialCoefficient (head (monomialList polynomial)) == 1 + + * comparison := TRUE | FALSE | restricted_cmp | (not restricted_cmp) + + * Normal Form for terms := polynomial + * Normal Form for atoms := comparison + */ + +/** + * Section 2: Helper Classes + * The langauges accepted by each of these defintions + * roughly corresponds to one of the following helper classes: + * Variable + * Constant + * VarList + * Monomial + * Polynomial + * Comparison + * + * Each of the classes obeys the following contracts/design decisions: + * -Calling isMember(Node node) on a node returns true iff that node is a + * a member of the language. Note: isMember is O(n). + * -Calling isNormalForm() on a helper class object returns true iff that + * helper class currently represents a normal form object. + * -If isNormalForm() is false, then this object must have been made + * using a mk*() factory function. + * -If isNormalForm() is true, calling getNode() on all of these classes + * returns a node that would be accepted by the corresponding language. + * And if isNormalForm() is false, returns Node::null(). + * -Each of the classes is immutable. + * -Public facing constuctors have a 1-to-1 correspondence with one of + * production rules in the above grammar. + * -Public facing constuctors are required to fail in debug mode when the + * guards of the production rule are not strictly met. + * For example: Monomial(Constant(1),VarList(Variable(x))) must fail. + * -When a class has a Class parseClass(Node node) function, + * if isMember(node) is true, the function is required to return an instance + * of the helper class, instance, s.t. instance.getNode() == node. + * And if isMember(node) is false, this throws an assertion failure in debug + * mode and has undefined behaviour if not in debug mode. + * -Only public facing constructors, parseClass(node), and mk*() functions are + * considered privledged functions for the helper class. + * -Only privledged functions may use private constructors, and access + * private data members. + * -All non-privledges functions are considered utility functions and + * must use a privledged function in order to create an instance of the class. + */ + +/** + * Section 3: Guard Conditions Misc. + * + * + * var_list_len vl = + * match vl with + * variable -> 1 + * | (* [variable]) -> len [variable] + * + * order res = + * match res with + * Empty -> (0,Node::null()) + * | NonEmpty(vl) -> (var_list_len vl, vl) + * + * var_listOrder a b = tuple_cmp (order a) (order b) + * + * monomialVarList monomial = + * match monomial with + * constant -> Empty + * | var_list -> NonEmpty(var_list) + * | (* constant' var_list') -> NonEmpty(var_list') + * + * monoOrder m0 m1 = var_listOrder (monomialVarList m0) (monomialVarList m1) + * + * constantMonomial monomial = + * match monomial with + * constant -> true + * | var_list -> false + * | (* constant' var_list') -> false + * + * monomialCoefficient monomial = + * match monomial with + * constant -> constant + * | var_list -> Constant(1) + * | (* constant' var_list') -> constant' + * + * monomialList polynomial = + * match polynomial with + * monomial -> monomial::[] + * | (+ [monomial]) -> [monomial] + */ + +/** + * A NodeWrapper is a class that is a thinly veiled container of a Node object. + */ +class NodeWrapper { +private: + Node node; +public: + NodeWrapper(Node n) : node(n) {} + const Node& getNode() const { return node; } +}; + +class Variable : public NodeWrapper { +public: + Variable(Node n) : NodeWrapper(n) { + Assert(isMember(getNode())); + } + + static bool isMember(Node n) { + return n.getMetaKind() == kind::metakind::VARIABLE; + } + + bool isNormalForm() { return isMember(getNode()); } + + bool operator<(const Variable& v) const{ return getNode() < v.getNode();} + bool operator==(const Variable& v) const{ return getNode() == v.getNode();} + +}; + +class Constant : public NodeWrapper { +public: + Constant(Node n) : NodeWrapper(n) { + Assert(isMember(getNode())); + } + + static bool isMember(Node n) { + return n.getKind() == kind::CONST_RATIONAL; + } + + bool isNormalForm() { return isMember(getNode()); } + + static Constant mkConstant(Node n) { + return Constant(coerceToRationalNode(n)); + } + + static Constant mkConstant(const Rational& rat){ + return Constant(mkRationalNode(rat)); + } + + const Rational& getValue() const { + return getNode().getConst(); + } + + bool isZero() const{ return getValue() == 0; } + bool isOne() const{ return getValue() == 1; } + + Constant operator*(const Constant& other) const{ + return mkConstant(getValue() * other.getValue()); + } + Constant operator+(const Constant& other) const{ + return mkConstant(getValue() + other.getValue()); + } + Constant operator-() const{ + return mkConstant(-getValue()); + } +}; + +template +inline Node makeNode(Kind k, GetNodeIterator start, GetNodeIterator end){ + NodeBuilder<> nb(k); + + while(start != end){ + nb << (*start).getNode(); + ++start; + } + return Node(nb); +} + +/** + * A VarList is a sorted list of variables representing a product. + * If the VarList is empty, it represents an empty product or 1. + * If the VarList has size 1, it represents a single variable. + * + * A non-sorted VarList can never be successfully made in debug mode. + */ +class VarList { +private: + Node backingNode; + + static Node multList(const std::vector& list){ + Assert(list.size() >= 2); + + return makeNode(kind::MULT, list.begin(), list.end()); + } + static Node makeTuple(Node n){ + return NodeManager::currentNM()->mkNode(kind::IDENTITY, n); + } + + VarList() : backingNode(Node::null()){} + + VarList(Node n){ + backingNode = (Variable::isMember(n)) ? makeTuple(n) : n; + + Assert(isSorted(begin(), end())); + } + +public: + class iterator { + private: + Node::iterator d_iter; + public: + explicit iterator(Node::iterator i) : d_iter(i) {} + + inline Variable operator*(){ + return Variable(*d_iter); + } + + bool operator==(const iterator& i){ + return d_iter == i.d_iter; + } + + bool operator!=(const iterator& i){ + return d_iter != i.d_iter; + } + + iterator operator++() { + ++d_iter; + return *this; + } + + iterator operator++(int) { + return iterator(d_iter++); + } + }; + + Node getNode() const{ + if(singleton()){ + return backingNode[0]; + }else{ + return backingNode; + } + } + + iterator begin() const{ + return iterator(backingNode.begin()); + } + iterator end() const{ + return iterator(backingNode.end()); + } + + VarList(Variable v) : backingNode(makeTuple(v.getNode())){ + Assert(isSorted(begin(), end())); + } + VarList(const std::vector& l) : backingNode(multList(l)){ + Assert(l.size() >= 2); + Assert(isSorted(begin(), end())); + } + + static bool isMember(Node n); + + bool isNormalForm() const{ + return !empty(); + } + + static VarList mkEmptyVarList(){ + return VarList(); + } + + + /** There are no restrictions on the size of l */ + static VarList mkVarList(const std::vector& l){ + if(l.size() == 0){ + return mkEmptyVarList(); + }else if(l.size() == 1){ + return VarList((*l.begin()).getNode()); + }else{ + return VarList(l); + } + } + + int size() const{ return backingNode.getNumChildren(); } + bool empty() const { return size() == 0; } + bool singleton() const { return backingNode.getKind() == kind::IDENTITY; } + + static VarList parseVarList(Node n); + + VarList operator*(const VarList& vl) const; + + int cmp(const VarList& vl) const; + + bool operator<(const VarList& vl) const{ return cmp(vl) < 0; } + + bool operator==(const VarList& vl) const{ return cmp(vl) == 0; } + + std::vector asList() const; + +private: + bool isSorted(iterator start, iterator end); +}; + +class Monomial : public NodeWrapper { +private: + Constant constant; + VarList varList; + Monomial(Node n, const Constant& c, const VarList& vl): + NodeWrapper(n), constant(c), varList(vl) + { + Assert(!c.isZero() || vl.empty() ); + Assert( c.isZero() || !vl.empty() ); + + Assert(!c.isOne() || !multStructured(n)); + } + + static Node makeMultNode(const Constant& c, const VarList& vl){ + Assert(!c.isZero()); + Assert(!c.isOne()); + Assert(!vl.empty()); + return NodeManager::currentNM()->mkNode(kind::MULT, c.getNode(), vl.getNode()); + } + + static bool multStructured(Node n){ + return n.getKind() == kind::MULT && + n[0].getKind() == kind::CONST_RATIONAL && + n.getNumChildren() == 2; + } + +public: + + Monomial(const Constant& c): + NodeWrapper(c.getNode()), constant(c), varList(VarList::mkEmptyVarList()) + { } + + Monomial(const VarList& vl): + NodeWrapper(vl.getNode()), constant(Constant::mkConstant(1)), varList(vl) + { + Assert( !varList.empty() ); + } + + Monomial(const Constant& c, const VarList& vl): + NodeWrapper(makeMultNode(c,vl)), constant(c), varList(vl) + { + Assert( !c.isZero() ); + Assert( !c.isOne() ); + Assert( !varList.empty() ); + + Assert(multStructured(getNode())); + } + + /** Makes a monomial with no restrictions on c and vl. */ + static Monomial mkMonomial(const Constant& c, const VarList& vl); + + + static Monomial parseMonomial(Node n); + + static Monomial mkZero(){ + return Monomial(Constant::mkConstant(0)); + } + static Monomial mkOne(){ + return Monomial(Constant::mkConstant(1)); + } + const Constant& getConstant() const{ return constant; } + const VarList& getVarList() const{ return varList; } + + bool isConstant() const{ + return varList.empty(); + } + + bool isZero() const{ + return constant.isZero(); + } + + bool coefficientIsOne() const { + return constant.isOne(); + } + + Monomial operator*(const Monomial& mono) const; + + + int cmp(const Monomial& mono) const{ + return getVarList().cmp(mono.getVarList()); + } + + bool operator<(const Monomial& vl) const{ + return cmp(vl) < 0; + } + + bool operator==(const Monomial& vl) const{ + return cmp(vl) == 0; + } + + static bool isSorted(const std::vector& m){ + return __gnu_cxx::is_sorted(m.begin(), m.end()); + } + + static bool isStrictlySorted(const std::vector& m){ + return isSorted(m) && std::adjacent_find(m.begin(),m.end()) == m.end(); + } + + /** + * Given a sorted list of monomials, this function transforms this + * into a strictly sorted list of monomials that does not contain zero. + */ + static std::vector sumLikeTerms(const std::vector& monos); + + static void printList(const std::vector& monos); +}; + +class Polynomial : public NodeWrapper { +private: + std::vector monos; + + Polynomial(Node n, const std::vector& m): + NodeWrapper(n), monos(m) + { + Assert( !monos.empty() ); + Assert( Monomial::isStrictlySorted(monos) ); + } + + static Node makePlusNode(const std::vector& m){ + Assert(m.size() >= 2); + + return makeNode(kind::PLUS, m.begin(), m.end()); + } + +public: + typedef std::vector::const_iterator iterator; + + iterator begin() const{ return monos.begin(); } + iterator end() const{ return monos.end(); } + + Polynomial(const Monomial& m): + NodeWrapper(m.getNode()), monos() + { + monos.push_back(m); + } + Polynomial(const std::vector& m): + NodeWrapper(makePlusNode(m)), monos(m) + { + Assert( monos.size() >= 2); + Assert( Monomial::isStrictlySorted(monos) ); + } + + + static Polynomial mkPolynomial(const std::vector& m){ + if(m.size() == 0){ + return Polynomial(Monomial::mkZero()); + }else if(m.size() == 1){ + return Polynomial((*m.begin())); + }else{ + return Polynomial(m); + } + } + + static Polynomial parsePolynomial(Node n){ + std::vector monos; + if(n.getKind() == kind::PLUS){ + for(Node::iterator i=n.begin(), end=n.end(); i != end; ++i){ + monos.push_back(Monomial::parseMonomial(*i)); + } + }else{ + monos.push_back(Monomial::parseMonomial(n)); + } + return Polynomial(n,monos); + } + + static Polynomial mkZero(){ + return Polynomial(Monomial::mkZero()); + } + static Polynomial mkOne(){ + return Polynomial(Monomial::mkOne()); + } + bool isZero() const{ + return (monos.size() == 1) && (getHead().isZero()); + } + + bool isConstant() const{ + return (monos.size() == 1) && (getHead().isConstant()); + } + + bool containsConstant() const{ + return getHead().isConstant(); + } + + Monomial getHead() const{ + return *(begin()); + } + + Polynomial getTail() const{ + Assert(monos.size() >= 1); + + iterator start = begin()+1; + std::vector subrange(start, end()); + return mkPolynomial(subrange); + } + + void printList() const{ + Debug("blah") << "start list" << std::endl; + Monomial::printList(monos); + Debug("blah") << "end list" << std::endl; + } + + Polynomial operator+(const Polynomial& vl) const; + + Polynomial operator*(const Monomial& mono) const; + + Polynomial operator*(const Polynomial& poly) const; + +}; + +class Comparison : public NodeWrapper { +private: + Kind oper; + Polynomial left; + Constant right; + + static Node toNode(Kind k, const Polynomial& l, const Constant& r); + + Comparison(TNode n, Kind k, const Polynomial& l, const Constant& r): + NodeWrapper(n), oper(k), left(l), right(r) + { } +public: + Comparison(bool val) : + NodeWrapper(NodeManager::currentNM()->mkConst(val)), + oper(kind::CONST_BOOLEAN), + left(Polynomial::mkZero()), + right(Constant::mkConstant(0)) + { } + + Comparison(Kind k, const Polynomial& l, const Constant& r): + NodeWrapper(toNode(k, l, r)), oper(k), left(l), right(r) + { + Assert(isRelationOperator(oper)); + Assert(!left.containsConstant()); + Assert(left.getHead().getConstant().isOne()); + } + + static Comparison mkComparison(Kind k, const Polynomial& left, const Constant& right); + + bool isBoolean() const{ + return (oper == kind::CONST_BOOLEAN); + } + + bool isNormalForm() const{ + if(isBoolean()){ + return true; + }else if(left.containsConstant()){ + return false; + }else if(left.getHead().getConstant().isOne()){ + return true; + }else{ + return false; + } + } + + const Polynomial& getLeft() const { return left; } + const Constant& getRight() const { return right; } + + Comparison addConstant(const Constant& constant) const; + Comparison multiplyConstant(const Constant& constant) const; + + static Comparison parseNormalForm(TNode n); +}; + + + +}; /* namesapce arith */ +}; /* namespace theory */ +}; /* namespace CVC4 */ + +#endif /* __CVC4__THEORY__ARITH__NORMAL_FORM_H */ diff --git a/src/theory/arith/tableau.h b/src/theory/arith/tableau.h index 12d93d9fe..603eb5278 100644 --- a/src/theory/arith/tableau.h +++ b/src/theory/arith/tableau.h @@ -23,7 +23,7 @@ #include "theory/arith/basic.h" #include "theory/arith/arith_activity.h" - +#include "theory/arith/normal_form.h" #include #include @@ -52,21 +52,21 @@ public: * Construct a row equal to: * basic = \sum_{x_i} c_i * x_i */ - Row(TNode basic, TNode sum): + Row(TNode basic, const Polynomial& sum): d_x_i(basic), d_coeffs(){ Assert(d_x_i.getMetaKind() == kind::metakind::VARIABLE); - Assert(sum.getKind() == kind::PLUS); - - for(TNode::iterator iter=sum.begin(); iter != sum.end(); ++iter){ - TNode pair = *iter; - Assert(pair.getKind() == kind::MULT); - Assert(pair.getNumChildren() == 2); - TNode coeff = pair[0]; - TNode var_i = pair[1]; + + for(Polynomial::iterator iter=sum.begin(), end = sum.end(); iter != end; ++iter){ + const Monomial& mono = *iter; + + Assert(!mono.isConstant()); + + TNode coeff = mono.getConstant().getNode(); + TNode var_i = mono.getVarList().getNode(); + Assert(coeff.getKind() == kind::CONST_RATIONAL); - Assert(var_i.getKind() == kind::VARIABLE); Assert(!has(var_i)); d_coeffs[var_i] = coeff.getConst(); @@ -192,14 +192,13 @@ private: public: void addRow(TNode eq){ - Assert(eq.getKind() == kind::EQUAL); - Assert(eq.getNumChildren() == 2); - TNode var = eq[0]; - TNode sum = eq[1]; + TNode sumNode = eq[1]; Assert(var.getAttribute(IsBasic())); + Polynomial sum = Polynomial::parsePolynomial(sumNode); + //The new basic variable cannot already be a basic variable Assert(!isActiveBasicVariable(var)); d_activeBasicVars.insert(var); @@ -208,13 +207,11 @@ public: //A variable in the row may have been made non-basic already. //If this is the case we fake pivoting this variable - for(TNode::iterator sumIter = sum.begin(); sumIter!=sum.end(); ++sumIter){ - TNode child = *sumIter; - Assert(child.getKind() == kind::MULT); - Assert(child.getNumChildren() == 2); - Assert(child[0].getKind() == kind::CONST_RATIONAL); - TNode c = child[1]; - Assert(var.getMetaKind() == kind::metakind::VARIABLE); + for(Polynomial::iterator sumIter = sum.begin(); sumIter!= sum.end(); ++sumIter){ + const Monomial& child = *sumIter; + + Assert(!child.isConstant()); + TNode c = child.getVarList().getNode(); if(isActiveBasicVariable(c)){ Row* row_c = lookup(c); row_var->subsitute(*row_c); diff --git a/src/theory/arith/theory_arith.cpp b/src/theory/arith/theory_arith.cpp index 157c45160..8f17b01a9 100644 --- a/src/theory/arith/theory_arith.cpp +++ b/src/theory/arith/theory_arith.cpp @@ -33,10 +33,12 @@ #include "theory/arith/basic.h" #include "theory/arith/arith_activity.h" -#include "theory/arith/arith_rewriter.h" +#include "theory/arith/next_arith_rewriter.h" #include "theory/arith/arith_propagator.h" #include "theory/arith/theory_arith.h" +#include "theory/arith/normal_form.h" + #include #include @@ -55,7 +57,7 @@ TheoryArith::TheoryArith(int id, context::Context* c, OutputChannel& out) : d_constants(NodeManager::currentNM()), d_partialModel(c), d_diseq(c), - d_rewriter(&d_constants), + d_nextRewriter(&d_constants), d_propagator(c), d_statistics() { @@ -109,22 +111,9 @@ bool isBasicSum(TNode n){ bool isNormalAtom(TNode n){ - if(!(n.getKind() == LEQ|| n.getKind() == GEQ || n.getKind() == EQUAL)){ - return false; - } - TNode left = n[0]; - TNode right = n[1]; - if(right.getKind() != CONST_RATIONAL){ - return false; - } - if(left.getMetaKind() == metakind::VARIABLE){ - return true; - }else if(isBasicSum(left)){ - return true; - }else{ - return false; - } + Comparison parse = Comparison::parseNormalForm(n); + return parse.isNormalForm(); } @@ -213,7 +202,6 @@ void TheoryArith::preRegisterTerm(TNode n) { if(left.getKind() == PLUS){ //We may need to introduce a slack variable. Assert(left.getNumChildren() >= 2); - Assert(isBasicSum(left)); if(!left.hasAttribute(Slack())){ setupSlack(left); } @@ -229,11 +217,9 @@ void TheoryArith::setupSlack(TNode left){ left.setAttribute(Slack(), slack); makeBasic(slack); - Node slackEqLeft = NodeManager::currentNM()->mkNode(EQUAL,slack,left); - - Debug("slack") << "slack " << slackEqLeft << endl; + Node eq = NodeManager::currentNM()->mkNode(kind::EQUAL, slack, left); - d_tableau.addRow(slackEqLeft); + d_tableau.addRow(eq); setupVariable(slack); } @@ -316,56 +302,9 @@ DeltaRational TheoryArith::computeRowValueUsingSavedAssignment(TNode x){ } RewriteResponse TheoryArith::preRewrite(TNode n, bool topLevel) { - // ensure a hard link to the node we're returning - Node out; - - // Look for multiplications with a 0 argument and rewrite the whole - // thing as 0 - if(n.getKind() == MULT) { - Rational ratZero; - Integer intZero; - for(TNode::iterator i = n.begin(); i != n.end(); ++i) { - if((*i).getKind() == CONST_RATIONAL) { - if((*i).getConst() == ratZero) { - out = NodeManager::currentNM()->mkConst(ratZero); - break; - } - } else if((*i).getKind() == CONST_INTEGER) { - if((*i).getConst() == intZero) { - if(n.getType().isInteger()) { - out = NodeManager::currentNM()->mkConst(intZero); - break; - } else { - out = NodeManager::currentNM()->mkConst(ratZero); - break; - } - } - } - } - } else if(n.getKind() == EQUAL) { - if(n[0] == n[1]) { - out = NodeManager::currentNM()->mkConst(true); - } - } - - if(out.isNull()) { - // no preRewrite to perform - return RewriteComplete(Node(n)); - } else { - // out is always a constant, so doesn't need to be rewritten again - return RewriteComplete(out); - } + return d_nextRewriter.preRewrite(n); } -Node TheoryArith::rewrite(TNode n){ - Debug("arith") << "rewrite(" << n << ")" << endl; - - Node result = d_rewriter.rewrite(n); - Debug("arith-rewrite") << "rewrite(" << n << ") -> " << result << endl; - return result; -} - - void TheoryArith::registerTerm(TNode tn){ Debug("arith") << "registerTerm(" << tn << ")" << endl; } diff --git a/src/theory/arith/theory_arith.h b/src/theory/arith/theory_arith.h index 7367f5726..03be7a77b 100644 --- a/src/theory/arith/theory_arith.h +++ b/src/theory/arith/theory_arith.h @@ -28,7 +28,7 @@ #include "theory/arith/delta_rational.h" #include "theory/arith/tableau.h" -#include "theory/arith/arith_rewriter.h" +#include "theory/arith/next_arith_rewriter.h" #include "theory/arith/partial_model.h" #include "theory/arith/arith_propagator.h" @@ -94,7 +94,7 @@ private: /** * The rewriter module for arithmetic. */ - ArithRewriter d_rewriter; + NextArithRewriter d_nextRewriter; ArithUnatePropagator d_propagator; @@ -102,11 +102,6 @@ public: TheoryArith(int id, context::Context* c, OutputChannel& out); ~TheoryArith(); - /** - * Rewrites a node to a unique normal form given in normal_form_notes.txt - */ - Node rewrite(TNode n); - /** * Rewriting optimizations. */ @@ -116,7 +111,7 @@ public: * Plug in old rewrite to the new (pre,post)rewrite interface. */ RewriteResponse postRewrite(TNode n, bool topLevel) { - return RewriteComplete(topLevel ? rewrite(n) : Node(n)); + return d_nextRewriter.postRewrite(n); } /** diff --git a/test/unit/theory/theory_arith_white.h b/test/unit/theory/theory_arith_white.h index ea1ee698f..763e03fdb 100644 --- a/test/unit/theory/theory_arith_white.h +++ b/test/unit/theory/theory_arith_white.h @@ -65,6 +65,32 @@ public: TheoryArithWhite() : d_level(Theory::FULL_EFFORT), d_zero(0), d_one(1), debug(false) {} + void fakeTheoryEnginePreprocess(TNode inp){ + Node rewrite = inp; //FIXME this needs to enforce that inp is fully rewritten already! + + if(debug) cout << rewrite << inp << endl; + + std::list toPreregister; + + toPreregister.push_back(rewrite); + for(std::list::iterator i = toPreregister.begin(); i != toPreregister.end(); ++i){ + Node n = *i; + preregistered->insert(n); + + for(Node::iterator citer = n.begin(); citer != n.end(); ++citer){ + Node c = *citer; + if(preregistered->find(c) == preregistered->end()){ + toPreregister.push_back(c); + } + } + } + for(std::list::reverse_iterator i = toPreregister.rbegin(); i != toPreregister.rend(); ++i){ + Node n = *i; + if(debug) cout << n.getId() << " "<< n << endl; + d_arith->preRegisterTerm(n); + } + } + void setUp() { d_ctxt = new Context; d_nm = new NodeManager(d_ctxt); @@ -92,42 +118,14 @@ public: delete d_ctxt; } - Node fakeTheoryEnginePreprocess(TNode inp){ - Node rewrite = d_arith->rewrite(inp); - - if(debug) cout << rewrite << inp << endl; - - std::list toPreregister; - - toPreregister.push_back(rewrite); - for(std::list::iterator i = toPreregister.begin(); i != toPreregister.end(); ++i){ - Node n = *i; - preregistered->insert(n); - - for(Node::iterator citer = n.begin(); citer != n.end(); ++citer){ - Node c = *citer; - if(preregistered->find(c) == preregistered->end()){ - toPreregister.push_back(c); - } - } - } - for(std::list::reverse_iterator i = toPreregister.rbegin(); i != toPreregister.rend(); ++i){ - Node n = *i; - if(debug) cout << n.getId() << " "<< n << endl; - d_arith->preRegisterTerm(n); - } - - return rewrite; - } - void testAssert() { Node x = d_nm->mkVar(*d_realType); Node c = d_nm->mkConst(d_zero); Node leq = d_nm->mkNode(LEQ, x, c); - Node rLeq = fakeTheoryEnginePreprocess(leq); + fakeTheoryEnginePreprocess(leq); - d_arith->assertFact(rLeq); + d_arith->assertFact(leq); d_arith->check(d_level); @@ -143,51 +141,19 @@ public: return dis; } - void testAssertEqualityEagerSplit() { - Node x = d_nm->mkVar(*d_realType); - Node c = d_nm->mkConst(d_zero); - - Node eq = d_nm->mkNode(EQUAL, x, c); - Node expectedDisjunct = simulateSplit(x,c); - - Node rEq = fakeTheoryEnginePreprocess(eq); - - d_arith->assertFact(rEq); - - d_arith->check(d_level); - - TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 1u); - - TS_ASSERT_EQUALS(d_outputChannel.getIthNode(0), expectedDisjunct); - TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(0), AUG_LEMMA); - - } - void testLtRewrite() { - Node x = d_nm->mkVar(*d_realType); - Node c = d_nm->mkConst(d_zero); - - Node lt = d_nm->mkNode(LT, x, c); - Node geq = d_nm->mkNode(GEQ, x, c); - Node expectedRewrite = d_nm->mkNode(NOT, geq); - - Node rewrite = d_arith->rewrite(lt); - - TS_ASSERT_EQUALS(expectedRewrite, rewrite); - } - void testBasicConflict() { Node x = d_nm->mkVar(*d_realType); Node c = d_nm->mkConst(d_zero); Node eq = d_nm->mkNode(EQUAL, x, c); - Node lt = d_nm->mkNode(LT, x, c); + Node lt = d_nm->mkNode(NOT, d_nm->mkNode(GEQ, x, c)); Node expectedDisjunct = simulateSplit(x,c); - Node rEq = fakeTheoryEnginePreprocess(eq); - Node rLt = fakeTheoryEnginePreprocess(lt); + fakeTheoryEnginePreprocess(eq); + fakeTheoryEnginePreprocess(lt); - d_arith->assertFact(rEq); - d_arith->assertFact(rLt); + d_arith->assertFact(eq); + d_arith->assertFact(lt); d_arith->check(d_level); @@ -198,7 +164,7 @@ public: TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(1), CONFLICT); - Node expectedClonflict = d_nm->mkNode(AND, rEq, rLt); + Node expectedClonflict = d_nm->mkNode(AND, eq, lt); TS_ASSERT_EQUALS(d_outputChannel.getIthNode(1), expectedClonflict); } @@ -208,13 +174,13 @@ public: Node c = d_nm->mkConst(d_zero); Node eq = d_nm->mkNode(EQUAL, x, c); - Node lt = d_nm->mkNode(LT, x, c); + Node lt = d_nm->mkNode(NOT, d_nm->mkNode(GEQ, x, c)); Node expectedDisjunct = simulateSplit(x,c); - Node rEq = fakeTheoryEnginePreprocess(eq); - Node rLt = fakeTheoryEnginePreprocess(lt); + fakeTheoryEnginePreprocess(eq); + fakeTheoryEnginePreprocess(lt); - d_arith->assertFact(rEq); + d_arith->assertFact(eq); d_arith->check(d_level); @@ -236,29 +202,30 @@ public: Node leq0 = d_nm->mkNode(LEQ, x, c0); Node leq1 = d_nm->mkNode(LEQ, x, c1); - Node lt1 = d_nm->mkNode(LT, x, c1); + Node geq1 = d_nm->mkNode(GEQ, x, c1); + Node lt1 = d_nm->mkNode(NOT, geq1); - Node rLeq0 = fakeTheoryEnginePreprocess(leq0); - Node rLt1 = fakeTheoryEnginePreprocess(lt1); - Node rLeq1 = fakeTheoryEnginePreprocess(leq1); + fakeTheoryEnginePreprocess(leq0); + fakeTheoryEnginePreprocess(leq1); + fakeTheoryEnginePreprocess(geq1); - d_arith->assertFact(rLt1); + d_arith->assertFact(lt1); d_arith->check(d_level); d_arith->propagate(d_level); #ifdef CVC4_ASSERTIONS - TS_ASSERT_THROWS( d_arith->explain(rLeq0, d_level), AssertionException ); - TS_ASSERT_THROWS( d_arith->explain(rLt1, d_level), AssertionException ); + TS_ASSERT_THROWS( d_arith->explain(leq0, d_level), AssertionException ); + TS_ASSERT_THROWS( d_arith->explain(lt1, d_level), AssertionException ); #endif - d_arith->explain(rLeq1, d_level); + d_arith->explain(leq1, d_level); TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 2u); TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(0), PROPAGATE); TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(1), EXPLANATION); TS_ASSERT_EQUALS(d_outputChannel.getIthNode(0), leq1); - TS_ASSERT_EQUALS(d_outputChannel.getIthNode(1), rLt1); + TS_ASSERT_EQUALS(d_outputChannel.getIthNode(1), lt1); } @@ -269,24 +236,25 @@ public: Node leq0 = d_nm->mkNode(LEQ, x, c0); Node leq1 = d_nm->mkNode(LEQ, x, c1); - Node lt1 = d_nm->mkNode(LT, x, c1); + Node geq1 = d_nm->mkNode(GEQ, x, c1); + Node lt1 = d_nm->mkNode(NOT, geq1); - Node rLeq0 = fakeTheoryEnginePreprocess(leq0); - Node rLt1 = fakeTheoryEnginePreprocess(lt1); - Node rLeq1 = fakeTheoryEnginePreprocess(leq1); + fakeTheoryEnginePreprocess(leq0); + fakeTheoryEnginePreprocess(leq1); + fakeTheoryEnginePreprocess(geq1); - d_arith->assertFact(rLeq0); + d_arith->assertFact(leq0); d_arith->check(d_level); d_arith->propagate(d_level); - d_arith->explain(rLt1, d_level); + d_arith->explain(lt1, d_level); #ifdef CVC4_ASSERTIONS - TS_ASSERT_THROWS( d_arith->explain(rLeq0, d_level), AssertionException ); + TS_ASSERT_THROWS( d_arith->explain(leq0, d_level), AssertionException ); #endif - d_arith->explain(rLeq1, d_level); + d_arith->explain(leq1, d_level); TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 4u); TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(0), PROPAGATE); @@ -294,12 +262,12 @@ public: TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(2), EXPLANATION); TS_ASSERT_EQUALS(d_outputChannel.getIthCallType(3), EXPLANATION); - TS_ASSERT_EQUALS(d_outputChannel.getIthNode(1), rLt1); - TS_ASSERT_EQUALS(d_outputChannel.getIthNode(0), rLeq1); + TS_ASSERT_EQUALS(d_outputChannel.getIthNode(1), lt1); + TS_ASSERT_EQUALS(d_outputChannel.getIthNode(0), leq1); - TS_ASSERT_EQUALS(d_outputChannel.getIthNode(2), rLeq0); - TS_ASSERT_EQUALS(d_outputChannel.getIthNode(3), rLeq0); + TS_ASSERT_EQUALS(d_outputChannel.getIthNode(2), leq0); + TS_ASSERT_EQUALS(d_outputChannel.getIthNode(3), leq0); } void testTPLeq1() { Node x = d_nm->mkVar(*d_realType); @@ -308,22 +276,23 @@ public: Node leq0 = d_nm->mkNode(LEQ, x, c0); Node leq1 = d_nm->mkNode(LEQ, x, c1); - Node lt1 = d_nm->mkNode(LT, x, c1); + Node geq1 = d_nm->mkNode(GEQ, x, c1); + Node lt1 = d_nm->mkNode(NOT, geq1); - Node rLeq0 = fakeTheoryEnginePreprocess(leq0); - Node rLt1 = fakeTheoryEnginePreprocess(lt1); - Node rLeq1 = fakeTheoryEnginePreprocess(leq1); + fakeTheoryEnginePreprocess(leq0); + fakeTheoryEnginePreprocess(leq1); + fakeTheoryEnginePreprocess(geq1); - d_arith->assertFact(rLeq1); + d_arith->assertFact(leq1); d_arith->check(d_level); d_arith->propagate(d_level); #ifdef CVC4_ASSERTIONS - TS_ASSERT_THROWS( d_arith->explain(rLeq0, d_level), AssertionException ); - TS_ASSERT_THROWS( d_arith->explain(rLeq1, d_level), AssertionException ); - TS_ASSERT_THROWS( d_arith->explain(rLt1, d_level), AssertionException ); + TS_ASSERT_THROWS( d_arith->explain(leq0, d_level), AssertionException ); + TS_ASSERT_THROWS( d_arith->explain(leq1, d_level), AssertionException ); + TS_ASSERT_THROWS( d_arith->explain(lt1, d_level), AssertionException ); #endif TS_ASSERT_EQUALS(d_outputChannel.getNumCalls(), 0u);