1 /********************* */
2 /*! \file quant_util.cpp
4 ** Original author: Andrew Reynolds
5 ** Major contributors: Morgan Deters
6 ** Minor contributors (to current version): none
7 ** This file is part of the CVC4 project.
8 ** Copyright (c) 2009-2014 New York University and The University of Iowa
9 ** See the file COPYING in the top-level source directory for licensing
10 ** information.\endverbatim
12 ** \brief Implementation of quantifier utilities
15 #include "theory/quantifiers/quant_util.h"
16 #include "theory/quantifiers/inst_match.h"
17 #include "theory/quantifiers/term_database.h"
18 #include "theory/quantifiers_engine.h"
22 using namespace CVC4::kind
;
23 using namespace CVC4::context
;
24 using namespace CVC4::theory
;
26 bool QuantArith::getMonomial( Node n
, Node
& c
, Node
& v
){
27 if( n
.getKind()==MULT
&& n
.getNumChildren()==2 && n
[0].isConst() ){
35 bool QuantArith::getMonomial( Node n
, std::map
< Node
, Node
>& msum
) {
36 if( n
.getKind()==MULT
&& n
.getNumChildren()==2 && n
[0].isConst() ){
37 if( msum
.find(n
[1])==msum
.end() ){
41 }else if( n
.isConst() ){
42 if( msum
.find(Node::null())==msum
.end() ){
43 msum
[Node::null()] = n
;
47 if( msum
.find(n
)==msum
.end() ){
48 msum
[n
] = Node::null();
55 bool QuantArith::getMonomialSum( Node n
, std::map
< Node
, Node
>& msum
) {
56 if ( n
.getKind()==PLUS
){
57 for( unsigned i
=0; i
<n
.getNumChildren(); i
++) {
58 if (!getMonomial( n
[i
], msum
)){
64 return getMonomial( n
, msum
);
68 bool QuantArith::getMonomialSumLit( Node lit
, std::map
< Node
, Node
>& msum
) {
69 if( lit
.getKind()==GEQ
|| lit
.getKind()==EQUAL
){
70 if( getMonomialSum( lit
[0], msum
) ){
71 if( lit
[1].isConst() && lit
[1].getConst
<Rational
>().isZero() ){
74 //subtract the other side
75 std::map
< Node
, Node
> msum2
;
76 if( getMonomialSum( lit
[1], msum2
) ){
77 for( std::map
< Node
, Node
>::iterator it
= msum2
.begin(); it
!= msum2
.end(); ++it
){
78 std::map
< Node
, Node
>::iterator it2
= msum
.find( it
->first
);
79 if( it2
!=msum
.end() ){
80 Node r
= NodeManager::currentNM()->mkNode( MINUS
, it2
->second
.isNull() ? NodeManager::currentNM()->mkConst( Rational(1) ) : it2
->second
,
81 it
->second
.isNull() ? NodeManager::currentNM()->mkConst( Rational(1) ) : it
->second
);
82 msum
[it
->first
] = Rewriter::rewrite( r
);
84 msum
[it
->first
] = it
->second
.isNull() ? NodeManager::currentNM()->mkConst( Rational(-1) ) : negate( it
->second
);
95 Node
QuantArith::mkNode( std::map
< Node
, Node
>& msum
) {
96 std::vector
< Node
> children
;
97 for( std::map
< Node
, Node
>::iterator it
= msum
.begin(); it
!= msum
.end(); ++it
){
99 if( !it
->first
.isNull() ){
100 if( !it
->second
.isNull() ){
101 m
= NodeManager::currentNM()->mkNode( MULT
, it
->second
, it
->first
);
106 Assert( !it
->second
.isNull() );
109 children
.push_back(m
);
111 return children
.size()>1 ? NodeManager::currentNM()->mkNode( PLUS
, children
) : (children
.size()==1 ? children
[0] : NodeManager::currentNM()->mkConst( Rational(0) ));
114 // given (msum <k> 0), solve (veq_c * v <k> val) or (val <k> veq_c * v), where:
115 // veq_c is either null (meaning 1), or positive.
116 // return value 1: veq_c*v is RHS, -1: veq_c*v is LHS, 0: failed.
117 int QuantArith::isolate( Node v
, std::map
< Node
, Node
>& msum
, Node
& veq_c
, Node
& val
, Kind k
) {
118 std::map
< Node
, Node
>::iterator itv
= msum
.find( v
);
119 if( itv
!=msum
.end() ){
120 std::vector
< Node
> children
;
121 Rational r
= itv
->second
.isNull() ? Rational(1) : itv
->second
.getConst
<Rational
>();
123 for( std::map
< Node
, Node
>::iterator it
= msum
.begin(); it
!= msum
.end(); ++it
){
126 if( !it
->first
.isNull() ){
127 if ( !it
->second
.isNull() ){
128 m
= NodeManager::currentNM()->mkNode( MULT
, it
->second
, it
->first
);
135 children
.push_back(m
);
138 val
= children
.size()>1 ? NodeManager::currentNM()->mkNode( PLUS
, children
) :
139 (children
.size()==1 ? children
[0] : NodeManager::currentNM()->mkConst( Rational(0) ));
140 if( !r
.isOne() && !r
.isNegativeOne() ){
141 if( v
.getType().isInteger() ){
142 veq_c
= NodeManager::currentNM()->mkConst( r
.abs() );
144 val
= NodeManager::currentNM()->mkNode( MULT
, val
, NodeManager::currentNM()->mkConst( Rational(1) / r
.abs() ) );
150 val
= Rewriter::rewrite( val
);
152 return ( r
.sgn()==1 || k
==EQUAL
) ? 1 : -1;
158 int QuantArith::isolate( Node v
, std::map
< Node
, Node
>& msum
, Node
& veq
, Kind k
, bool doCoeff
) {
161 //isolate v in the (in)equality
162 int ires
= isolate( v
, msum
, veq_c
, val
, k
);
165 if( !veq_c
.isNull() ){
167 vc
= NodeManager::currentNM()->mkNode( MULT
, veq_c
, vc
);
172 bool inOrder
= ires
==1;
173 veq
= NodeManager::currentNM()->mkNode( k
, inOrder
? vc
: val
, inOrder
? val
: vc
);
178 Node
QuantArith::solveEqualityFor( Node lit
, Node v
) {
179 Assert( lit
.getKind()==EQUAL
|| lit
.getKind()==IFF
);
180 //first look directly at sides
181 TypeNode tn
= lit
[0].getType();
182 for( unsigned r
=0; r
<2; r
++ ){
188 if( quantifiers::TermDb::containsTerm( lit
, v
) ){
189 std::map
< Node
, Node
> msum
;
190 if( QuantArith::getMonomialSumLit( lit
, msum
) ){
192 if( QuantArith::isolate( v
, msum
, veqc
, val
, EQUAL
)!=0 ){
203 Node
QuantArith::negate( Node t
) {
204 Node tt
= NodeManager::currentNM()->mkNode( MULT
, NodeManager::currentNM()->mkConst( Rational(-1) ), t
);
205 tt
= Rewriter::rewrite( tt
);
209 Node
QuantArith::offset( Node t
, int i
) {
210 Node tt
= NodeManager::currentNM()->mkNode( PLUS
, NodeManager::currentNM()->mkConst( Rational(i
) ), t
);
211 tt
= Rewriter::rewrite( tt
);
215 void QuantArith::debugPrintMonomialSum( std::map
< Node
, Node
>& msum
, const char * c
) {
216 for(std::map
< Node
, Node
>::iterator it
= msum
.begin(); it
!= msum
.end(); ++it
){
218 if( !it
->second
.isNull() ){
219 Trace(c
) << it
->second
;
220 if( !it
->first
.isNull() ){
224 if( !it
->first
.isNull() ){
225 Trace(c
) << it
->first
;
227 Trace(c
) << std::endl
;
229 Trace(c
) << std::endl
;
233 void QuantRelevance::registerQuantifier( Node f
){
234 //compute symbols in f
235 std::vector
< Node
> syms
;
236 computeSymbols( f
[1], syms
);
237 d_syms
[f
].insert( d_syms
[f
].begin(), syms
.begin(), syms
.end() );
238 //set initial relevance
239 int minRelevance
= -1;
240 for( int i
=0; i
<(int)syms
.size(); i
++ ){
241 d_syms_quants
[ syms
[i
] ].push_back( f
);
242 int r
= getRelevance( syms
[i
] );
243 if( r
!=-1 && ( minRelevance
==-1 || r
<minRelevance
) ){
247 if( minRelevance
!=-1 ){
248 setRelevance( f
, minRelevance
+1 );
253 /** compute symbols */
254 void QuantRelevance::computeSymbols( Node n
, std::vector
< Node
>& syms
){
255 if( n
.getKind()==APPLY_UF
){
256 Node op
= n
.getOperator();
257 if( std::find( syms
.begin(), syms
.end(), op
)==syms
.end() ){
258 syms
.push_back( op
);
261 if( n
.getKind()!=FORALL
){
262 for( int i
=0; i
<(int)n
.getNumChildren(); i
++ ){
263 computeSymbols( n
[i
], syms
);
269 void QuantRelevance::setRelevance( Node s
, int r
){
271 int rOld
= getRelevance( s
);
272 if( rOld
==-1 || r
<rOld
){
274 if( s
.getKind()==FORALL
){
275 for( int i
=0; i
<(int)d_syms
[s
].size(); i
++ ){
276 setRelevance( d_syms
[s
][i
], r
);
279 for( int i
=0; i
<(int)d_syms_quants
[s
].size(); i
++ ){
280 setRelevance( d_syms_quants
[s
][i
], r
+1 );
288 QuantPhaseReq::QuantPhaseReq( Node n
, bool computeEq
){
289 initialize( n
, computeEq
);
292 void QuantPhaseReq::initialize( Node n
, bool computeEq
){
293 std::map
< Node
, int > phaseReqs2
;
294 computePhaseReqs( n
, false, phaseReqs2
);
295 for( std::map
< Node
, int >::iterator it
= phaseReqs2
.begin(); it
!= phaseReqs2
.end(); ++it
){
297 d_phase_reqs
[ it
->first
] = true;
298 }else if( it
->second
==-1 ){
299 d_phase_reqs
[ it
->first
] = false;
302 Debug("inst-engine-phase-req") << "Phase requirements for " << n
<< ":" << std::endl
;
303 //now, compute if any patterns are equality required
305 for( std::map
< Node
, bool >::iterator it
= d_phase_reqs
.begin(); it
!= d_phase_reqs
.end(); ++it
){
306 Debug("inst-engine-phase-req") << " " << it
->first
<< " -> " << it
->second
<< std::endl
;
307 if( it
->first
.getKind()==EQUAL
){
308 if( quantifiers::TermDb::hasInstConstAttr(it
->first
[0]) ){
309 if( !quantifiers::TermDb::hasInstConstAttr(it
->first
[1]) ){
310 d_phase_reqs_equality_term
[ it
->first
[0] ] = it
->first
[1];
311 d_phase_reqs_equality
[ it
->first
[0] ] = it
->second
;
312 Debug("inst-engine-phase-req") << " " << it
->first
[0] << ( it
->second
? " == " : " != " ) << it
->first
[1] << std::endl
;
314 }else if( quantifiers::TermDb::hasInstConstAttr(it
->first
[1]) ){
315 d_phase_reqs_equality_term
[ it
->first
[1] ] = it
->first
[0];
316 d_phase_reqs_equality
[ it
->first
[1] ] = it
->second
;
317 Debug("inst-engine-phase-req") << " " << it
->first
[1] << ( it
->second
? " == " : " != " ) << it
->first
[0] << std::endl
;
324 void QuantPhaseReq::computePhaseReqs( Node n
, bool polarity
, std::map
< Node
, int >& phaseReqs
){
325 bool newReqPol
= false;
327 if( n
.getKind()==NOT
){
329 newPolarity
= !polarity
;
330 }else if( n
.getKind()==OR
|| n
.getKind()==IMPLIES
){
335 }else if( n
.getKind()==AND
){
341 int val
= polarity
? 1 : -1;
342 if( phaseReqs
.find( n
)==phaseReqs
.end() ){
344 }else if( val
!=phaseReqs
[n
] ){
349 for( int i
=0; i
<(int)n
.getNumChildren(); i
++ ){
350 if( n
.getKind()==IMPLIES
&& i
==0 ){
351 computePhaseReqs( n
[i
], !newPolarity
, phaseReqs
);
353 computePhaseReqs( n
[i
], newPolarity
, phaseReqs
);
359 void QuantPhaseReq::getPolarity( Node n
, int child
, bool hasPol
, bool pol
, bool& newHasPol
, bool& newPol
) {
360 if( n
.getKind()==AND
|| n
.getKind()==OR
){
363 }else if( n
.getKind()==IMPLIES
){
365 newPol
= child
==0 ? !pol
: pol
;
366 }else if( n
.getKind()==NOT
){
369 }else if( n
.getKind()==ITE
){
370 newHasPol
= (child
!=0) && hasPol
;
372 }else if( n
.getKind()==FORALL
){
373 newHasPol
= (child
==1) && hasPol
;
381 void QuantPhaseReq::getEntailPolarity( Node n
, int child
, bool hasPol
, bool pol
, bool& newHasPol
, bool& newPol
) {
382 if( n
.getKind()==AND
|| n
.getKind()==OR
){
383 newHasPol
= hasPol
&& pol
==( n
.getKind()==AND
);
385 }else if( n
.getKind()==IMPLIES
){
386 newHasPol
= hasPol
&& !pol
;
387 newPol
= child
==0 ? !pol
: pol
;
388 }else if( n
.getKind()==NOT
){