From 6c6f4e23aea405a812b1c6a3dd4d80696eb34741 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Thu, 16 Nov 2017 14:09:30 -0600 Subject: [PATCH] (Refactor) Arithmetic monomial sum (#1381) * Add arithmetic monomial sum utility. * Clang format * Fix * Address review. * Fix missed comment. * Format * Fix --- src/Makefile.am | 2 + src/smt/smt_engine.cpp | 4 +- src/theory/arith/arith_msum.cpp | 324 ++++++++++++++++++ src/theory/arith/arith_msum.h | 188 ++++++++++ src/theory/quantifiers/bounded_integers.cpp | 14 +- .../quantifiers/ce_guided_single_inv.cpp | 8 +- .../quantifiers/ce_guided_single_inv_sol.cpp | 10 - src/theory/quantifiers/equality_infer.cpp | 22 +- src/theory/quantifiers/extended_rewrite.cpp | 6 +- .../quantifiers/quantifiers_rewriter.cpp | 10 +- src/theory/quantifiers/relevant_domain.cpp | 17 +- src/theory/quantifiers/relevant_domain.h | 1 + .../quantifiers/term_database_sygus.cpp | 9 +- src/theory/quantifiers/term_util.cpp | 8 +- 14 files changed, 579 insertions(+), 44 deletions(-) create mode 100644 src/theory/arith/arith_msum.cpp create mode 100644 src/theory/arith/arith_msum.h diff --git a/src/Makefile.am b/src/Makefile.am index 7dcf73652..75fdd32ae 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -190,6 +190,8 @@ libcvc4_la_SOURCES = \ theory/arith/approx_simplex.h \ theory/arith/arith_ite_utils.cpp \ theory/arith/arith_ite_utils.h \ + theory/arith/arith_msum.cpp \ + theory/arith/arith_msum.h \ theory/arith/arith_rewriter.cpp \ theory/arith/arith_rewriter.h \ theory/arith/arith_static_learner.cpp \ diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp index e4ec57bb4..3e82a337c 100644 --- a/src/smt/smt_engine.cpp +++ b/src/smt/smt_engine.cpp @@ -83,6 +83,7 @@ #include "smt_util/boolean_simplification.h" #include "smt_util/nary_builder.h" #include "smt_util/node_visitor.h" +#include "theory/arith/arith_msum.h" #include "theory/arith/pseudoboolean_proc.h" #include "theory/booleans/circuit_propagator.h" #include "theory/bv/bvintropow2.h" @@ -2735,7 +2736,8 @@ Node SmtEnginePrivate::realToInt(TNode n, NodeMap& cache, std::vector< Node >& v Node ret_lit = ret.getKind()==kind::NOT ? ret[0] : ret; bool ret_pol = ret.getKind()!=kind::NOT; std::map< Node, Node > msum; - if( QuantArith::getMonomialSumLit( ret_lit, msum ) ){ + if (ArithMSum::getMonomialSumLit(ret_lit, msum)) + { //get common coefficient std::vector< Node > coeffs; for( std::map< Node, Node >::iterator itm = msum.begin(); itm != msum.end(); ++itm ){ diff --git a/src/theory/arith/arith_msum.cpp b/src/theory/arith/arith_msum.cpp new file mode 100644 index 000000000..46ee1cad5 --- /dev/null +++ b/src/theory/arith/arith_msum.cpp @@ -0,0 +1,324 @@ +/********************* */ +/*! \file arith_msum.cpp + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2017 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Implementation of arith_msum + **/ + +#include "theory/arith/arith_msum.h" + +#include "theory/rewriter.h" + +using namespace CVC4::kind; + +namespace CVC4 { +namespace theory { + +bool ArithMSum::getMonomial(Node n, Node& c, Node& v) +{ + if (n.getKind() == MULT && n.getNumChildren() == 2 && n[0].isConst()) + { + c = n[0]; + v = n[1]; + return true; + } + return false; +} + +bool ArithMSum::getMonomial(Node n, std::map& msum) +{ + if (n.isConst()) + { + if (msum.find(Node::null()) == msum.end()) + { + msum[Node::null()] = n; + return true; + } + } + else if (n.getKind() == MULT && n.getNumChildren() == 2 && n[0].isConst()) + { + if (msum.find(n[1]) == msum.end()) + { + msum[n[1]] = n[0]; + return true; + } + } + else + { + if (msum.find(n) == msum.end()) + { + msum[n] = Node::null(); + return true; + } + } + return false; +} + +bool ArithMSum::getMonomialSum(Node n, std::map& msum) +{ + if (n.getKind() == PLUS) + { + for (Node nc : n) + { + if (!getMonomial(nc, msum)) + { + return false; + } + } + return true; + } + return getMonomial(n, msum); +} + +bool ArithMSum::getMonomialSumLit(Node lit, std::map& msum) +{ + if (lit.getKind() == GEQ || lit.getKind() == EQUAL) + { + if (getMonomialSum(lit[0], msum)) + { + if (lit[1].isConst() && lit[1].getConst().isZero()) + { + return true; + } + else + { + // subtract the other side + std::map msum2; + NodeManager* nm = NodeManager::currentNM(); + if (getMonomialSum(lit[1], msum2)) + { + for (std::map::iterator it = msum2.begin(); + it != msum2.end(); + ++it) + { + std::map::iterator it2 = msum.find(it->first); + if (it2 != msum.end()) + { + Node r = nm->mkNode( + MINUS, + it2->second.isNull() ? nm->mkConst(Rational(1)) : it2->second, + it->second.isNull() ? nm->mkConst(Rational(1)) : it->second); + msum[it->first] = Rewriter::rewrite(r); + } + else + { + msum[it->first] = it->second.isNull() ? nm->mkConst(Rational(-1)) + : negate(it->second); + } + } + return true; + } + } + } + } + return false; +} + +Node ArithMSum::mkNode(const std::map& msum) +{ + NodeManager* nm = NodeManager::currentNM(); + std::vector children; + for (std::map::const_iterator it = msum.begin(); it != msum.end(); + ++it) + { + Node m; + if (!it->first.isNull()) + { + m = mkCoeffTerm(it->second, it->first); + } + else + { + Assert(!it->second.isNull()); + m = it->second; + } + children.push_back(m); + } + return children.size() > 1 + ? nm->mkNode(PLUS, children) + : (children.size() == 1 ? children[0] : nm->mkConst(Rational(0))); +} + +int ArithMSum::isolate( + Node v, const std::map& msum, Node& veq_c, Node& val, Kind k) +{ + std::map::const_iterator itv = msum.find(v); + if (itv != msum.end()) + { + std::vector children; + Rational r = + itv->second.isNull() ? Rational(1) : itv->second.getConst(); + if (r.sgn() != 0) + { + for (std::map::const_iterator it = msum.begin(); + it != msum.end(); + ++it) + { + if (it->first != v) + { + Node m; + if (!it->first.isNull()) + { + m = mkCoeffTerm(it->second, it->first); + } + else + { + m = it->second; + } + children.push_back(m); + } + } + val = children.size() > 1 + ? NodeManager::currentNM()->mkNode(PLUS, children) + : (children.size() == 1 + ? children[0] + : NodeManager::currentNM()->mkConst(Rational(0))); + if (!r.isOne() && !r.isNegativeOne()) + { + if (v.getType().isInteger()) + { + veq_c = NodeManager::currentNM()->mkConst(r.abs()); + } + else + { + val = NodeManager::currentNM()->mkNode( + MULT, + val, + NodeManager::currentNM()->mkConst(Rational(1) / r.abs())); + } + } + val = r.sgn() == 1 ? negate(val) : Rewriter::rewrite(val); + return (r.sgn() == 1 || k == EQUAL) ? 1 : -1; + } + } + return 0; +} + +int ArithMSum::isolate( + Node v, const std::map& msum, Node& veq, Kind k, bool doCoeff) +{ + Node veq_c; + Node val; + // isolate v in the (in)equality + int ires = isolate(v, msum, veq_c, val, k); + if (ires != 0) + { + Node vc = v; + if (!veq_c.isNull()) + { + if (doCoeff) + { + vc = NodeManager::currentNM()->mkNode(MULT, veq_c, vc); + } + else + { + return 0; + } + } + bool inOrder = ires == 1; + veq = NodeManager::currentNM()->mkNode( + k, inOrder ? vc : val, inOrder ? val : vc); + } + return ires; +} + +Node ArithMSum::solveEqualityFor(Node lit, Node v) +{ + Assert(lit.getKind() == EQUAL); + // first look directly at sides + TypeNode tn = lit[0].getType(); + for (unsigned r = 0; r < 2; r++) + { + if (lit[r] == v) + { + return lit[1 - r]; + } + } + if (tn.isReal()) + { + std::map msum; + if (ArithMSum::getMonomialSumLit(lit, msum)) + { + Node val, veqc; + if (ArithMSum::isolate(v, msum, veqc, val, EQUAL) != 0) + { + if (veqc.isNull()) + { + // in this case, we have an integer equality with a coefficient + // on the variable we solved for that could not be eliminated, + // hence we fail. + return val; + } + } + } + } + return Node::null(); +} + +bool ArithMSum::decompose(Node n, Node v, Node& coeff, Node& rem) +{ + std::map msum; + if (getMonomialSum(n, msum)) + { + std::map::iterator it = msum.find(v); + if (it == msum.end()) + { + return false; + } + else + { + coeff = it->second; + msum.erase(v); + rem = mkNode(msum); + return true; + } + } + return false; +} + +Node ArithMSum::negate(Node t) +{ + Node tt = NodeManager::currentNM()->mkNode( + MULT, NodeManager::currentNM()->mkConst(Rational(-1)), t); + tt = Rewriter::rewrite(tt); + return tt; +} + +Node ArithMSum::offset(Node t, int i) +{ + Node tt = NodeManager::currentNM()->mkNode( + PLUS, NodeManager::currentNM()->mkConst(Rational(i)), t); + tt = Rewriter::rewrite(tt); + return tt; +} + +void ArithMSum::debugPrintMonomialSum(std::map& msum, const char* c) +{ + for (std::map::iterator it = msum.begin(); it != msum.end(); ++it) + { + Trace(c) << " "; + if (!it->second.isNull()) + { + Trace(c) << it->second; + if (!it->first.isNull()) + { + Trace(c) << " * "; + } + } + if (!it->first.isNull()) + { + Trace(c) << it->first; + } + Trace(c) << std::endl; + } + Trace(c) << std::endl; +} + +} /* CVC4::theory namespace */ +} /* CVC4 namespace */ diff --git a/src/theory/arith/arith_msum.h b/src/theory/arith/arith_msum.h new file mode 100644 index 000000000..652a395cc --- /dev/null +++ b/src/theory/arith/arith_msum.h @@ -0,0 +1,188 @@ +/********************* */ +/*! \file arith_msum.h + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2017 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief arith_msum + **/ + +#include "cvc4_private.h" + +#ifndef __CVC4__THEORY__ARITH__MSUM_H +#define __CVC4__THEORY__ARITH__MSUM_H + +#include + +#include "expr/node.h" + +namespace CVC4 { +namespace theory { + +/** Arithmetic utilities regarding monomial sums. + * + * Note the following terminology: + * + * We say Node c is a {monomial constant} (or m-constant) if either: + * (a) c is a constant Rational, or + * (b) c is null. + * + * We say Node v is a {monomial variable} (or m-variable) if either: + * (a) v.getType().isReal() and v is not a constant, or + * (b) v is null. + * + * For m-constant or m-variable t, we write [t] to denote 1 if t.isNull() and + * t otherwise. + * + * A monomial m is a pair ( mvariable, mconstant ) of the form ( v, c ), which + * is interpreted as [c]*[v]. + * + * A {monmoial sum} msum is represented by a std::map< Node, Node > having + * key-value pairs of the form ( mvariable, mconstant ). + * It is interpreted as: + * [msum] = sum_{( v, c ) \in msum } [c]*[v] + * It is critical that this map is ordered so that operations like adding + * two monomial sums can be done efficiently. The ordering itself is not + * important, and currently corresponds to the default ordering on Nodes. + * + * The following has utilities involving monmoial sums. + * + */ +class ArithMSum +{ + public: + /** get monomial + * + * If n = n[0]*n[1] where n[0] is constant and n[1] is not, + * this function returns true, sets c to n[0] and v to n[1]. + */ + static bool getMonomial(Node n, Node& c, Node& v); + + /** get monomial + * + * If this function returns true, it adds the ( m-constant, m-variable ) + * pair corresponding to the monomial representation of n to the + * monomial sum msum. + * + * This function returns false if the m-variable of n is already + * present in n. + */ + static bool getMonomial(Node n, std::map& msum); + + /** get monomial sum for real-valued term n + * + * If this function returns true, it sets msum to a monmoial sum such that + * [msum] is equivalent to n + * + * This function may return false if n is not a sum of monomials + * whose variables are pairwise unique. + * If term n is in rewritten form, this function should always return true. + */ + static bool getMonomialSum(Node n, std::map& msum); + + /** get monmoial sum literal for literal lit + * + * If this function returns true, it sets msum to a monmoial sum such that + * [msum] 0 is equivalent to lit[0] lit[1] + * where k is the Kind of lit, one of { EQUAL, GEQ }. + * + * This function may return false if either side of lit is not a sum + * of monomials whose variables are pairwise unique on that side. + * If literal lit is in rewritten form, this function should always return + * true. + */ + static bool getMonomialSumLit(Node lit, std::map& msum); + + /** make node for monomial sum + * + * Make the Node corresponding to the interpretation of msum, [msum], where: + * [msum] = sum_{( v, c ) \in msum } [c]*[v] + */ + static Node mkNode(const std::map& msum); + + /** make coefficent term + * + * Input c is a m-constant. + * Returns the term t if c.isNull() or c*t otherwise. + */ + static inline Node mkCoeffTerm(Node c, Node t) + { + return c.isNull() ? t : NodeManager::currentNM()->mkNode(kind::MULT, c, t); + } + + /** isolate variable v in constraint ([msum] 0) + * + * If this function returns a value ret where ret != 0, then + * veq_c is set to m-constant, and val is set to a term such that: + * If ret=1, then ([veq_c] * v val) is equivalent to [msum] 0. + * If ret=-1, then (val [veq_c] * v) is equivalent to [msum] 0. + * If veq_c is non-null, then it is a positive constant Rational. + * The returned value of veq_c is only non-null if v has integer type. + * + * This function returns 0, indicating a failure, if msum does not contain + * a (non-zero) monomial having mvariable v. + */ + static int isolate( + Node v, const std::map& msum, Node& veq_c, Node& val, Kind k); + + /** isolate variable v in constraint ([msum] 0) + * + * If this function returns a value ret where ret != 0, then veq + * is set to a literal that is equivalent to ([msum] 0), and: + * If ret=1, then veq is of the form ( v val) if veq_c.isNull(), + * or ([veq_c] * v val) if !veq_c.isNull(). + * If ret=-1, then veq is of the form ( val v) if veq_c.isNull(), + * or (val [veq_c] * v) if !veq_c.isNull(). + * If doCoeff = false or v does not have Integer type, then veq_c is null. + * + * This function returns 0 indicating a failure if msum does not contain + * a (non-zero) monomial having variable v, or if veq_c must be non-null + * for an integer constraint and doCoeff is false. + */ + static int isolate(Node v, + const std::map& msum, + Node& veq, + Kind k, + bool doCoeff = false); + + /** solve equality lit for variable + * + * If return value ret is non-null, then: + * v = ret is equivalent to lit. + * + * This function may return false if lit does not contain v, + * or if lit is an integer equality with a coefficent on v, + * e.g. 3*v = 7. + */ + static Node solveEqualityFor(Node lit, Node v); + + /** decompose real-valued term n + * + * If this function returns true, then + * ([coeff]*v + rem) is equivalent to n + * where coeff is non-zero m-constant. + * + * This function will return false if n is not a monomial sum containing + * a monomial with factor v. + */ + static bool decompose(Node n, Node v, Node& coeff, Node& rem); + + /** return the rewritten form of (UMINUS t) */ + static Node negate(Node t); + + /** return the rewritten form of (PLUS t (CONST_RATIONAL i)) */ + static Node offset(Node t, int i); + + /** debug print for a monmoial sum, prints to Trace(c) */ + static void debugPrintMonomialSum(std::map& msum, const char* c); +}; + +} /* CVC4::theory namespace */ +} /* CVC4 namespace */ + +#endif /* __CVC4__THEORY__ARITH__MSUM_H */ diff --git a/src/theory/quantifiers/bounded_integers.cpp b/src/theory/quantifiers/bounded_integers.cpp index f99d0b080..963aba612 100644 --- a/src/theory/quantifiers/bounded_integers.cpp +++ b/src/theory/quantifiers/bounded_integers.cpp @@ -16,9 +16,9 @@ #include "theory/quantifiers/bounded_integers.h" #include "options/quantifiers_options.h" +#include "theory/arith/arith_msum.h" #include "theory/quantifiers/first_order_model.h" #include "theory/quantifiers/model_engine.h" -#include "theory/quantifiers/quant_util.h" #include "theory/quantifiers/term_enumeration.h" #include "theory/quantifiers/term_util.h" #include "theory/theory_engine.h" @@ -288,15 +288,17 @@ void BoundedIntegers::process( Node q, Node n, bool pol, }else if( n.getKind()==GEQ ){ if( n[0].getType().isInteger() ){ std::map< Node, Node > msum; - if( QuantArith::getMonomialSumLit( n, msum ) ){ + if (ArithMSum::getMonomialSumLit(n, msum)) + { Trace("bound-int-debug") << "literal (polarity = " << pol << ") " << n << " is monomial sum : " << std::endl; - QuantArith::debugPrintMonomialSum( msum, "bound-int-debug" ); + ArithMSum::debugPrintMonomialSum(msum, "bound-int-debug"); for( std::map< Node, Node >::iterator it = msum.begin(); it != msum.end(); ++it ){ if ( !it->first.isNull() && it->first.getKind()==BOUND_VARIABLE && !isBound( q, it->first ) ){ //if not bound in another way if( bound_lit_type_map.find( it->first )==bound_lit_type_map.end() || bound_lit_type_map[it->first] == BOUND_INT_RANGE ){ Node veq; - if( QuantArith::isolate( it->first, msum, veq, GEQ )!=0 ){ + if (ArithMSum::isolate(it->first, msum, veq, GEQ) != 0) + { Node n1 = veq[0]; Node n2 = veq[1]; if(pol){ @@ -304,9 +306,9 @@ void BoundedIntegers::process( Node q, Node n, bool pol, n1 = veq[1]; n2 = veq[0]; if( n1.getKind()==BOUND_VARIABLE ){ - n2 = QuantArith::offset( n2, 1 ); + n2 = ArithMSum::offset(n2, 1); }else{ - n1 = QuantArith::offset( n1, -1 ); + n1 = ArithMSum::offset(n1, -1); } veq = NodeManager::currentNM()->mkNode( GEQ, n1, n2 ); } diff --git a/src/theory/quantifiers/ce_guided_single_inv.cpp b/src/theory/quantifiers/ce_guided_single_inv.cpp index c1b6c82ad..c810ed5cf 100644 --- a/src/theory/quantifiers/ce_guided_single_inv.cpp +++ b/src/theory/quantifiers/ce_guided_single_inv.cpp @@ -16,9 +16,9 @@ #include "expr/datatype.h" #include "options/quantifiers_options.h" +#include "theory/arith/arith_msum.h" #include "theory/quantifiers/ce_guided_instantiation.h" #include "theory/quantifiers/first_order_model.h" -#include "theory/quantifiers/quant_util.h" #include "theory/quantifiers/term_database_sygus.h" #include "theory/quantifiers/term_enumeration.h" #include "theory/quantifiers/term_util.h" @@ -1129,12 +1129,14 @@ void TransitionInference::getConstantSubstitution( std::vector< Node >& vars, st if( v.isNull() ){ //solve for var std::map< Node, Node > msum; - if( QuantArith::getMonomialSumLit( slit, msum ) ){ + if (ArithMSum::getMonomialSumLit(slit, msum)) + { for( std::map< Node, Node >::iterator itm = msum.begin(); itm != msum.end(); ++itm ){ if( std::find( vars.begin(), vars.end(), itm->first )!=vars.end() ){ Node veq_c; Node val; - int ires = QuantArith::isolate( itm->first, msum, veq_c, val, EQUAL ); + int ires = + ArithMSum::isolate(itm->first, msum, veq_c, val, EQUAL); if( ires!=0 && veq_c.isNull() && !TermUtil::containsTerm( val, itm->first ) ){ v = itm->first; s = val; diff --git a/src/theory/quantifiers/ce_guided_single_inv_sol.cpp b/src/theory/quantifiers/ce_guided_single_inv_sol.cpp index e21535bef..91c6e3089 100644 --- a/src/theory/quantifiers/ce_guided_single_inv_sol.cpp +++ b/src/theory/quantifiers/ce_guided_single_inv_sol.cpp @@ -19,7 +19,6 @@ #include "theory/quantifiers/ce_guided_instantiation.h" #include "theory/quantifiers/ce_guided_single_inv.h" #include "theory/quantifiers/first_order_model.h" -#include "theory/quantifiers/quant_util.h" #include "theory/quantifiers/term_database_sygus.h" #include "theory/quantifiers/term_enumeration.h" #include "theory/quantifiers/term_util.h" @@ -296,15 +295,6 @@ bool CegConjectureSingleInvSol::getAssignEquality( Node eq, std::vector< Node >& } } } - /* - TypeNode tn = eq[0].getType(); - if( tn.isInteger() || tn.isReal() ){ - std::map< Node, Node > msum; - if( QuantArith::getMonomialSumLit( eq, msum ) ){ - - } - } - */ return false; } diff --git a/src/theory/quantifiers/equality_infer.cpp b/src/theory/quantifiers/equality_infer.cpp index 66ca38e8c..a5cbef746 100644 --- a/src/theory/quantifiers/equality_infer.cpp +++ b/src/theory/quantifiers/equality_infer.cpp @@ -15,8 +15,10 @@ **/ #include "theory/quantifiers/equality_infer.h" -#include "theory/quantifiers/quant_util.h" + #include "context/context_mm.h" +#include "theory/rewriter.h" +#include "theory/arith/arith_msum.h" using namespace CVC4; using namespace CVC4::kind; @@ -144,7 +146,8 @@ void EqualityInference::eqNotifyNewClass(TNode t) { //somewhat strange: t may not be in rewritten form Node r = Rewriter::rewrite( t ); std::map< Node, Node > msum; - if( QuantArith::getMonomialSum( r, msum ) ){ + if (ArithMSum::getMonomialSum(r, msum)) + { Trace("eq-infer-debug2") << "...process monomial sum, size = " << msum.size() << std::endl; eqci->d_valid = true; bool changed = false; @@ -185,7 +188,8 @@ void EqualityInference::eqNotifyNewClass(TNode t) { Trace("eq-infer-debug2") << "...pre-rewrite : " << r << std::endl; r = Rewriter::rewrite( r ); msum.clear(); - if( !QuantArith::getMonomialSum( r, msum ) ){ + if (!ArithMSum::getMonomialSum(r, msum)) + { mvalid = false; } } @@ -285,7 +289,8 @@ void EqualityInference::eqNotifyMerge(TNode t1, TNode t2) { if( tr1!=tr2 ){ Node eq = tr1.eqNode( tr2 ); std::map< Node, Node > msum; - if( QuantArith::getMonomialSumLit( eq, msum ) ){ + if (ArithMSum::getMonomialSumLit(eq, msum)) + { Node v_solve; //solve for variables with no coefficient if( Trace.isOn("eq-infer-debug2") ){ @@ -315,7 +320,8 @@ void EqualityInference::eqNotifyMerge(TNode t1, TNode t2) { if( !v_solve.isNull() ){ //solve for v_solve Node veq; - if( QuantArith::isolate( v_solve, msum, veq, kind::EQUAL, true )==1 ){ + if (ArithMSum::isolate(v_solve, msum, veq, kind::EQUAL, true) == 1) + { Node v_value = veq[1]; Trace("eq-infer") << "...solved " << v_solve << " == " << v_value << std::endl; Assert( d_elim_vars.find( v_solve )==d_elim_vars.end() ); @@ -375,7 +381,9 @@ void EqualityInference::eqNotifyMerge(TNode t1, TNode t2) { std::map< Node, EqcInfo * >::iterator itt = d_eqci.find( r ); if( itt!=d_eqci.end() && ( itt->second->d_valid || itt->second->d_solved ) ){ std::map< Node, Node > msum2; - if( QuantArith::getMonomialSum( itt->second->d_rep.get(), msum2 ) ){ + if (ArithMSum::getMonomialSum(itt->second->d_rep.get(), + msum2)) + { std::map< Node, Node >::iterator itm = msum2.find( v_solve ); if( itm!=msum2.end() ){ //substitute in solved form @@ -387,7 +395,7 @@ void EqualityInference::eqNotifyMerge(TNode t1, TNode t2) { itm->second.isNull() ? d_one : itm->second ); } msum2.erase( itm ); - Node rr = QuantArith::mkNode( msum2 ); + Node rr = ArithMSum::mkNode(msum2); rr = Rewriter::rewrite( rr ); Trace("eq-infer") << "......update " << itt->first << " => " << rr << std::endl; setEqcRep( itt->first, rr, exp, itt->second ); diff --git a/src/theory/quantifiers/extended_rewrite.cpp b/src/theory/quantifiers/extended_rewrite.cpp index 7d3f9afab..b463a319a 100644 --- a/src/theory/quantifiers/extended_rewrite.cpp +++ b/src/theory/quantifiers/extended_rewrite.cpp @@ -14,8 +14,8 @@ #include "theory/quantifiers/extended_rewrite.h" +#include "theory/arith/arith_msum.h" #include "theory/datatypes/datatypes_rewriter.h" -#include "theory/quantifiers/quant_util.h" // for QuantArith #include "theory/rewriter.h" #include "theory/strings/theory_strings_rewriter.h" @@ -258,7 +258,7 @@ Node ExtendedRewriter::extendedRewrite(Node n) << std::endl; // compute monomial sum std::map msum; - if (QuantArith::getMonomialSumLit(ret_atom, msum)) + if (ArithMSum::getMonomialSumLit(ret_atom, msum)) { for (std::map::iterator itm = msum.begin(); itm != msum.end(); @@ -270,7 +270,7 @@ Node ExtendedRewriter::extendedRewrite(Node n) if (v.getKind() == ITE) { Node veq; - int res = QuantArith::isolate(v, msum, veq, ret_atom.getKind()); + int res = ArithMSum::isolate(v, msum, veq, ret_atom.getKind()); if (res != 0) { Trace("q-ext-rewrite-debug") diff --git a/src/theory/quantifiers/quantifiers_rewriter.cpp b/src/theory/quantifiers/quantifiers_rewriter.cpp index 45c418996..511e8f051 100644 --- a/src/theory/quantifiers/quantifiers_rewriter.cpp +++ b/src/theory/quantifiers/quantifiers_rewriter.cpp @@ -15,6 +15,7 @@ #include "theory/quantifiers/quantifiers_rewriter.h" #include "options/quantifiers_options.h" +#include "theory/arith/arith_msum.h" #include "theory/quantifiers/quantifiers_attributes.h" #include "theory/quantifiers/skolemize.h" #include "theory/quantifiers/term_database.h" @@ -942,7 +943,8 @@ bool QuantifiersRewriter::computeVariableElimLit( Node lit, bool pol, std::vecto ( ( lit.getKind()==GEQ || lit.getKind()==GT ) && options::varIneqElimQuant() ) ){ //for arithmetic, solve the (in)equality std::map< Node, Node > msum; - if( QuantArith::getMonomialSumLit( lit, msum ) ){ + if (ArithMSum::getMonomialSumLit(lit, msum)) + { for( std::map< Node, Node >::iterator itm = msum.begin(); itm != msum.end(); ++itm ){ if( !itm->first.isNull() ){ std::vector< Node >::iterator ita = std::find( args.begin(), args.end(), itm->first ); @@ -951,7 +953,8 @@ bool QuantifiersRewriter::computeVariableElimLit( Node lit, bool pol, std::vecto Assert( pol ); Node veq_c; Node val; - int ires = QuantArith::isolate( itm->first, msum, veq_c, val, EQUAL ); + int ires = + ArithMSum::isolate(itm->first, msum, veq_c, val, EQUAL); if( ires!=0 && veq_c.isNull() && isVariableElim( itm->first, val ) ){ Trace("var-elim-quant") << "Variable eliminate based on solved equality : " << itm->first << " -> " << val << std::endl; vars.push_back( itm->first ); @@ -964,7 +967,8 @@ bool QuantifiersRewriter::computeVariableElimLit( Node lit, bool pol, std::vecto //store that this literal is upper/lower bound for itm->first Node veq_c; Node val; - int ires = QuantArith::isolate( itm->first, msum, veq_c, val, lit.getKind() ); + int ires = ArithMSum::isolate( + itm->first, msum, veq_c, val, lit.getKind()); if( ires!=0 && veq_c.isNull() ){ bool is_upper = pol!=( ires==1 ); Trace("var-elim-ineq-debug") << lit << " is a " << ( is_upper ? "upper" : "lower" ) << " bound for " << itm->first << std::endl; diff --git a/src/theory/quantifiers/relevant_domain.cpp b/src/theory/quantifiers/relevant_domain.cpp index dcd24b433..e38f76994 100644 --- a/src/theory/quantifiers/relevant_domain.cpp +++ b/src/theory/quantifiers/relevant_domain.cpp @@ -12,11 +12,12 @@ ** \brief Implementation of relevant domain class **/ -#include "theory/quantifiers_engine.h" #include "theory/quantifiers/relevant_domain.h" +#include "theory/arith/arith_msum.h" +#include "theory/quantifiers/first_order_model.h" #include "theory/quantifiers/term_database.h" #include "theory/quantifiers/term_util.h" -#include "theory/quantifiers/first_order_model.h" +#include "theory/quantifiers_engine.h" using namespace std; using namespace CVC4; @@ -245,7 +246,8 @@ void RelevantDomain::computeRelevantDomainLit( Node q, bool hasPol, bool pol, No //solve the inequality for one/two variables, if possible if( n[0].getType().isReal() ){ std::map< Node, Node > msum; - if( QuantArith::getMonomialSumLit( n, msum ) ){ + if (ArithMSum::getMonomialSumLit(n, msum)) + { Node var; Node var2; bool hasNonVar = false; @@ -267,7 +269,8 @@ void RelevantDomain::computeRelevantDomainLit( Node q, bool hasPol, bool pol, No //single variable solve Node veq_c; Node val; - int ires = QuantArith::isolate( var, msum, veq_c, val, n.getKind() ); + int ires = + ArithMSum::isolate(var, msum, veq_c, val, n.getKind()); if( ires!=0 ){ if( veq_c.isNull() ){ r_add = val; @@ -302,10 +305,12 @@ void RelevantDomain::computeRelevantDomainLit( Node q, bool hasPol, bool pol, No if( ( !hasPol || pol ) && n[0].getType().isInteger() ){ if( n.getKind()==EQUAL ){ for( unsigned i=0; i<2; i++ ){ - d_rel_dom_lit[hasPol][pol][n].d_val.push_back( QuantArith::offset( r_add, i==0 ? 1 : -1 ) ); + d_rel_dom_lit[hasPol][pol][n].d_val.push_back( + ArithMSum::offset(r_add, i == 0 ? 1 : -1)); } }else if( n.getKind()==GEQ ){ - d_rel_dom_lit[hasPol][pol][n].d_val.push_back( QuantArith::offset( r_add, varLhs ? 1 : -1 ) ); + d_rel_dom_lit[hasPol][pol][n].d_val.push_back( + ArithMSum::offset(r_add, varLhs ? 1 : -1)); } } }else{ diff --git a/src/theory/quantifiers/relevant_domain.h b/src/theory/quantifiers/relevant_domain.h index fbf3520c6..112530788 100644 --- a/src/theory/quantifiers/relevant_domain.h +++ b/src/theory/quantifiers/relevant_domain.h @@ -18,6 +18,7 @@ #define __CVC4__THEORY__QUANTIFIERS__RELEVANT_DOMAIN_H #include "theory/quantifiers/first_order_model.h" +#include "theory/quantifiers/quant_util.h" namespace CVC4 { namespace theory { diff --git a/src/theory/quantifiers/term_database_sygus.cpp b/src/theory/quantifiers/term_database_sygus.cpp index 45e3d7593..e212e0dfb 100644 --- a/src/theory/quantifiers/term_database_sygus.cpp +++ b/src/theory/quantifiers/term_database_sygus.cpp @@ -19,6 +19,7 @@ #include "options/datatypes_options.h" #include "options/quantifiers_options.h" #include "smt/smt_engine.h" +#include "theory/arith/arith_msum.h" #include "theory/quantifiers/ce_guided_conjecture.h" #include "theory/quantifiers/term_database.h" #include "theory/quantifiers/term_util.h" @@ -1410,7 +1411,9 @@ Node TermDbSygus::minimizeBuiltinTerm( Node n ) { Node nc; if( n[r].getKind()==PLUS ){ for( unsigned i=0; i().isNegativeOne() ){ + if (ArithMSum::getMonomial(n[r][i], c, nc) + && c.getConst().isNegativeOne()) + { mon[ro].push_back( nc ); changed = true; }else{ @@ -1420,7 +1423,9 @@ Node TermDbSygus::minimizeBuiltinTerm( Node n ) { } } }else{ - if( QuantArith::getMonomial( n[r], c, nc ) && c.getConst().isNegativeOne() ){ + if (ArithMSum::getMonomial(n[r], c, nc) + && c.getConst().isNegativeOne()) + { mon[ro].push_back( nc ); changed = true; }else{ diff --git a/src/theory/quantifiers/term_util.cpp b/src/theory/quantifiers/term_util.cpp index 2183db5f1..4e75f7df8 100644 --- a/src/theory/quantifiers/term_util.cpp +++ b/src/theory/quantifiers/term_util.cpp @@ -19,6 +19,7 @@ #include "options/datatypes_options.h" #include "options/quantifiers_options.h" #include "options/uf_options.h" +#include "theory/arith/arith_msum.h" #include "theory/quantifiers/term_database.h" #include "theory/quantifiers/term_enumeration.h" #include "theory/quantifiers_engine.h" @@ -675,16 +676,17 @@ Node TermUtil::rewriteVtsSymbols( Node n ) { } if( !rew_vts_inf.isNull() || rew_delta ){ std::map< Node, Node > msum; - if( QuantArith::getMonomialSumLit( n, msum ) ){ + if (ArithMSum::getMonomialSumLit(n, msum)) + { if( Trace.isOn("quant-vts-debug") ){ Trace("quant-vts-debug") << "VTS got monomial sum : " << std::endl; - QuantArith::debugPrintMonomialSum( msum, "quant-vts-debug" ); + ArithMSum::debugPrintMonomialSum(msum, "quant-vts-debug"); } Node vts_sym = !rew_vts_inf.isNull() ? rew_vts_inf : d_vts_delta; Assert( !vts_sym.isNull() ); Node iso_n; Node nlit; - int res = QuantArith::isolate( vts_sym, msum, iso_n, n.getKind(), true ); + int res = ArithMSum::isolate(vts_sym, msum, iso_n, n.getKind(), true); if( res!=0 ){ Trace("quant-vts-debug") << "VTS isolated : -> " << iso_n << ", res = " << res << std::endl; Node slv = iso_n[res==1 ? 1 : 0]; -- 2.30.2