Updates to E-matching to avoid entailed instantiations earlier. Minor updates to...
[cvc5.git] / src / theory / quantifiers / quant_util.cpp
1 /********************* */
2 /*! \file quant_util.cpp
3 ** \verbatim
4 ** Original author: Andrew Reynolds
5 ** Major contributors: Morgan Deters
6 ** Minor contributors (to current version): none
7 ** This file is part of the CVC4 project.
8 ** Copyright (c) 2009-2014 New York University and The University of Iowa
9 ** See the file COPYING in the top-level source directory for licensing
10 ** information.\endverbatim
11 **
12 ** \brief Implementation of quantifier utilities
13 **/
14
15 #include "theory/quantifiers/quant_util.h"
16 #include "theory/quantifiers/inst_match.h"
17 #include "theory/quantifiers/term_database.h"
18 #include "theory/quantifiers_engine.h"
19
20 using namespace std;
21 using namespace CVC4;
22 using namespace CVC4::kind;
23 using namespace CVC4::context;
24 using namespace CVC4::theory;
25
26 bool QuantArith::getMonomial( Node n, Node& c, Node& v ){
27 if( n.getKind()==MULT && n.getNumChildren()==2 && n[0].isConst() ){
28 c = n[0];
29 v = n[1];
30 return true;
31 }else{
32 return false;
33 }
34 }
35 bool QuantArith::getMonomial( Node n, std::map< Node, Node >& msum ) {
36 if( n.getKind()==MULT && n.getNumChildren()==2 && n[0].isConst() ){
37 if( msum.find(n[1])==msum.end() ){
38 msum[n[1]] = n[0];
39 return true;
40 }
41 }else if( n.isConst() ){
42 if( msum.find(Node::null())==msum.end() ){
43 msum[Node::null()] = n;
44 return true;
45 }
46 }else{
47 if( msum.find(n)==msum.end() ){
48 msum[n] = Node::null();
49 return true;
50 }
51 }
52 return false;
53 }
54
55 bool QuantArith::getMonomialSum( Node n, std::map< Node, Node >& msum ) {
56 if ( n.getKind()==PLUS ){
57 for( unsigned i=0; i<n.getNumChildren(); i++) {
58 if (!getMonomial( n[i], msum )){
59 return false;
60 }
61 }
62 return true;
63 }else{
64 return getMonomial( n, msum );
65 }
66 }
67
68 bool QuantArith::getMonomialSumLit( Node lit, std::map< Node, Node >& msum ) {
69 if( lit.getKind()==GEQ || lit.getKind()==EQUAL ){
70 if( getMonomialSum( lit[0], msum ) ){
71 if( lit[1].isConst() && lit[1].getConst<Rational>().isZero() ){
72 return true;
73 }else{
74 //subtract the other side
75 std::map< Node, Node > msum2;
76 if( getMonomialSum( lit[1], msum2 ) ){
77 for( std::map< Node, Node >::iterator it = msum2.begin(); it != msum2.end(); ++it ){
78 std::map< Node, Node >::iterator it2 = msum.find( it->first );
79 if( it2!=msum.end() ){
80 Node r = NodeManager::currentNM()->mkNode( MINUS, it2->second.isNull() ? NodeManager::currentNM()->mkConst( Rational(1) ) : it2->second,
81 it->second.isNull() ? NodeManager::currentNM()->mkConst( Rational(1) ) : it->second );
82 msum[it->first] = Rewriter::rewrite( r );
83 }else{
84 msum[it->first] = it->second.isNull() ? NodeManager::currentNM()->mkConst( Rational(-1) ) : negate( it->second );
85 }
86 }
87 return true;
88 }
89 }
90 }
91 }
92 return false;
93 }
94
95 Node QuantArith::mkNode( std::map< Node, Node >& msum ) {
96 std::vector< Node > children;
97 for( std::map< Node, Node >::iterator it = msum.begin(); it != msum.end(); ++it ){
98 Node m;
99 if( !it->first.isNull() ){
100 if( !it->second.isNull() ){
101 m = NodeManager::currentNM()->mkNode( MULT, it->second, it->first );
102 }else{
103 m = it->first;
104 }
105 }else{
106 Assert( !it->second.isNull() );
107 m = it->second;
108 }
109 children.push_back(m);
110 }
111 return children.size()>1 ? NodeManager::currentNM()->mkNode( PLUS, children ) : (children.size()==1 ? children[0] : NodeManager::currentNM()->mkConst( Rational(0) ));
112 }
113
114 // given (msum <k> 0), solve (veq_c * v <k> val) or (val <k> veq_c * v), where:
115 // veq_c is either null (meaning 1), or positive.
116 // return value 1: veq_c*v is RHS, -1: veq_c*v is LHS, 0: failed.
117 int QuantArith::isolate( Node v, std::map< Node, Node >& msum, Node & veq_c, Node & val, Kind k ) {
118 std::map< Node, Node >::iterator itv = msum.find( v );
119 if( itv!=msum.end() ){
120 std::vector< Node > children;
121 Rational r = itv->second.isNull() ? Rational(1) : itv->second.getConst<Rational>();
122 if ( r.sgn()!=0 ){
123 for( std::map< Node, Node >::iterator it = msum.begin(); it != msum.end(); ++it ){
124 if( it->first!=v ){
125 Node m;
126 if( !it->first.isNull() ){
127 if ( !it->second.isNull() ){
128 m = NodeManager::currentNM()->mkNode( MULT, it->second, it->first );
129 }else{
130 m = it->first;
131 }
132 }else{
133 m = it->second;
134 }
135 children.push_back(m);
136 }
137 }
138 val = children.size()>1 ? NodeManager::currentNM()->mkNode( PLUS, children ) :
139 (children.size()==1 ? children[0] : NodeManager::currentNM()->mkConst( Rational(0) ));
140 if( !r.isOne() && !r.isNegativeOne() ){
141 if( v.getType().isInteger() ){
142 veq_c = NodeManager::currentNM()->mkConst( r.abs() );
143 }else{
144 val = NodeManager::currentNM()->mkNode( MULT, val, NodeManager::currentNM()->mkConst( Rational(1) / r.abs() ) );
145 }
146 }
147 if( r.sgn()==1 ){
148 val = negate(val);
149 }else{
150 val = Rewriter::rewrite( val );
151 }
152 return ( r.sgn()==1 || k==EQUAL ) ? 1 : -1;
153 }
154 }
155 return 0;
156 }
157
158 int QuantArith::isolate( Node v, std::map< Node, Node >& msum, Node & veq, Kind k, bool doCoeff ) {
159 Node veq_c;
160 Node val;
161 //isolate v in the (in)equality
162 int ires = isolate( v, msum, veq_c, val, k );
163 if( ires!=0 ){
164 Node vc = v;
165 if( !veq_c.isNull() ){
166 if( doCoeff ){
167 vc = NodeManager::currentNM()->mkNode( MULT, veq_c, vc );
168 }else{
169 return 0;
170 }
171 }
172 bool inOrder = ires==1;
173 veq = NodeManager::currentNM()->mkNode( k, inOrder ? vc : val, inOrder ? val : vc );
174 }
175 return ires;
176 }
177
178 Node QuantArith::solveEqualityFor( Node lit, Node v ) {
179 Assert( lit.getKind()==EQUAL || lit.getKind()==IFF );
180 //first look directly at sides
181 TypeNode tn = lit[0].getType();
182 for( unsigned r=0; r<2; r++ ){
183 if( lit[r]==v ){
184 return lit[1-r];
185 }
186 }
187 if( tn.isReal() ){
188 if( quantifiers::TermDb::containsTerm( lit, v ) ){
189 std::map< Node, Node > msum;
190 if( QuantArith::getMonomialSumLit( lit, msum ) ){
191 Node val, veqc;
192 if( QuantArith::isolate( v, msum, veqc, val, EQUAL )!=0 ){
193 if( veqc.isNull() ){
194 return val;
195 }
196 }
197 }
198 }
199 }
200 return Node::null();
201 }
202
203 Node QuantArith::negate( Node t ) {
204 Node tt = NodeManager::currentNM()->mkNode( MULT, NodeManager::currentNM()->mkConst( Rational(-1) ), t );
205 tt = Rewriter::rewrite( tt );
206 return tt;
207 }
208
209 Node QuantArith::offset( Node t, int i ) {
210 Node tt = NodeManager::currentNM()->mkNode( PLUS, NodeManager::currentNM()->mkConst( Rational(i) ), t );
211 tt = Rewriter::rewrite( tt );
212 return tt;
213 }
214
215 void QuantArith::debugPrintMonomialSum( std::map< Node, Node >& msum, const char * c ) {
216 for(std::map< Node, Node >::iterator it = msum.begin(); it != msum.end(); ++it ){
217 Trace(c) << " ";
218 if( !it->second.isNull() ){
219 Trace(c) << it->second;
220 if( !it->first.isNull() ){
221 Trace(c) << " * ";
222 }
223 }
224 if( !it->first.isNull() ){
225 Trace(c) << it->first;
226 }
227 Trace(c) << std::endl;
228 }
229 Trace(c) << std::endl;
230 }
231
232
233 void QuantRelevance::registerQuantifier( Node f ){
234 //compute symbols in f
235 std::vector< Node > syms;
236 computeSymbols( f[1], syms );
237 d_syms[f].insert( d_syms[f].begin(), syms.begin(), syms.end() );
238 //set initial relevance
239 int minRelevance = -1;
240 for( int i=0; i<(int)syms.size(); i++ ){
241 d_syms_quants[ syms[i] ].push_back( f );
242 int r = getRelevance( syms[i] );
243 if( r!=-1 && ( minRelevance==-1 || r<minRelevance ) ){
244 minRelevance = r;
245 }
246 }
247 if( minRelevance!=-1 ){
248 setRelevance( f, minRelevance+1 );
249 }
250 }
251
252
253 /** compute symbols */
254 void QuantRelevance::computeSymbols( Node n, std::vector< Node >& syms ){
255 if( n.getKind()==APPLY_UF ){
256 Node op = n.getOperator();
257 if( std::find( syms.begin(), syms.end(), op )==syms.end() ){
258 syms.push_back( op );
259 }
260 }
261 if( n.getKind()!=FORALL ){
262 for( int i=0; i<(int)n.getNumChildren(); i++ ){
263 computeSymbols( n[i], syms );
264 }
265 }
266 }
267
268 /** set relevance */
269 void QuantRelevance::setRelevance( Node s, int r ){
270 if( d_computeRel ){
271 int rOld = getRelevance( s );
272 if( rOld==-1 || r<rOld ){
273 d_relevance[s] = r;
274 if( s.getKind()==FORALL ){
275 for( int i=0; i<(int)d_syms[s].size(); i++ ){
276 setRelevance( d_syms[s][i], r );
277 }
278 }else{
279 for( int i=0; i<(int)d_syms_quants[s].size(); i++ ){
280 setRelevance( d_syms_quants[s][i], r+1 );
281 }
282 }
283 }
284 }
285 }
286
287
288 QuantPhaseReq::QuantPhaseReq( Node n, bool computeEq ){
289 initialize( n, computeEq );
290 }
291
292 void QuantPhaseReq::initialize( Node n, bool computeEq ){
293 std::map< Node, int > phaseReqs2;
294 computePhaseReqs( n, false, phaseReqs2 );
295 for( std::map< Node, int >::iterator it = phaseReqs2.begin(); it != phaseReqs2.end(); ++it ){
296 if( it->second==1 ){
297 d_phase_reqs[ it->first ] = true;
298 }else if( it->second==-1 ){
299 d_phase_reqs[ it->first ] = false;
300 }
301 }
302 Debug("inst-engine-phase-req") << "Phase requirements for " << n << ":" << std::endl;
303 //now, compute if any patterns are equality required
304 if( computeEq ){
305 for( std::map< Node, bool >::iterator it = d_phase_reqs.begin(); it != d_phase_reqs.end(); ++it ){
306 Debug("inst-engine-phase-req") << " " << it->first << " -> " << it->second << std::endl;
307 if( it->first.getKind()==EQUAL ){
308 if( quantifiers::TermDb::hasInstConstAttr(it->first[0]) ){
309 if( !quantifiers::TermDb::hasInstConstAttr(it->first[1]) ){
310 d_phase_reqs_equality_term[ it->first[0] ] = it->first[1];
311 d_phase_reqs_equality[ it->first[0] ] = it->second;
312 Debug("inst-engine-phase-req") << " " << it->first[0] << ( it->second ? " == " : " != " ) << it->first[1] << std::endl;
313 }
314 }else if( quantifiers::TermDb::hasInstConstAttr(it->first[1]) ){
315 d_phase_reqs_equality_term[ it->first[1] ] = it->first[0];
316 d_phase_reqs_equality[ it->first[1] ] = it->second;
317 Debug("inst-engine-phase-req") << " " << it->first[1] << ( it->second ? " == " : " != " ) << it->first[0] << std::endl;
318 }
319 }
320 }
321 }
322 }
323
324 void QuantPhaseReq::computePhaseReqs( Node n, bool polarity, std::map< Node, int >& phaseReqs ){
325 bool newReqPol = false;
326 bool newPolarity;
327 if( n.getKind()==NOT ){
328 newReqPol = true;
329 newPolarity = !polarity;
330 }else if( n.getKind()==OR || n.getKind()==IMPLIES ){
331 if( !polarity ){
332 newReqPol = true;
333 newPolarity = false;
334 }
335 }else if( n.getKind()==AND ){
336 if( polarity ){
337 newReqPol = true;
338 newPolarity = true;
339 }
340 }else{
341 int val = polarity ? 1 : -1;
342 if( phaseReqs.find( n )==phaseReqs.end() ){
343 phaseReqs[n] = val;
344 }else if( val!=phaseReqs[n] ){
345 phaseReqs[n] = 0;
346 }
347 }
348 if( newReqPol ){
349 for( int i=0; i<(int)n.getNumChildren(); i++ ){
350 if( n.getKind()==IMPLIES && i==0 ){
351 computePhaseReqs( n[i], !newPolarity, phaseReqs );
352 }else{
353 computePhaseReqs( n[i], newPolarity, phaseReqs );
354 }
355 }
356 }
357 }
358
359 void QuantPhaseReq::getPolarity( Node n, int child, bool hasPol, bool pol, bool& newHasPol, bool& newPol ) {
360 if( n.getKind()==AND || n.getKind()==OR ){
361 newHasPol = hasPol;
362 newPol = pol;
363 }else if( n.getKind()==IMPLIES ){
364 newHasPol = hasPol;
365 newPol = child==0 ? !pol : pol;
366 }else if( n.getKind()==NOT ){
367 newHasPol = hasPol;
368 newPol = !pol;
369 }else if( n.getKind()==ITE ){
370 newHasPol = (child!=0) && hasPol;
371 newPol = pol;
372 }else if( n.getKind()==FORALL ){
373 newHasPol = (child==1) && hasPol;
374 newPol = pol;
375 }else{
376 newHasPol = false;
377 newPol = pol;
378 }
379 }
380
381 void QuantPhaseReq::getEntailPolarity( Node n, int child, bool hasPol, bool pol, bool& newHasPol, bool& newPol ) {
382 if( n.getKind()==AND || n.getKind()==OR ){
383 newHasPol = hasPol && pol==( n.getKind()==AND );
384 newPol = pol;
385 }else if( n.getKind()==IMPLIES ){
386 newHasPol = hasPol && !pol;
387 newPol = child==0 ? !pol : pol;
388 }else if( n.getKind()==NOT ){
389 newHasPol = hasPol;
390 newPol = !pol;
391 }else{
392 newHasPol = false;
393 newPol = pol;
394 }
395 }
396