Refactoring Options Handler & Library Cycle Breaking
[cvc5.git] / src / theory / sort_inference.cpp
1 /********************* */
2 /*! \file sort_inference.cpp
3 ** \verbatim
4 ** Original author: Andrew Reynolds
5 ** Major contributors: Morgan Deters
6 ** Minor contributors (to current version): Kshitij Bansal
7 ** This file is part of the CVC4 project.
8 ** Copyright (c) 2009-2014 New York University and The University of Iowa
9 ** See the file COPYING in the top-level source directory for licensing
10 ** information.\endverbatim
11 **
12 ** \brief Sort inference module
13 **
14 ** This class implements sort inference, based on a simple algorithm:
15 ** First, we assume all functions and predicates have distinct uninterpreted types.
16 ** One pass is made through the input assertions, while a union-find data structure
17 ** maintains necessary information regarding constraints on these types.
18 **/
19
20 #include "theory/sort_inference.h"
21
22 #include <vector>
23
24 #include "options/quantifiers_options.h"
25 #include "options/smt_options.h"
26 #include "options/uf_options.h"
27 #include "proof/proof_manager.h"
28 #include "theory/rewriter.h"
29
30 using namespace std;
31
32 namespace CVC4 {
33
34 void SortInference::UnionFind::print(const char * c){
35 for( std::map< int, int >::iterator it = d_eqc.begin(); it != d_eqc.end(); ++it ){
36 Trace(c) << "s_" << it->first << " = s_" << it->second << ", ";
37 }
38 for( unsigned i=0; i<d_deq.size(); i++ ){
39 Trace(c) << "s_" << d_deq[i].first << " != s_" << d_deq[i].second << ", ";
40 }
41 Trace(c) << std::endl;
42 }
43
44 void SortInference::UnionFind::set( UnionFind& c ) {
45 clear();
46 for( std::map< int, int >::iterator it = c.d_eqc.begin(); it != c.d_eqc.end(); ++it ){
47 d_eqc[ it->first ] = it->second;
48 }
49 d_deq.insert( d_deq.end(), c.d_deq.begin(), c.d_deq.end() );
50 }
51
52 int SortInference::UnionFind::getRepresentative( int t ){
53 std::map< int, int >::iterator it = d_eqc.find( t );
54 if( it==d_eqc.end() || it->second==t ){
55 return t;
56 }else{
57 int rt = getRepresentative( it->second );
58 d_eqc[t] = rt;
59 return rt;
60 }
61 }
62
63 void SortInference::UnionFind::setEqual( int t1, int t2 ){
64 if( t1!=t2 ){
65 int rt1 = getRepresentative( t1 );
66 int rt2 = getRepresentative( t2 );
67 if( rt1>rt2 ){
68 d_eqc[rt1] = rt2;
69 }else{
70 d_eqc[rt2] = rt1;
71 }
72 }
73 }
74
75 bool SortInference::UnionFind::isValid() {
76 for( unsigned i=0; i<d_deq.size(); i++ ){
77 if( areEqual( d_deq[i].first, d_deq[i].second ) ){
78 return false;
79 }
80 }
81 return true;
82 }
83
84
85 void SortInference::recordSubsort( TypeNode tn, int s ){
86 s = d_type_union_find.getRepresentative( s );
87 if( std::find( d_sub_sorts.begin(), d_sub_sorts.end(), s )==d_sub_sorts.end() ){
88 d_sub_sorts.push_back( s );
89 d_type_sub_sorts[tn].push_back( s );
90 }
91 }
92
93 void SortInference::printSort( const char* c, int t ){
94 int rt = d_type_union_find.getRepresentative( t );
95 if( d_type_types.find( rt )!=d_type_types.end() ){
96 Trace(c) << d_type_types[rt];
97 }else{
98 Trace(c) << "s_" << rt;
99 }
100 }
101
102 void SortInference::reset() {
103 d_sub_sorts.clear();
104 d_non_monotonic_sorts.clear();
105 d_type_sub_sorts.clear();
106 //reset info
107 sortCount = 1;
108 d_type_union_find.clear();
109 d_type_types.clear();
110 d_id_for_types.clear();
111 d_op_return_types.clear();
112 d_op_arg_types.clear();
113 d_var_types.clear();
114 //for rewriting
115 d_symbol_map.clear();
116 d_const_map.clear();
117 }
118
119 bool SortInference::simplify( std::vector< Node >& assertions ){
120 Trace("sort-inference") << "Calculating sort inference..." << std::endl;
121 //process all assertions
122 for( unsigned i=0; i<assertions.size(); i++ ){
123 Trace("sort-inference-debug") << "Process " << assertions[i] << std::endl;
124 std::map< Node, Node > var_bound;
125 process( assertions[i], var_bound );
126 }
127 for( std::map< Node, int >::iterator it = d_op_return_types.begin(); it != d_op_return_types.end(); ++it ){
128 Trace("sort-inference") << it->first << " : ";
129 TypeNode retTn = it->first.getType();
130 if( !d_op_arg_types[ it->first ].empty() ){
131 Trace("sort-inference") << "( ";
132 for( size_t i=0; i<d_op_arg_types[ it->first ].size(); i++ ){
133 recordSubsort( retTn[i], d_op_arg_types[ it->first ][i] );
134 printSort( "sort-inference", d_op_arg_types[ it->first ][i] );
135 Trace("sort-inference") << " ";
136 }
137 Trace("sort-inference") << ") -> ";
138 retTn = retTn[(int)retTn.getNumChildren()-1];
139 }
140 recordSubsort( retTn, it->second );
141 printSort( "sort-inference", it->second );
142 Trace("sort-inference") << std::endl;
143 }
144 for( std::map< Node, std::map< Node, int > >::iterator it = d_var_types.begin(); it != d_var_types.end(); ++it ){
145 Trace("sort-inference") << "Quantified formula : " << it->first << " : " << std::endl;
146 for( unsigned i=0; i<it->first[0].getNumChildren(); i++ ){
147 recordSubsort( it->first[0][i].getType(), it->second[it->first[0][i]] );
148 printSort( "sort-inference", it->second[it->first[0][i]] );
149 Trace("sort-inference") << std::endl;
150 }
151 Trace("sort-inference") << std::endl;
152 }
153
154 if( !options::ufssSymBreak() ){
155 bool rewritten = false;
156 //determine monotonicity of sorts
157 for( unsigned i=0; i<assertions.size(); i++ ){
158 Trace("sort-inference-debug") << "Process monotonicity for " << assertions[i] << std::endl;
159 std::map< Node, Node > var_bound;
160 processMonotonic( assertions[i], true, true, var_bound );
161 }
162
163 Trace("sort-inference") << "We have " << d_sub_sorts.size() << " sub-sorts : " << std::endl;
164 for( unsigned i=0; i<d_sub_sorts.size(); i++ ){
165 printSort( "sort-inference", d_sub_sorts[i] );
166 if( d_type_types.find( d_sub_sorts[i] )!=d_type_types.end() ){
167 Trace("sort-inference") << " is interpreted." << std::endl;
168 }else if( d_non_monotonic_sorts.find( d_sub_sorts[i] )==d_non_monotonic_sorts.end() ){
169 Trace("sort-inference") << " is monotonic." << std::endl;
170 }else{
171 Trace("sort-inference") << " is not monotonic." << std::endl;
172 }
173 }
174
175 //simplify all assertions by introducing new symbols wherever necessary
176 for( unsigned i=0; i<assertions.size(); i++ ){
177 Node prev = assertions[i];
178 std::map< Node, Node > var_bound;
179 Trace("sort-inference-debug") << "Rewrite " << assertions[i] << std::endl;
180 Node curr = simplify( assertions[i], var_bound );
181 Trace("sort-inference-debug") << "Done." << std::endl;
182 if( curr!=assertions[i] ){
183 curr = theory::Rewriter::rewrite( curr );
184 rewritten = true;
185 Trace("sort-inference-rewrite") << assertions << std::endl;
186 Trace("sort-inference-rewrite") << " --> " << curr << std::endl;
187 PROOF( ProofManager::currentPM()->addDependence(curr, assertions[i]); );
188 assertions[i] = curr;
189 }
190 }
191 //now, ensure constants are distinct
192 for( std::map< TypeNode, std::map< Node, Node > >::iterator it = d_const_map.begin(); it != d_const_map.end(); ++it ){
193 std::vector< Node > consts;
194 for( std::map< Node, Node >::iterator it2 = it->second.begin(); it2 != it->second.end(); ++it2 ){
195 consts.push_back( it2->second );
196 }
197 //TODO: add lemma enforcing introduced constants to be distinct
198 }
199
200 //enforce constraints based on monotonicity
201 for( std::map< TypeNode, std::vector< int > >::iterator it = d_type_sub_sorts.begin(); it != d_type_sub_sorts.end(); ++it ){
202 int nmonSort = -1;
203 for( unsigned i=0; i<it->second.size(); i++ ){
204 if( d_non_monotonic_sorts.find( it->second[i] )!=d_non_monotonic_sorts.end() ){
205 nmonSort = it->second[i];
206 break;
207 }
208 }
209 if( nmonSort!=-1 ){
210 std::vector< Node > injections;
211 TypeNode base_tn = getOrCreateTypeForId( nmonSort, it->first );
212 for( unsigned i=0; i<it->second.size(); i++ ){
213 if( it->second[i]!=nmonSort ){
214 TypeNode new_tn = getOrCreateTypeForId( it->second[i], it->first );
215 //make injection to nmonSort
216 Node a1 = mkInjection( new_tn, base_tn );
217 injections.push_back( a1 );
218 if( d_non_monotonic_sorts.find( it->second[i] )!=d_non_monotonic_sorts.end() ){
219 //also must make injection from nmonSort to this
220 Node a2 = mkInjection( base_tn, new_tn );
221 injections.push_back( a2 );
222 }
223 }
224 }
225 Trace("sort-inference-rewrite") << "Add the following injections for " << it->first << " to ensure consistency wrt non-monotonic sorts : " << std::endl;
226 for( unsigned j=0; j<injections.size(); j++ ){
227 Trace("sort-inference-rewrite") << " " << injections[j] << std::endl;
228 }
229 assertions.insert( assertions.end(), injections.begin(), injections.end() );
230 if( !injections.empty() ){
231 rewritten = true;
232 }
233 }
234 }
235 //no sub-sort information is stored
236 reset();
237 return rewritten;
238 }
239 /*
240 else if( !options::ufssSymBreak() ){
241 //just add the unit lemmas between constants
242 std::map< TypeNode, std::map< int, Node > > constants;
243 for( std::map< Node, int >::iterator it = d_op_return_types.begin(); it != d_op_return_types.end(); ++it ){
244 int rt = d_type_union_find.getRepresentative( it->second );
245 if( d_op_arg_types[ it->first ].empty() ){
246 TypeNode tn = it->first.getType();
247 if( constants[ tn ].find( rt )==constants[ tn ].end() ){
248 constants[ tn ][ rt ] = it->first;
249 }
250 }
251 }
252 //add unit lemmas for each constant
253 for( std::map< TypeNode, std::map< int, Node > >::iterator it = constants.begin(); it != constants.end(); ++it ){
254 Node first_const;
255 for( std::map< int, Node >::iterator it2 = it->second.begin(); it2 != it->second.end(); ++it2 ){
256 if( first_const.isNull() ){
257 first_const = it2->second;
258 }else{
259 Node eq = first_const.eqNode( it2->second );
260 //eq = Rewriter::rewrite( eq );
261 Trace("sort-inference-lemma") << "Sort inference lemma : " << eq << std::endl;
262 assertions.push_back( eq );
263 }
264 }
265 }
266 }
267 */
268 initialSortCount = sortCount;
269 return false;
270 }
271
272 void SortInference::setEqual( int t1, int t2 ){
273 if( t1!=t2 ){
274 int rt1 = d_type_union_find.getRepresentative( t1 );
275 int rt2 = d_type_union_find.getRepresentative( t2 );
276 if( rt1!=rt2 ){
277 Trace("sort-inference-debug") << "Set equal : ";
278 printSort( "sort-inference-debug", rt1 );
279 Trace("sort-inference-debug") << " ";
280 printSort( "sort-inference-debug", rt2 );
281 Trace("sort-inference-debug") << std::endl;
282 /*
283 d_type_eq_class[rt1].insert( d_type_eq_class[rt1].end(), d_type_eq_class[rt2].begin(), d_type_eq_class[rt2].end() );
284 d_type_eq_class[rt2].clear();
285 Trace("sort-inference-debug") << "EqClass : { ";
286 for( int i=0; i<(int)d_type_eq_class[rt1].size(); i++ ){
287 Trace("sort-inference-debug") << d_type_eq_class[rt1][i] << ", ";
288 }
289 Trace("sort-inference-debug") << "}" << std::endl;
290 */
291 if( rt2>rt1 ){
292 //swap
293 int swap = rt1;
294 rt1 = rt2;
295 rt2 = swap;
296 }
297 std::map< int, TypeNode >::iterator it1 = d_type_types.find( rt1 );
298 if( it1!=d_type_types.end() ){
299 if( d_type_types.find( rt2 )==d_type_types.end() ){
300 d_type_types[rt2] = it1->second;
301 d_type_types.erase( rt1 );
302 }else{
303 Trace("sort-inference-debug") << "...fail : associated with types " << d_type_types[rt1] << " and " << d_type_types[rt2] << std::endl;
304 return;
305 }
306 }
307 d_type_union_find.d_eqc[rt1] = rt2;
308 }
309 }
310 }
311
312 int SortInference::getIdForType( TypeNode tn ){
313 //register the return type
314 std::map< TypeNode, int >::iterator it = d_id_for_types.find( tn );
315 if( it==d_id_for_types.end() ){
316 int sc = sortCount;
317 d_type_types[ sortCount ] = tn;
318 d_id_for_types[ tn ] = sortCount;
319 sortCount++;
320 return sc;
321 }else{
322 return it->second;
323 }
324 }
325
326 int SortInference::process( Node n, std::map< Node, Node >& var_bound ){
327 //add to variable bindings
328 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
329 if( d_var_types.find( n )!=d_var_types.end() ){
330 return getIdForType( n.getType() );
331 }else{
332 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
333 //apply sort inference to quantified variables
334 d_var_types[n][ n[0][i] ] = sortCount;
335 sortCount++;
336
337 //type of the quantified variable must be the same
338 var_bound[ n[0][i] ] = n;
339 }
340 }
341 }
342
343 //process children
344 std::vector< Node > children;
345 std::vector< int > child_types;
346 for( size_t i=0; i<n.getNumChildren(); i++ ){
347 bool processChild = true;
348 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
349 processChild = options::userPatternsQuant()==theory::quantifiers::USER_PAT_MODE_IGNORE ? i==1 : i>=1;
350 }
351 if( processChild ){
352 children.push_back( n[i] );
353 child_types.push_back( process( n[i], var_bound ) );
354 }
355 }
356
357 //remove from variable bindings
358 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
359 //erase from variable bound
360 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
361 var_bound.erase( n[0][i] );
362 }
363 }
364 Trace("sort-inference-debug") << "...Process " << n << std::endl;
365
366 int retType;
367 if( n.getKind()==kind::EQUAL ){
368 Trace("sort-inference-debug") << "For equality " << n << ", set equal types from : " << n[0].getType() << " " << n[1].getType() << std::endl;
369 //if original types are mixed (e.g. Int/Real), don't commit type equality in either direction
370 if( n[0].getType()!=n[1].getType() ){
371 //for now, assume the original types
372 for( unsigned i=0; i<2; i++ ){
373 int ct = getIdForType( n[i].getType() );
374 setEqual( child_types[i], ct );
375 }
376 }else{
377 //we only require that the left and right hand side must be equal
378 setEqual( child_types[0], child_types[1] );
379 }
380 //int eqType = getIdForType( n[0].getType() );
381 //setEqual( child_types[0], eqType );
382 //setEqual( child_types[1], eqType );
383 retType = getIdForType( n.getType() );
384 }else if( n.getKind()==kind::APPLY_UF ){
385 Node op = n.getOperator();
386 TypeNode tn_op = op.getType();
387 if( d_op_return_types.find( op )==d_op_return_types.end() ){
388 if( n.getType().isBoolean() ){
389 //use booleans
390 d_op_return_types[op] = getIdForType( n.getType() );
391 }else{
392 //assign arbitrary sort for return type
393 d_op_return_types[op] = sortCount;
394 sortCount++;
395 }
396 //d_type_eq_class[sortCount].push_back( op );
397 //assign arbitrary sort for argument types
398 for( size_t i=0; i<n.getNumChildren(); i++ ){
399 d_op_arg_types[op].push_back( sortCount );
400 sortCount++;
401 }
402 }
403 for( size_t i=0; i<n.getNumChildren(); i++ ){
404 //the argument of the operator must match the return type of the subterm
405 if( n[i].getType()!=tn_op[i] ){
406 //if type mismatch, assume original types
407 Trace("sort-inference-debug") << "Argument " << i << " of " << op << " " << n[i] << " has type " << n[i].getType();
408 Trace("sort-inference-debug") << ", while operator arg has type " << tn_op[i] << std::endl;
409 int ct1 = getIdForType( n[i].getType() );
410 setEqual( child_types[i], ct1 );
411 int ct2 = getIdForType( tn_op[i] );
412 setEqual( d_op_arg_types[op][i], ct2 );
413 }else{
414 setEqual( child_types[i], d_op_arg_types[op][i] );
415 }
416 }
417 //return type is the return type
418 retType = d_op_return_types[op];
419 }else{
420 std::map< Node, Node >::iterator it = var_bound.find( n );
421 if( it!=var_bound.end() ){
422 Trace("sort-inference-debug") << n << " is a bound variable." << std::endl;
423 //the return type was specified while binding
424 retType = d_var_types[it->second][n];
425 }else if( n.getKind() == kind::VARIABLE || n.getKind()==kind::SKOLEM ){
426 Trace("sort-inference-debug") << n << " is a variable." << std::endl;
427 if( d_op_return_types.find( n )==d_op_return_types.end() ){
428 //assign arbitrary sort
429 d_op_return_types[n] = sortCount;
430 sortCount++;
431 //d_type_eq_class[sortCount].push_back( n );
432 }
433 retType = d_op_return_types[n];
434 //}else if( n.isConst() ){
435 // Trace("sort-inference-debug") << n << " is a constant." << std::endl;
436 //can be any type we want
437 // retType = sortCount;
438 // sortCount++;
439 }else{
440 Trace("sort-inference-debug") << n << " is a interpreted symbol." << std::endl;
441 //it is an interpretted term
442 for( size_t i=0; i<children.size(); i++ ){
443 Trace("sort-inference-debug") << children[i] << " forced to have " << children[i].getType() << std::endl;
444 //must enforce the actual type of the operator on the children
445 int ct = getIdForType( children[i].getType() );
446 setEqual( child_types[i], ct );
447 }
448 //return type must be the actual return type
449 retType = getIdForType( n.getType() );
450 }
451 }
452 Trace("sort-inference-debug") << "...Type( " << n << " ) = ";
453 printSort("sort-inference-debug", retType );
454 Trace("sort-inference-debug") << std::endl;
455 return retType;
456 }
457
458 void SortInference::processMonotonic( Node n, bool pol, bool hasPol, std::map< Node, Node >& var_bound ) {
459 Trace("sort-inference-debug") << "...Process monotonic " << pol << " " << hasPol << " " << n << std::endl;
460 if( n.getKind()==kind::FORALL ){
461 for( unsigned i=0; i<n[0].getNumChildren(); i++ ){
462 var_bound[n[0][i]] = n;
463 }
464 processMonotonic( n[1], pol, hasPol, var_bound );
465 for( unsigned i=0; i<n[0].getNumChildren(); i++ ){
466 var_bound.erase( n[0][i] );
467 }
468 }else if( n.getKind()==kind::EQUAL ){
469 if( !hasPol || pol ){
470 for( unsigned i=0; i<2; i++ ){
471 if( var_bound.find( n[i] )!=var_bound.end() ){
472 int sid = getSortId( var_bound[n[i]], n[i] );
473 d_non_monotonic_sorts[sid] = true;
474 break;
475 }
476 }
477 }
478 }
479 for( unsigned i=0; i<n.getNumChildren(); i++ ){
480 bool npol = pol;
481 bool nhasPol = hasPol;
482 if( n.getKind()==kind::NOT || ( n.getKind()==kind::IMPLIES && i==0 ) ){
483 npol = !npol;
484 }
485 if( ( n.getKind()==kind::ITE && i==0 ) || n.getKind()==kind::XOR || n.getKind()==kind::IFF ){
486 nhasPol = false;
487 }
488 processMonotonic( n[i], npol, nhasPol, var_bound );
489 }
490 }
491
492
493 TypeNode SortInference::getOrCreateTypeForId( int t, TypeNode pref ){
494 int rt = d_type_union_find.getRepresentative( t );
495 if( d_type_types.find( rt )!=d_type_types.end() ){
496 return d_type_types[rt];
497 }else{
498 TypeNode retType;
499 //see if we can assign pref
500 if( !pref.isNull() && d_id_for_types.find( pref )==d_id_for_types.end() ){
501 retType = pref;
502 }else{
503 //must create new type
504 std::stringstream ss;
505 ss << "it_" << t << "_" << pref;
506 retType = NodeManager::currentNM()->mkSort( ss.str() );
507 }
508 Trace("sort-inference") << "-> Make type " << retType << " to correspond to ";
509 printSort("sort-inference", t );
510 Trace("sort-inference") << std::endl;
511 d_id_for_types[ retType ] = rt;
512 d_type_types[ rt ] = retType;
513 return retType;
514 }
515 }
516
517 TypeNode SortInference::getTypeForId( int t ){
518 int rt = d_type_union_find.getRepresentative( t );
519 if( d_type_types.find( rt )!=d_type_types.end() ){
520 return d_type_types[rt];
521 }else{
522 return TypeNode::null();
523 }
524 }
525
526 Node SortInference::getNewSymbol( Node old, TypeNode tn ){
527 if( tn==old.getType() ){
528 return old;
529 }else if( old.isConst() ){
530 //must make constant of type tn
531 if( d_const_map[tn].find( old )==d_const_map[tn].end() ){
532 std::stringstream ss;
533 ss << "ic_" << tn << "_" << old;
534 d_const_map[tn][ old ] = NodeManager::currentNM()->mkSkolem( ss.str(), tn, "constant created during sort inference" ); //use mkConst???
535 }
536 return d_const_map[tn][ old ];
537 }else if( old.getKind()==kind::BOUND_VARIABLE ){
538 std::stringstream ss;
539 ss << "b_" << old;
540 return NodeManager::currentNM()->mkBoundVar( ss.str(), tn );
541 }else{
542 std::stringstream ss;
543 ss << "i_" << old;
544 return NodeManager::currentNM()->mkSkolem( ss.str(), tn, "created during sort inference" );
545 }
546 }
547
548 Node SortInference::simplify( Node n, std::map< Node, Node >& var_bound ){
549 Trace("sort-inference-debug2") << "Simplify " << n << std::endl;
550 std::vector< Node > children;
551 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
552 //recreate based on types of variables
553 std::vector< Node > new_children;
554 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
555 TypeNode tn = getOrCreateTypeForId( d_var_types[n][ n[0][i] ], n[0][i].getType() );
556 Node v = getNewSymbol( n[0][i], tn );
557 Trace("sort-inference-debug2") << "Map variable " << n[0][i] << " to " << v << std::endl;
558 new_children.push_back( v );
559 var_bound[ n[0][i] ] = v;
560 }
561 children.push_back( NodeManager::currentNM()->mkNode( n[0].getKind(), new_children ) );
562 }
563
564 //process children
565 if( n.getMetaKind() == kind::metakind::PARAMETERIZED ){
566 children.push_back( n.getOperator() );
567 }
568 bool childChanged = false;
569 for( size_t i=0; i<n.getNumChildren(); i++ ){
570 bool processChild = true;
571 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
572 processChild = options::userPatternsQuant()==theory::quantifiers::USER_PAT_MODE_IGNORE ? i==1 : i>=1;
573 }
574 if( processChild ){
575 Node nc = simplify( n[i], var_bound );
576 Trace("sort-inference-debug2") << "Simplify " << i << " " << n[i] << " returned " << nc << std::endl;
577 children.push_back( nc );
578 childChanged = childChanged || nc!=n[i];
579 }
580 }
581
582 //remove from variable bindings
583 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
584 //erase from variable bound
585 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
586 Trace("sort-inference-debug2") << "Remove bound for " << n[0][i] << std::endl;
587 var_bound.erase( n[0][i] );
588 }
589 return NodeManager::currentNM()->mkNode( n.getKind(), children );
590 }else if( n.getKind()==kind::EQUAL ){
591 TypeNode tn1 = children[0].getType();
592 TypeNode tn2 = children[1].getType();
593 if( !tn1.isSubtypeOf( tn2 ) && !tn2.isSubtypeOf( tn1 ) ){
594 if( children[0].isConst() ){
595 children[0] = getNewSymbol( children[0], children[1].getType() );
596 }else if( children[1].isConst() ){
597 children[1] = getNewSymbol( children[1], children[0].getType() );
598 }else{
599 Trace("sort-inference-warn") << "Sort inference created bad equality: " << children[0] << " = " << children[1] << std::endl;
600 Trace("sort-inference-warn") << " Types : " << children[0].getType() << " " << children[1].getType() << std::endl;
601 Assert( false );
602 }
603 }
604 return NodeManager::currentNM()->mkNode( kind::EQUAL, children );
605 }else if( n.getKind()==kind::APPLY_UF ){
606 Node op = n.getOperator();
607 if( d_symbol_map.find( op )==d_symbol_map.end() ){
608 //make the new operator if necessary
609 bool opChanged = false;
610 std::vector< TypeNode > argTypes;
611 for( size_t i=0; i<n.getNumChildren(); i++ ){
612 TypeNode tn = getOrCreateTypeForId( d_op_arg_types[op][i], n[i].getType() );
613 argTypes.push_back( tn );
614 if( tn!=n[i].getType() ){
615 opChanged = true;
616 }
617 }
618 TypeNode retType = getOrCreateTypeForId( d_op_return_types[op], n.getType() );
619 if( retType!=n.getType() ){
620 opChanged = true;
621 }
622 if( opChanged ){
623 std::stringstream ss;
624 ss << "io_" << op;
625 TypeNode typ = NodeManager::currentNM()->mkFunctionType( argTypes, retType );
626 d_symbol_map[op] = NodeManager::currentNM()->mkSkolem( ss.str(), typ, "op created during sort inference" );
627 Trace("setp-model") << "Function " << op << " is replaced with " << d_symbol_map[op] << std::endl;
628 d_model_replace_f[op] = d_symbol_map[op];
629 }else{
630 d_symbol_map[op] = op;
631 }
632 }
633 children[0] = d_symbol_map[op];
634 //make sure all children have been taken care of
635 for( size_t i=0; i<n.getNumChildren(); i++ ){
636 TypeNode tn = children[i+1].getType();
637 TypeNode tna = getTypeForId( d_op_arg_types[op][i] );
638 if( tn!=tna ){
639 if( n[i].isConst() ){
640 children[i+1] = getNewSymbol( n[i], tna );
641 }else{
642 Trace("sort-inference-warn") << "Sort inference created bad child: " << n << " " << n[i] << " " << tn << " " << tna << std::endl;
643 Assert( false );
644 }
645 }
646 }
647 return NodeManager::currentNM()->mkNode( kind::APPLY_UF, children );
648 }else{
649 std::map< Node, Node >::iterator it = var_bound.find( n );
650 if( it!=var_bound.end() ){
651 return it->second;
652 }else if( n.getKind() == kind::VARIABLE || n.getKind() == kind::SKOLEM ){
653 if( d_symbol_map.find( n )==d_symbol_map.end() ){
654 TypeNode tn = getOrCreateTypeForId( d_op_return_types[n], n.getType() );
655 d_symbol_map[n] = getNewSymbol( n, tn );
656 }
657 return d_symbol_map[n];
658 }else if( n.isConst() ){
659 //just return n, we will fix at higher scope
660 return n;
661 }else{
662 if( childChanged ){
663 return NodeManager::currentNM()->mkNode( n.getKind(), children );
664 }else{
665 return n;
666 }
667 }
668 }
669
670 }
671
672 Node SortInference::mkInjection( TypeNode tn1, TypeNode tn2 ) {
673 std::vector< TypeNode > tns;
674 tns.push_back( tn1 );
675 TypeNode typ = NodeManager::currentNM()->mkFunctionType( tns, tn2 );
676 Node f = NodeManager::currentNM()->mkSkolem( "inj", typ, "injection for monotonicity constraint" );
677 Trace("sort-inference") << "-> Make injection " << f << " from " << tn1 << " to " << tn2 << std::endl;
678 Node v1 = NodeManager::currentNM()->mkBoundVar( "?x", tn1 );
679 Node v2 = NodeManager::currentNM()->mkBoundVar( "?y", tn1 );
680 Node ret = NodeManager::currentNM()->mkNode( kind::FORALL,
681 NodeManager::currentNM()->mkNode( kind::BOUND_VAR_LIST, v1, v2 ),
682 NodeManager::currentNM()->mkNode( kind::OR,
683 NodeManager::currentNM()->mkNode( kind::APPLY_UF, f, v1 ).eqNode( NodeManager::currentNM()->mkNode( kind::APPLY_UF, f, v2 ) ).negate(),
684 v1.eqNode( v2 ) ) );
685 ret = theory::Rewriter::rewrite( ret );
686 return ret;
687 }
688
689 int SortInference::getSortId( Node n ) {
690 Node op = n.getKind()==kind::APPLY_UF ? n.getOperator() : n;
691 if( d_op_return_types.find( op )!=d_op_return_types.end() ){
692 return d_type_union_find.getRepresentative( d_op_return_types[op] );
693 }else{
694 return 0;
695 }
696 }
697
698 int SortInference::getSortId( Node f, Node v ) {
699 if( d_var_types.find( f )!=d_var_types.end() ){
700 return d_type_union_find.getRepresentative( d_var_types[f][v] );
701 }else{
702 return 0;
703 }
704 }
705
706 void SortInference::setSkolemVar( Node f, Node v, Node sk ){
707 Trace("sort-inference-temp") << "Set skolem var for " << f << ", variable " << v << std::endl;
708 if( isWellSortedFormula( f ) && d_var_types.find( f )==d_var_types.end() ){
709 //calculate the sort for variables if not done so already
710 std::map< Node, Node > var_bound;
711 process( f, var_bound );
712 }
713 d_op_return_types[sk] = getSortId( f, v );
714 Trace("sort-inference-temp") << "Set skolem sort id for " << sk << " to " << d_op_return_types[sk] << std::endl;
715 }
716
717 bool SortInference::isWellSortedFormula( Node n ) {
718 if( n.getType().isBoolean() && n.getKind()!=kind::APPLY_UF ){
719 for( unsigned i=0; i<n.getNumChildren(); i++ ){
720 if( !isWellSortedFormula( n[i] ) ){
721 return false;
722 }
723 }
724 return true;
725 }else{
726 return isWellSorted( n );
727 }
728 }
729
730 bool SortInference::isWellSorted( Node n ) {
731 if( getSortId( n )==0 ){
732 return false;
733 }else{
734 if( n.getKind()==kind::APPLY_UF ){
735 for( unsigned i=0; i<n.getNumChildren(); i++ ){
736 int s1 = getSortId( n[i] );
737 int s2 = d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] );
738 if( s1!=s2 ){
739 return false;
740 }
741 if( !isWellSorted( n[i] ) ){
742 return false;
743 }
744 }
745 }
746 return true;
747 }
748 }
749
750 void SortInference::getSortConstraints( Node n, UnionFind& uf ) {
751 if( n.getKind()==kind::APPLY_UF ){
752 for( unsigned i=0; i<n.getNumChildren(); i++ ){
753 getSortConstraints( n[i], uf );
754 uf.setEqual( getSortId( n[i] ), d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] ) );
755 }
756 }
757 }
758
759 }/* CVC4 namespace */