7321e22b4ab8d427b662118e1f540a2b5ab8e3b5
[cvc5.git] / src / theory / quantifiers / macros.cpp
1 /********************* */
2 /*! \file macros.cpp
3 ** \verbatim
4 ** Original author: Andrew Reynolds
5 ** Major contributors: Morgan Deters
6 ** Minor contributors (to current version): Kshitij Bansal
7 ** This file is part of the CVC4 project.
8 ** Copyright (c) 2009-2014 New York University and The University of Iowa
9 ** See the file COPYING in the top-level source directory for licensing
10 ** information.\endverbatim
11 **
12 ** \brief Sort inference module
13 **
14 ** This class implements quantifiers macro definitions.
15 **/
16
17 #include <vector>
18
19 #include "theory/quantifiers/macros.h"
20 #include "theory/rewriter.h"
21
22 using namespace CVC4;
23 using namespace std;
24 using namespace CVC4::theory;
25 using namespace CVC4::theory::quantifiers;
26 using namespace CVC4::kind;
27 using namespace CVC4::context;
28
29
30 bool QuantifierMacros::simplify( std::vector< Node >& assertions, bool doRewrite ){
31 Trace("macros") << "Find macros..." << std::endl;
32 //first, collect macro definitions
33 for( size_t i=0; i<assertions.size(); i++ ){
34 processAssertion( assertions[i] );
35 }
36 bool retVal = false;
37 if( doRewrite && !d_macro_defs.empty() ){
38 //now, rewrite based on macro definitions
39 for( size_t i=0; i<assertions.size(); i++ ){
40 Node prev = assertions[i];
41 assertions[i] = simplify( assertions[i] );
42 if( prev!=assertions[i] ){
43 assertions[i] = Rewriter::rewrite( assertions[i] );
44 Trace("macros-rewrite") << "Rewrite " << prev << " to " << assertions[i] << std::endl;
45 retVal = true;
46 }
47 }
48 }
49 return retVal;
50 }
51
52 void QuantifierMacros::processAssertion( Node n ) {
53 if( n.getKind()==AND ){
54 for( unsigned i=0; i<n.getNumChildren(); i++ ){
55 processAssertion( n[i] );
56 }
57 }else if( n.getKind()==FORALL ){
58 std::vector< Node > args;
59 for( size_t j=0; j<n[0].getNumChildren(); j++ ){
60 args.push_back( n[0][j] );
61 }
62 //look at the body of the quantifier for macro definition
63 process( n[1], true, args, n );
64 }
65 }
66
67 bool QuantifierMacros::contains( Node n, Node n_s ){
68 if( n==n_s ){
69 return true;
70 }else{
71 for( size_t i=0; i<n.getNumChildren(); i++ ){
72 if( contains( n[i], n_s ) ){
73 return true;
74 }
75 }
76 return false;
77 }
78 }
79
80 bool QuantifierMacros::containsBadOp( Node n, Node op ){
81 if( n.getKind()==APPLY_UF ){
82 Node nop = n.getOperator();
83 if( nop==op || d_macro_defs.find( nop )!=d_macro_defs.end() ){
84 return true;
85 }
86 }
87 for( size_t i=0; i<n.getNumChildren(); i++ ){
88 if( containsBadOp( n[i], op ) ){
89 return true;
90 }
91 }
92 return false;
93 }
94
95 bool QuantifierMacros::isMacroLiteral( Node n, bool pol ){
96 return pol && ( n.getKind()==EQUAL || n.getKind()==IFF );
97 }
98
99 bool QuantifierMacros::isBoundVarApplyUf( Node n ) {
100 Assert( n.getKind()==APPLY_UF );
101 TypeNode tn = n.getOperator().getType();
102 for( unsigned i=0; i<n.getNumChildren(); i++ ){
103 if( n[i].getKind()!=BOUND_VARIABLE ){
104 return false;
105 }
106 if( n[i].getType()!=tn[i] ){
107 return false;
108 }
109 for( unsigned j=0; j<i; j++ ){
110 if( n[j]==n[i] ){
111 return false;
112 }
113 }
114 }
115 return true;
116 }
117
118 void QuantifierMacros::getMacroCandidates( Node n, std::vector< Node >& candidates ){
119 if( n.getKind()==APPLY_UF ){
120 if( isBoundVarApplyUf( n ) ){
121 candidates.push_back( n );
122 }
123 }else if( n.getKind()==PLUS ){
124 for( size_t i=0; i<n.getNumChildren(); i++ ){
125 getMacroCandidates( n[i], candidates );
126 }
127 }else if( n.getKind()==MULT ){
128 //if the LHS is a constant
129 if( n.getNumChildren()==2 && n[0].isConst() ){
130 getMacroCandidates( n[1], candidates );
131 }
132 }
133 }
134
135 Node QuantifierMacros::solveInEquality( Node n, Node lit ){
136 if( lit.getKind()==IFF || lit.getKind()==EQUAL ){
137 //return the opposite side of the equality if defined that way
138 for( int i=0; i<2; i++ ){
139 if( lit[i]==n ){
140 return lit[ i==0 ? 1 : 0];
141 }
142 }
143 //must solve for term n in the literal lit
144 if( lit[0].getType().isInteger() || lit[0].getType().isReal() ){
145 Node coeff;
146 Node term;
147 //could be solved for on LHS
148 if( lit[0].getKind()==MULT && lit[0][1]==n ){
149 Assert( lit[0][0].isConst() );
150 term = lit[1];
151 coeff = lit[0][0];
152 }else{
153 Assert( lit[1].getKind()==PLUS );
154 std::vector< Node > plus_children;
155 //find monomial with n
156 for( size_t j=0; j<lit[1].getNumChildren(); j++ ){
157 if( lit[1][j]==n ){
158 Assert( coeff.isNull() );
159 coeff = NodeManager::currentNM()->mkConst( Rational(1) );
160 }else if( lit[1][j].getKind()==MULT && lit[1][j][1]==n ){
161 Assert( coeff.isNull() );
162 Assert( lit[1][j][0].isConst() );
163 coeff = lit[1][j][0];
164 }else{
165 plus_children.push_back( lit[1][j] );
166 }
167 }
168 if( !coeff.isNull() ){
169 term = NodeManager::currentNM()->mkNode( PLUS, plus_children );
170 term = NodeManager::currentNM()->mkNode( MINUS, lit[0], term );
171 }
172 }
173 if( !coeff.isNull() ){
174 coeff = NodeManager::currentNM()->mkConst( Rational(1) / coeff.getConst<Rational>() );
175 term = NodeManager::currentNM()->mkNode( MULT, coeff, term );
176 term = Rewriter::rewrite( term );
177 return term;
178 }
179 }
180 }
181 Trace("macros-debug") << "Cannot find for " << lit << " " << n << std::endl;
182 return Node::null();
183 }
184
185 bool QuantifierMacros::getFreeVariables( Node n, std::vector< Node >& v_quant, std::vector< Node >& vars, bool retOnly ){
186 if( std::find( v_quant.begin(), v_quant.end(), n )!=v_quant.end() ){
187 if( std::find( vars.begin(), vars.end(), n )==vars.end() ){
188 if( retOnly ){
189 return true;
190 }else{
191 vars.push_back( n );
192 }
193 }
194 }
195 for( size_t i=0; i<n.getNumChildren(); i++ ){
196 if( getFreeVariables( n[i], v_quant, vars, retOnly ) ){
197 return true;
198 }
199 }
200 return false;
201 }
202
203 bool QuantifierMacros::getSubstitution( std::vector< Node >& v_quant, std::map< Node, Node >& solved,
204 std::vector< Node >& vars, std::vector< Node >& subs, bool reqComplete ){
205 bool success = true;
206 for( size_t a=0; a<v_quant.size(); a++ ){
207 if( !solved[ v_quant[a] ].isNull() ){
208 vars.push_back( v_quant[a] );
209 subs.push_back( solved[ v_quant[a] ] );
210 }else{
211 if( reqComplete ){
212 success = false;
213 break;
214 }
215 }
216 }
217 return success;
218 }
219
220 void QuantifierMacros::process( Node n, bool pol, std::vector< Node >& args, Node f ){
221 if( n.getKind()==NOT ){
222 process( n[0], !pol, args, f );
223 }else if( n.getKind()==AND || n.getKind()==OR ){
224 //bool favorPol = (n.getKind()==AND)==pol;
225 //conditional?
226 }else if( n.getKind()==ITE ){
227 //can not do anything
228 }else if( n.getKind()==APPLY_UF ){
229 //predicate case
230 if( isBoundVarApplyUf( n ) ){
231 Node n_def = NodeManager::currentNM()->mkConst( pol );
232 Trace("macros-quant") << "Macro found for " << f << std::endl;
233 Trace("macros") << "* " << n_def << " is a macro for " << n.getOperator() << std::endl;
234 d_macro_defs[ n.getOperator() ] = n_def;
235 }
236 }else{
237 //literal case
238 if( isMacroLiteral( n, pol ) ){
239 std::vector< Node > candidates;
240 for( size_t i=0; i<n.getNumChildren(); i++ ){
241 getMacroCandidates( n[i], candidates );
242 }
243 for( size_t i=0; i<candidates.size(); i++ ){
244 Node m = candidates[i];
245 Node op = m.getOperator();
246 if( d_macro_defs.find( op )==d_macro_defs.end() ){
247 std::vector< Node > fvs;
248 getFreeVariables( m, args, fvs, false );
249 //get definition and condition
250 Node n_def = solveInEquality( m, n ); //definition for the macro
251 //definition must exist and not contain any free variables apart from fvs
252 if( !n_def.isNull() && !getFreeVariables( n_def, args, fvs, true ) && !containsBadOp( n_def, op ) ){
253 Node n_cond; //condition when this definition holds
254 //conditional must not contain any free variables apart from fvs
255 if( n_cond.isNull() || !getFreeVariables( n_cond, args, fvs, true ) ){
256 Trace("macros-debug") << m << " is possible macro in " << f << std::endl;
257 //now we must rewrite candidates[i] to a term of form g( x1, ..., xn ) where
258 // x1 ... xn are distinct variables
259 if( d_macro_basis[op].empty() ){
260 for( size_t a=0; a<m.getNumChildren(); a++ ){
261 std::stringstream ss;
262 ss << "mda_" << op << "";
263 Node v = NodeManager::currentNM()->mkSkolem( ss.str(), m[a].getType(), "created during macro definition recognition" );
264 d_macro_basis[op].push_back( v );
265 }
266 }
267 std::map< Node, Node > solved;
268 for( size_t a=0; a<m.getNumChildren(); a++ ){
269 solved[m[a]] = d_macro_basis[op][a];
270 }
271 std::vector< Node > vars;
272 std::vector< Node > subs;
273 if( getSubstitution( fvs, solved, vars, subs, true ) ){
274 n_def = n_def.substitute( vars.begin(), vars.end(), subs.begin(), subs.end() );
275 Trace("macros-quant") << "Macro found for " << f << std::endl;
276 Trace("macros") << "* " << n_def << " is a macro for " << op << std::endl;
277 d_macro_defs[op] = n_def;
278 return;
279 }
280 }
281 }
282 }
283 }
284 }
285 }
286 }
287
288 Node QuantifierMacros::simplify( Node n ){
289 Trace("macros-debug") << "simplify " << n << std::endl;
290 std::vector< Node > children;
291 bool childChanged = false;
292 for( size_t i=0; i<n.getNumChildren(); i++ ){
293 Node nn = simplify( n[i] );
294 children.push_back( nn );
295 childChanged = childChanged || nn!=n[i];
296 }
297 if( n.getKind()==APPLY_UF ){
298 Node op = n.getOperator();
299 if( d_macro_defs.find( op )!=d_macro_defs.end() && !d_macro_defs[op].isNull() ){
300 //do substitution if necessary
301 std::map< Node, std::vector< Node > >::iterator it = d_macro_basis.find( op );
302 Node ret = d_macro_defs[op];
303 if( it!=d_macro_basis.end() ){
304 ret = ret.substitute( it->second.begin(), it->second.end(), children.begin(), children.end() );
305 }
306 return ret;
307 }
308 }
309 if( childChanged ){
310 if( n.getMetaKind() == kind::metakind::PARAMETERIZED ){
311 children.insert( children.begin(), n.getOperator() );
312 }
313 return NodeManager::currentNM()->mkNode( n.getKind(), children );
314 }else{
315 return n;
316 }
317 }