simplify mkSkolem naming system: don't use $$
[cvc5.git] / src / util / sort_inference.cpp
1 /********************* */
2 /*! \file sort_inference.cpp
3 ** \verbatim
4 ** Original author: Andrew Reynolds
5 ** Major contributors: Morgan Deters
6 ** Minor contributors (to current version): none
7 ** This file is part of the CVC4 project.
8 ** Copyright (c) 2009-2013 New York University and The University of Iowa
9 ** See the file COPYING in the top-level source directory for licensing
10 ** information.\endverbatim
11 **
12 ** \brief Sort inference module
13 **
14 ** This class implements sort inference, based on a simple algorithm:
15 ** First, we assume all functions and predicates have distinct uninterpreted types.
16 ** One pass is made through the input assertions, while a union-find data structure
17 ** maintains necessary information regarding constraints on these types.
18 **/
19
20 #include <vector>
21
22 #include "util/sort_inference.h"
23 #include "theory/uf/options.h"
24 #include "smt/options.h"
25 //#include "theory/rewriter.h"
26
27 using namespace CVC4;
28 using namespace std;
29
30 namespace CVC4 {
31
32 void SortInference::UnionFind::print(const char * c){
33 for( std::map< int, int >::iterator it = d_eqc.begin(); it != d_eqc.end(); ++it ){
34 Trace(c) << "s_" << it->first << " = s_" << it->second << ", ";
35 }
36 for( unsigned i=0; i<d_deq.size(); i++ ){
37 Trace(c) << "s_" << d_deq[i].first << " != s_" << d_deq[i].second << ", ";
38 }
39 Trace(c) << std::endl;
40 }
41 void SortInference::UnionFind::set( UnionFind& c ) {
42 clear();
43 for( std::map< int, int >::iterator it = c.d_eqc.begin(); it != c.d_eqc.end(); ++it ){
44 d_eqc[ it->first ] = it->second;
45 }
46 d_deq.insert( d_deq.end(), c.d_deq.begin(), c.d_deq.end() );
47 }
48 int SortInference::UnionFind::getRepresentative( int t ){
49 std::map< int, int >::iterator it = d_eqc.find( t );
50 if( it==d_eqc.end() || it->second==t ){
51 return t;
52 }else{
53 int rt = getRepresentative( it->second );
54 d_eqc[t] = rt;
55 return rt;
56 }
57 }
58 void SortInference::UnionFind::setEqual( int t1, int t2 ){
59 if( t1!=t2 ){
60 int rt1 = getRepresentative( t1 );
61 int rt2 = getRepresentative( t2 );
62 if( rt1>rt2 ){
63 d_eqc[rt1] = rt2;
64 }else{
65 d_eqc[rt2] = rt1;
66 }
67 }
68 }
69 bool SortInference::UnionFind::isValid() {
70 for( unsigned i=0; i<d_deq.size(); i++ ){
71 if( areEqual( d_deq[i].first, d_deq[i].second ) ){
72 return false;
73 }
74 }
75 return true;
76 }
77
78
79 void SortInference::recordSubsort( TypeNode tn, int s ){
80 s = d_type_union_find.getRepresentative( s );
81 if( std::find( d_sub_sorts.begin(), d_sub_sorts.end(), s )==d_sub_sorts.end() ){
82 d_sub_sorts.push_back( s );
83 d_type_sub_sorts[tn].push_back( s );
84 }
85 }
86
87 void SortInference::printSort( const char* c, int t ){
88 int rt = d_type_union_find.getRepresentative( t );
89 if( d_type_types.find( rt )!=d_type_types.end() ){
90 Trace(c) << d_type_types[rt];
91 }else{
92 Trace(c) << "s_" << rt;
93 }
94 }
95
96 void SortInference::reset() {
97 d_sub_sorts.clear();
98 d_non_monotonic_sorts.clear();
99 d_type_sub_sorts.clear();
100 //reset info
101 sortCount = 1;
102 d_type_union_find.clear();
103 d_type_types.clear();
104 d_id_for_types.clear();
105 d_op_return_types.clear();
106 d_op_arg_types.clear();
107 d_var_types.clear();
108 //for rewriting
109 d_symbol_map.clear();
110 d_const_map.clear();
111 }
112
113 bool SortInference::simplify( std::vector< Node >& assertions ){
114 Trace("sort-inference") << "Calculating sort inference..." << std::endl;
115 //process all assertions
116 for( unsigned i=0; i<assertions.size(); i++ ){
117 Trace("sort-inference-debug") << "Process " << assertions[i] << std::endl;
118 std::map< Node, Node > var_bound;
119 process( assertions[i], var_bound );
120 }
121 for( std::map< Node, int >::iterator it = d_op_return_types.begin(); it != d_op_return_types.end(); ++it ){
122 Trace("sort-inference") << it->first << " : ";
123 TypeNode retTn = it->first.getType();
124 if( !d_op_arg_types[ it->first ].empty() ){
125 Trace("sort-inference") << "( ";
126 for( size_t i=0; i<d_op_arg_types[ it->first ].size(); i++ ){
127 recordSubsort( retTn[i], d_op_arg_types[ it->first ][i] );
128 printSort( "sort-inference", d_op_arg_types[ it->first ][i] );
129 Trace("sort-inference") << " ";
130 }
131 Trace("sort-inference") << ") -> ";
132 retTn = retTn[(int)retTn.getNumChildren()-1];
133 }
134 recordSubsort( retTn, it->second );
135 printSort( "sort-inference", it->second );
136 Trace("sort-inference") << std::endl;
137 }
138 for( std::map< Node, std::map< Node, int > >::iterator it = d_var_types.begin(); it != d_var_types.end(); ++it ){
139 Trace("sort-inference") << "Quantified formula : " << it->first << " : " << std::endl;
140 for( unsigned i=0; i<it->first[0].getNumChildren(); i++ ){
141 recordSubsort( it->first[0][i].getType(), it->second[it->first[0][i]] );
142 printSort( "sort-inference", it->second[it->first[0][i]] );
143 Trace("sort-inference") << std::endl;
144 }
145 Trace("sort-inference") << std::endl;
146 }
147
148 if( !options::ufssSymBreak() ){
149 bool rewritten = false;
150 //determine monotonicity of sorts
151 for( unsigned i=0; i<assertions.size(); i++ ){
152 Trace("sort-inference-debug") << "Process monotonicity for " << assertions[i] << std::endl;
153 std::map< Node, Node > var_bound;
154 processMonotonic( assertions[i], true, true, var_bound );
155 }
156
157 Trace("sort-inference") << "We have " << d_sub_sorts.size() << " sub-sorts : " << std::endl;
158 for( unsigned i=0; i<d_sub_sorts.size(); i++ ){
159 printSort( "sort-inference", d_sub_sorts[i] );
160 if( d_type_types.find( d_sub_sorts[i] )!=d_type_types.end() ){
161 Trace("sort-inference") << " is interpreted." << std::endl;
162 }else if( d_non_monotonic_sorts.find( d_sub_sorts[i] )==d_non_monotonic_sorts.end() ){
163 Trace("sort-inference") << " is monotonic." << std::endl;
164 }else{
165 Trace("sort-inference") << " is not monotonic." << std::endl;
166 }
167 }
168
169 //simplify all assertions by introducing new symbols wherever necessary
170 for( unsigned i=0; i<assertions.size(); i++ ){
171 Node prev = assertions[i];
172 std::map< Node, Node > var_bound;
173 assertions[i] = simplify( assertions[i], var_bound );
174 if( prev!=assertions[i] ){
175 rewritten = true;
176 Trace("sort-inference-rewrite") << prev << std::endl;
177 Trace("sort-inference-rewrite") << " --> " << assertions[i] << std::endl;
178 }
179 }
180 //now, ensure constants are distinct
181 for( std::map< TypeNode, std::map< Node, Node > >::iterator it = d_const_map.begin(); it != d_const_map.end(); ++it ){
182 std::vector< Node > consts;
183 for( std::map< Node, Node >::iterator it2 = it->second.begin(); it2 != it->second.end(); ++it2 ){
184 consts.push_back( it2->second );
185 }
186 //TODO: add lemma enforcing introduced constants to be distinct
187 }
188
189 //enforce constraints based on monotonicity
190 for( std::map< TypeNode, std::vector< int > >::iterator it = d_type_sub_sorts.begin(); it != d_type_sub_sorts.end(); ++it ){
191 int nmonSort = -1;
192 for( unsigned i=0; i<it->second.size(); i++ ){
193 if( d_non_monotonic_sorts.find( it->second[i] )!=d_non_monotonic_sorts.end() ){
194 nmonSort = it->second[i];
195 break;
196 }
197 }
198 if( nmonSort!=-1 ){
199 std::vector< Node > injections;
200 TypeNode base_tn = getOrCreateTypeForId( nmonSort, it->first );
201 for( unsigned i=0; i<it->second.size(); i++ ){
202 if( it->second[i]!=nmonSort ){
203 TypeNode new_tn = getOrCreateTypeForId( it->second[i], it->first );
204 //make injection to nmonSort
205 Node a1 = mkInjection( new_tn, base_tn );
206 injections.push_back( a1 );
207 if( d_non_monotonic_sorts.find( it->second[i] )!=d_non_monotonic_sorts.end() ){
208 //also must make injection from nmonSort to this
209 Node a2 = mkInjection( base_tn, new_tn );
210 injections.push_back( a2 );
211 }
212 }
213 }
214 Trace("sort-inference-rewrite") << "Add the following injections for " << it->first << " to ensure consistency wrt non-monotonic sorts : " << std::endl;
215 for( unsigned j=0; j<injections.size(); j++ ){
216 Trace("sort-inference-rewrite") << " " << injections[j] << std::endl;
217 }
218 assertions.insert( assertions.end(), injections.begin(), injections.end() );
219 if( !injections.empty() ){
220 rewritten = true;
221 }
222 }
223 }
224 //no sub-sort information is stored
225 reset();
226 return rewritten;
227 }
228 /*
229 else if( !options::ufssSymBreak() ){
230 //just add the unit lemmas between constants
231 std::map< TypeNode, std::map< int, Node > > constants;
232 for( std::map< Node, int >::iterator it = d_op_return_types.begin(); it != d_op_return_types.end(); ++it ){
233 int rt = d_type_union_find.getRepresentative( it->second );
234 if( d_op_arg_types[ it->first ].empty() ){
235 TypeNode tn = it->first.getType();
236 if( constants[ tn ].find( rt )==constants[ tn ].end() ){
237 constants[ tn ][ rt ] = it->first;
238 }
239 }
240 }
241 //add unit lemmas for each constant
242 for( std::map< TypeNode, std::map< int, Node > >::iterator it = constants.begin(); it != constants.end(); ++it ){
243 Node first_const;
244 for( std::map< int, Node >::iterator it2 = it->second.begin(); it2 != it->second.end(); ++it2 ){
245 if( first_const.isNull() ){
246 first_const = it2->second;
247 }else{
248 Node eq = first_const.eqNode( it2->second );
249 //eq = Rewriter::rewrite( eq );
250 Trace("sort-inference-lemma") << "Sort inference lemma : " << eq << std::endl;
251 assertions.push_back( eq );
252 }
253 }
254 }
255 }
256 */
257 initialSortCount = sortCount;
258 return false;
259 }
260
261 void SortInference::setEqual( int t1, int t2 ){
262 if( t1!=t2 ){
263 int rt1 = d_type_union_find.getRepresentative( t1 );
264 int rt2 = d_type_union_find.getRepresentative( t2 );
265 if( rt1!=rt2 ){
266 Trace("sort-inference-debug") << "Set equal : ";
267 printSort( "sort-inference-debug", rt1 );
268 Trace("sort-inference-debug") << " ";
269 printSort( "sort-inference-debug", rt2 );
270 Trace("sort-inference-debug") << std::endl;
271 /*
272 d_type_eq_class[rt1].insert( d_type_eq_class[rt1].end(), d_type_eq_class[rt2].begin(), d_type_eq_class[rt2].end() );
273 d_type_eq_class[rt2].clear();
274 Trace("sort-inference-debug") << "EqClass : { ";
275 for( int i=0; i<(int)d_type_eq_class[rt1].size(); i++ ){
276 Trace("sort-inference-debug") << d_type_eq_class[rt1][i] << ", ";
277 }
278 Trace("sort-inference-debug") << "}" << std::endl;
279 */
280 if( rt2>rt1 ){
281 //swap
282 int swap = rt1;
283 rt1 = rt2;
284 rt2 = swap;
285 }
286 d_type_union_find.d_eqc[rt1] = rt2;
287 std::map< int, TypeNode >::iterator it1 = d_type_types.find( rt1 );
288 if( it1!=d_type_types.end() ){
289 Assert( d_type_types.find( rt2 )==d_type_types.end() );
290 d_type_types[rt2] = it1->second;
291 d_type_types.erase( rt1 );
292 }
293 }
294 }
295 }
296
297 int SortInference::getIdForType( TypeNode tn ){
298 //register the return type
299 std::map< TypeNode, int >::iterator it = d_id_for_types.find( tn );
300 if( it==d_id_for_types.end() ){
301 int sc = sortCount;
302 d_type_types[ sortCount ] = tn;
303 d_id_for_types[ tn ] = sortCount;
304 sortCount++;
305 return sc;
306 }else{
307 return it->second;
308 }
309 }
310
311 int SortInference::process( Node n, std::map< Node, Node >& var_bound ){
312 Trace("sort-inference-debug") << "...Process " << n << std::endl;
313 //add to variable bindings
314 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
315 if( d_var_types.find( n )!=d_var_types.end() ){
316 return getIdForType( n.getType() );
317 }else{
318 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
319 //apply sort inference to quantified variables
320 d_var_types[n][ n[0][i] ] = sortCount;
321 sortCount++;
322
323 //type of the quantified variable must be the same
324 var_bound[ n[0][i] ] = n;
325 }
326 }
327 }
328
329 //process children
330 std::vector< Node > children;
331 std::vector< int > child_types;
332 for( size_t i=0; i<n.getNumChildren(); i++ ){
333 bool processChild = true;
334 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
335 processChild = i==1;
336 }
337 if( processChild ){
338 children.push_back( n[i] );
339 child_types.push_back( process( n[i], var_bound ) );
340 }
341 }
342
343 //remove from variable bindings
344 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
345 //erase from variable bound
346 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
347 var_bound.erase( n[0][i] );
348 }
349 }
350
351 int retType;
352 if( n.getKind()==kind::EQUAL ){
353 //we only require that the left and right hand side must be equal
354 setEqual( child_types[0], child_types[1] );
355 //int eqType = getIdForType( n[0].getType() );
356 //setEqual( child_types[0], eqType );
357 //setEqual( child_types[1], eqType );
358 retType = getIdForType( n.getType() );
359 }else if( n.getKind()==kind::APPLY_UF ){
360 Node op = n.getOperator();
361 if( d_op_return_types.find( op )==d_op_return_types.end() ){
362 if( n.getType().isBoolean() ){
363 //use booleans
364 d_op_return_types[op] = getIdForType( n.getType() );
365 }else{
366 //assign arbitrary sort for return type
367 d_op_return_types[op] = sortCount;
368 sortCount++;
369 }
370 //d_type_eq_class[sortCount].push_back( op );
371 //assign arbitrary sort for argument types
372 for( size_t i=0; i<n.getNumChildren(); i++ ){
373 d_op_arg_types[op].push_back( sortCount );
374 sortCount++;
375 }
376 }
377 for( size_t i=0; i<n.getNumChildren(); i++ ){
378 //the argument of the operator must match the return type of the subterm
379 setEqual( child_types[i], d_op_arg_types[op][i] );
380 }
381 //return type is the return type
382 retType = d_op_return_types[op];
383 }else{
384 std::map< Node, Node >::iterator it = var_bound.find( n );
385 if( it!=var_bound.end() ){
386 Trace("sort-inference-debug") << n << " is a bound variable." << std::endl;
387 //the return type was specified while binding
388 retType = d_var_types[it->second][n];
389 }else if( n.getKind() == kind::VARIABLE || n.getKind()==kind::SKOLEM ){
390 Trace("sort-inference-debug") << n << " is a variable." << std::endl;
391 if( d_op_return_types.find( n )==d_op_return_types.end() ){
392 //assign arbitrary sort
393 d_op_return_types[n] = sortCount;
394 sortCount++;
395 //d_type_eq_class[sortCount].push_back( n );
396 }
397 retType = d_op_return_types[n];
398 //}else if( n.isConst() ){
399 // Trace("sort-inference-debug") << n << " is a constant." << std::endl;
400 //can be any type we want
401 // retType = sortCount;
402 // sortCount++;
403 }else{
404 Trace("sort-inference-debug") << n << " is a interpreted symbol." << std::endl;
405 //it is an interpretted term
406 for( size_t i=0; i<children.size(); i++ ){
407 Trace("sort-inference-debug") << children[i] << " forced to have " << children[i].getType() << std::endl;
408 //must enforce the actual type of the operator on the children
409 int ct = getIdForType( children[i].getType() );
410 setEqual( child_types[i], ct );
411 }
412 //return type must be the actual return type
413 retType = getIdForType( n.getType() );
414 }
415 }
416 Trace("sort-inference-debug") << "...Type( " << n << " ) = ";
417 printSort("sort-inference-debug", retType );
418 Trace("sort-inference-debug") << std::endl;
419 return retType;
420 }
421
422 void SortInference::processMonotonic( Node n, bool pol, bool hasPol, std::map< Node, Node >& var_bound ) {
423 Trace("sort-inference-debug") << "...Process monotonic " << pol << " " << hasPol << " " << n << std::endl;
424 if( n.getKind()==kind::FORALL ){
425 for( unsigned i=0; i<n[0].getNumChildren(); i++ ){
426 var_bound[n[0][i]] = n;
427 }
428 processMonotonic( n[1], pol, hasPol, var_bound );
429 for( unsigned i=0; i<n[0].getNumChildren(); i++ ){
430 var_bound.erase( n[0][i] );
431 }
432 }else if( n.getKind()==kind::EQUAL ){
433 if( !hasPol || pol ){
434 for( unsigned i=0; i<2; i++ ){
435 if( var_bound.find( n[i] )!=var_bound.end() ){
436 int sid = getSortId( var_bound[n[i]], n[i] );
437 d_non_monotonic_sorts[sid] = true;
438 break;
439 }
440 }
441 }
442 }
443 for( unsigned i=0; i<n.getNumChildren(); i++ ){
444 bool npol = pol;
445 bool nhasPol = hasPol;
446 if( n.getKind()==kind::NOT || ( n.getKind()==kind::IMPLIES && i==0 ) ){
447 npol = !npol;
448 }
449 if( ( n.getKind()==kind::ITE && i==0 ) || n.getKind()==kind::XOR || n.getKind()==kind::IFF ){
450 nhasPol = false;
451 }
452 processMonotonic( n[i], npol, nhasPol, var_bound );
453 }
454 }
455
456
457 TypeNode SortInference::getOrCreateTypeForId( int t, TypeNode pref ){
458 int rt = d_type_union_find.getRepresentative( t );
459 if( d_type_types.find( rt )!=d_type_types.end() ){
460 return d_type_types[rt];
461 }else{
462 TypeNode retType;
463 //see if we can assign pref
464 if( !pref.isNull() && d_id_for_types.find( pref )==d_id_for_types.end() ){
465 retType = pref;
466 }else{
467 //must create new type
468 std::stringstream ss;
469 ss << "it_" << t << "_" << pref;
470 retType = NodeManager::currentNM()->mkSort( ss.str() );
471 }
472 Trace("sort-inference") << "-> Make type " << retType << " to correspond to ";
473 printSort("sort-inference", t );
474 Trace("sort-inference") << std::endl;
475 d_id_for_types[ retType ] = rt;
476 d_type_types[ rt ] = retType;
477 return retType;
478 }
479 }
480
481 TypeNode SortInference::getTypeForId( int t ){
482 int rt = d_type_union_find.getRepresentative( t );
483 if( d_type_types.find( rt )!=d_type_types.end() ){
484 return d_type_types[rt];
485 }else{
486 return TypeNode::null();
487 }
488 }
489
490 Node SortInference::getNewSymbol( Node old, TypeNode tn ){
491 if( tn==old.getType() ){
492 return old;
493 }else if( old.isConst() ){
494 //must make constant of type tn
495 if( d_const_map[tn].find( old )==d_const_map[tn].end() ){
496 std::stringstream ss;
497 ss << "ic_" << tn << "_" << old;
498 d_const_map[tn][ old ] = NodeManager::currentNM()->mkSkolem( ss.str(), tn, "constant created during sort inference" ); //use mkConst???
499 }
500 return d_const_map[tn][ old ];
501 }else if( old.getKind()==kind::BOUND_VARIABLE ){
502 std::stringstream ss;
503 ss << "b_" << old;
504 return NodeManager::currentNM()->mkBoundVar( ss.str(), tn );
505 }else{
506 std::stringstream ss;
507 ss << "i_" << old;
508 return NodeManager::currentNM()->mkSkolem( ss.str(), tn, "created during sort inference" );
509 }
510 }
511
512 Node SortInference::simplify( Node n, std::map< Node, Node >& var_bound ){
513 std::vector< Node > children;
514 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
515 //recreate based on types of variables
516 std::vector< Node > new_children;
517 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
518 TypeNode tn = getOrCreateTypeForId( d_var_types[n][ n[0][i] ], n[0][i].getType() );
519 Node v = getNewSymbol( n[0][i], tn );
520 new_children.push_back( v );
521 var_bound[ n[0][i] ] = v;
522 }
523 children.push_back( NodeManager::currentNM()->mkNode( n[0].getKind(), new_children ) );
524 }
525
526 //process children
527 if( n.getMetaKind() == kind::metakind::PARAMETERIZED ){
528 children.push_back( n.getOperator() );
529 }
530 for( size_t i=0; i<n.getNumChildren(); i++ ){
531 bool processChild = true;
532 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
533 processChild = i>=1;
534 }
535 if( processChild ){
536 children.push_back( simplify( n[i], var_bound ) );
537 }
538 }
539
540 //remove from variable bindings
541 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
542 //erase from variable bound
543 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
544 var_bound.erase( n[0][i] );
545 }
546 return NodeManager::currentNM()->mkNode( n.getKind(), children );
547 }else if( n.getKind()==kind::EQUAL ){
548 if( children[0].getType()!=children[1].getType() ){
549 if( children[0].isConst() ){
550 children[0] = getNewSymbol( children[0], children[1].getType() );
551 }else if( children[1].isConst() ){
552 children[1] = getNewSymbol( children[1], children[0].getType() );
553 }else{
554 Trace("sort-inference-warn") << "Sort inference created bad equality: " << children[0] << " = " << children[1] << std::endl;
555 Trace("sort-inference-warn") << " Types : " << children[0].getType() << " " << children[1].getType() << std::endl;
556 Assert( false );
557 }
558 }
559 return NodeManager::currentNM()->mkNode( kind::EQUAL, children );
560 }else if( n.getKind()==kind::APPLY_UF ){
561 Node op = n.getOperator();
562 if( d_symbol_map.find( op )==d_symbol_map.end() ){
563 //make the new operator if necessary
564 bool opChanged = false;
565 std::vector< TypeNode > argTypes;
566 for( size_t i=0; i<n.getNumChildren(); i++ ){
567 TypeNode tn = getOrCreateTypeForId( d_op_arg_types[op][i], n[i].getType() );
568 argTypes.push_back( tn );
569 if( tn!=n[i].getType() ){
570 opChanged = true;
571 }
572 }
573 TypeNode retType = getOrCreateTypeForId( d_op_return_types[op], n.getType() );
574 if( retType!=n.getType() ){
575 opChanged = true;
576 }
577 if( opChanged ){
578 std::stringstream ss;
579 ss << "io_" << op;
580 TypeNode typ = NodeManager::currentNM()->mkFunctionType( argTypes, retType );
581 d_symbol_map[op] = NodeManager::currentNM()->mkSkolem( ss.str(), typ, "op created during sort inference" );
582 }else{
583 d_symbol_map[op] = op;
584 }
585 }
586 children[0] = d_symbol_map[op];
587 //make sure all children have been taken care of
588 for( size_t i=0; i<n.getNumChildren(); i++ ){
589 TypeNode tn = children[i+1].getType();
590 TypeNode tna = getTypeForId( d_op_arg_types[op][i] );
591 if( tn!=tna ){
592 if( n[i].isConst() ){
593 children[i+1] = getNewSymbol( n[i], tna );
594 }else{
595 Trace("sort-inference-warn") << "Sort inference created bad child: " << n[i] << " " << tn << " " << tna << std::endl;
596 Assert( false );
597 }
598 }
599 }
600 return NodeManager::currentNM()->mkNode( kind::APPLY_UF, children );
601 }else{
602 std::map< Node, Node >::iterator it = var_bound.find( n );
603 if( it!=var_bound.end() ){
604 return it->second;
605 }else if( n.getKind() == kind::VARIABLE || n.getKind() == kind::SKOLEM ){
606 if( d_symbol_map.find( n )==d_symbol_map.end() ){
607 TypeNode tn = getOrCreateTypeForId( d_op_return_types[n], n.getType() );
608 d_symbol_map[n] = getNewSymbol( n, tn );
609 }
610 return d_symbol_map[n];
611 }else if( n.isConst() ){
612 //just return n, we will fix at higher scope
613 return n;
614 }else{
615 return NodeManager::currentNM()->mkNode( n.getKind(), children );
616 }
617 }
618
619 }
620
621 Node SortInference::mkInjection( TypeNode tn1, TypeNode tn2 ) {
622 std::vector< TypeNode > tns;
623 tns.push_back( tn1 );
624 TypeNode typ = NodeManager::currentNM()->mkFunctionType( tns, tn2 );
625 Node f = NodeManager::currentNM()->mkSkolem( "inj", typ, "injection for monotonicity constraint" );
626 Trace("sort-inference") << "-> Make injection " << f << " from " << tn1 << " to " << tn2 << std::endl;
627 Node v1 = NodeManager::currentNM()->mkBoundVar( "?x", tn1 );
628 Node v2 = NodeManager::currentNM()->mkBoundVar( "?y", tn1 );
629 return NodeManager::currentNM()->mkNode( kind::FORALL,
630 NodeManager::currentNM()->mkNode( kind::BOUND_VAR_LIST, v1, v2 ),
631 NodeManager::currentNM()->mkNode( kind::IMPLIES,
632 NodeManager::currentNM()->mkNode( kind::APPLY_UF, f, v1 ).eqNode( NodeManager::currentNM()->mkNode( kind::APPLY_UF, f, v2 ) ),
633 v1.eqNode( v2 ) ) );
634 }
635
636 int SortInference::getSortId( Node n ) {
637 Node op = n.getKind()==kind::APPLY_UF ? n.getOperator() : n;
638 if( d_op_return_types.find( op )!=d_op_return_types.end() ){
639 return d_type_union_find.getRepresentative( d_op_return_types[op] );
640 }else{
641 return 0;
642 }
643 }
644
645 int SortInference::getSortId( Node f, Node v ) {
646 if( d_var_types.find( f )!=d_var_types.end() ){
647 return d_type_union_find.getRepresentative( d_var_types[f][v] );
648 }else{
649 return 0;
650 }
651 }
652
653 void SortInference::setSkolemVar( Node f, Node v, Node sk ){
654 Trace("sort-inference-temp") << "Set skolem var for " << f << ", variable " << v << std::endl;
655 if( isWellSortedFormula( f ) && d_var_types.find( f )==d_var_types.end() ){
656 //calculate the sort for variables if not done so already
657 std::map< Node, Node > var_bound;
658 process( f, var_bound );
659 }
660 d_op_return_types[sk] = getSortId( f, v );
661 Trace("sort-inference-temp") << "Set skolem sort id for " << sk << " to " << d_op_return_types[sk] << std::endl;
662 }
663
664 bool SortInference::isWellSortedFormula( Node n ) {
665 if( n.getType().isBoolean() && n.getKind()!=kind::APPLY_UF ){
666 for( unsigned i=0; i<n.getNumChildren(); i++ ){
667 if( !isWellSortedFormula( n[i] ) ){
668 return false;
669 }
670 }
671 return true;
672 }else{
673 return isWellSorted( n );
674 }
675 }
676
677 bool SortInference::isWellSorted( Node n ) {
678 if( getSortId( n )==0 ){
679 return false;
680 }else{
681 if( n.getKind()==kind::APPLY_UF ){
682 for( unsigned i=0; i<n.getNumChildren(); i++ ){
683 int s1 = getSortId( n[i] );
684 int s2 = d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] );
685 if( s1!=s2 ){
686 return false;
687 }
688 if( !isWellSorted( n[i] ) ){
689 return false;
690 }
691 }
692 }
693 return true;
694 }
695 }
696
697 void SortInference::getSortConstraints( Node n, UnionFind& uf ) {
698 if( n.getKind()==kind::APPLY_UF ){
699 for( unsigned i=0; i<n.getNumChildren(); i++ ){
700 getSortConstraints( n[i], uf );
701 uf.setEqual( getSortId( n[i] ), d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] ) );
702 }
703 }
704 }
705
706 }/* CVC4 namespace */