dbd1dcd16593613cbfed7871320a221d89b5750b
[cvc5.git] / src / util / sort_inference.cpp
1 /********************* */
2 /*! \file sort_inference.cpp
3 ** \verbatim
4 ** Original author: Andrew Reynolds
5 ** Major contributors: Morgan Deters
6 ** Minor contributors (to current version): Kshitij Bansal
7 ** This file is part of the CVC4 project.
8 ** Copyright (c) 2009-2014 New York University and The University of Iowa
9 ** See the file COPYING in the top-level source directory for licensing
10 ** information.\endverbatim
11 **
12 ** \brief Sort inference module
13 **
14 ** This class implements sort inference, based on a simple algorithm:
15 ** First, we assume all functions and predicates have distinct uninterpreted types.
16 ** One pass is made through the input assertions, while a union-find data structure
17 ** maintains necessary information regarding constraints on these types.
18 **/
19
20 #include <vector>
21
22 #include "util/sort_inference.h"
23 #include "theory/uf/options.h"
24 #include "smt/options.h"
25 #include "theory/rewriter.h"
26 #include "theory/quantifiers/options.h"
27
28 using namespace CVC4;
29 using namespace std;
30
31 namespace CVC4 {
32
33 void SortInference::UnionFind::print(const char * c){
34 for( std::map< int, int >::iterator it = d_eqc.begin(); it != d_eqc.end(); ++it ){
35 Trace(c) << "s_" << it->first << " = s_" << it->second << ", ";
36 }
37 for( unsigned i=0; i<d_deq.size(); i++ ){
38 Trace(c) << "s_" << d_deq[i].first << " != s_" << d_deq[i].second << ", ";
39 }
40 Trace(c) << std::endl;
41 }
42 void SortInference::UnionFind::set( UnionFind& c ) {
43 clear();
44 for( std::map< int, int >::iterator it = c.d_eqc.begin(); it != c.d_eqc.end(); ++it ){
45 d_eqc[ it->first ] = it->second;
46 }
47 d_deq.insert( d_deq.end(), c.d_deq.begin(), c.d_deq.end() );
48 }
49 int SortInference::UnionFind::getRepresentative( int t ){
50 std::map< int, int >::iterator it = d_eqc.find( t );
51 if( it==d_eqc.end() || it->second==t ){
52 return t;
53 }else{
54 int rt = getRepresentative( it->second );
55 d_eqc[t] = rt;
56 return rt;
57 }
58 }
59 void SortInference::UnionFind::setEqual( int t1, int t2 ){
60 if( t1!=t2 ){
61 int rt1 = getRepresentative( t1 );
62 int rt2 = getRepresentative( t2 );
63 if( rt1>rt2 ){
64 d_eqc[rt1] = rt2;
65 }else{
66 d_eqc[rt2] = rt1;
67 }
68 }
69 }
70 bool SortInference::UnionFind::isValid() {
71 for( unsigned i=0; i<d_deq.size(); i++ ){
72 if( areEqual( d_deq[i].first, d_deq[i].second ) ){
73 return false;
74 }
75 }
76 return true;
77 }
78
79
80 void SortInference::recordSubsort( TypeNode tn, int s ){
81 s = d_type_union_find.getRepresentative( s );
82 if( std::find( d_sub_sorts.begin(), d_sub_sorts.end(), s )==d_sub_sorts.end() ){
83 d_sub_sorts.push_back( s );
84 d_type_sub_sorts[tn].push_back( s );
85 }
86 }
87
88 void SortInference::printSort( const char* c, int t ){
89 int rt = d_type_union_find.getRepresentative( t );
90 if( d_type_types.find( rt )!=d_type_types.end() ){
91 Trace(c) << d_type_types[rt];
92 }else{
93 Trace(c) << "s_" << rt;
94 }
95 }
96
97 void SortInference::reset() {
98 d_sub_sorts.clear();
99 d_non_monotonic_sorts.clear();
100 d_type_sub_sorts.clear();
101 //reset info
102 sortCount = 1;
103 d_type_union_find.clear();
104 d_type_types.clear();
105 d_id_for_types.clear();
106 d_op_return_types.clear();
107 d_op_arg_types.clear();
108 d_var_types.clear();
109 //for rewriting
110 d_symbol_map.clear();
111 d_const_map.clear();
112 }
113
114 bool SortInference::simplify( std::vector< Node >& assertions ){
115 Trace("sort-inference") << "Calculating sort inference..." << std::endl;
116 //process all assertions
117 for( unsigned i=0; i<assertions.size(); i++ ){
118 Trace("sort-inference-debug") << "Process " << assertions[i] << std::endl;
119 std::map< Node, Node > var_bound;
120 process( assertions[i], var_bound );
121 }
122 for( std::map< Node, int >::iterator it = d_op_return_types.begin(); it != d_op_return_types.end(); ++it ){
123 Trace("sort-inference") << it->first << " : ";
124 TypeNode retTn = it->first.getType();
125 if( !d_op_arg_types[ it->first ].empty() ){
126 Trace("sort-inference") << "( ";
127 for( size_t i=0; i<d_op_arg_types[ it->first ].size(); i++ ){
128 recordSubsort( retTn[i], d_op_arg_types[ it->first ][i] );
129 printSort( "sort-inference", d_op_arg_types[ it->first ][i] );
130 Trace("sort-inference") << " ";
131 }
132 Trace("sort-inference") << ") -> ";
133 retTn = retTn[(int)retTn.getNumChildren()-1];
134 }
135 recordSubsort( retTn, it->second );
136 printSort( "sort-inference", it->second );
137 Trace("sort-inference") << std::endl;
138 }
139 for( std::map< Node, std::map< Node, int > >::iterator it = d_var_types.begin(); it != d_var_types.end(); ++it ){
140 Trace("sort-inference") << "Quantified formula : " << it->first << " : " << std::endl;
141 for( unsigned i=0; i<it->first[0].getNumChildren(); i++ ){
142 recordSubsort( it->first[0][i].getType(), it->second[it->first[0][i]] );
143 printSort( "sort-inference", it->second[it->first[0][i]] );
144 Trace("sort-inference") << std::endl;
145 }
146 Trace("sort-inference") << std::endl;
147 }
148
149 if( !options::ufssSymBreak() ){
150 bool rewritten = false;
151 //determine monotonicity of sorts
152 for( unsigned i=0; i<assertions.size(); i++ ){
153 Trace("sort-inference-debug") << "Process monotonicity for " << assertions[i] << std::endl;
154 std::map< Node, Node > var_bound;
155 processMonotonic( assertions[i], true, true, var_bound );
156 }
157
158 Trace("sort-inference") << "We have " << d_sub_sorts.size() << " sub-sorts : " << std::endl;
159 for( unsigned i=0; i<d_sub_sorts.size(); i++ ){
160 printSort( "sort-inference", d_sub_sorts[i] );
161 if( d_type_types.find( d_sub_sorts[i] )!=d_type_types.end() ){
162 Trace("sort-inference") << " is interpreted." << std::endl;
163 }else if( d_non_monotonic_sorts.find( d_sub_sorts[i] )==d_non_monotonic_sorts.end() ){
164 Trace("sort-inference") << " is monotonic." << std::endl;
165 }else{
166 Trace("sort-inference") << " is not monotonic." << std::endl;
167 }
168 }
169
170 //simplify all assertions by introducing new symbols wherever necessary
171 for( unsigned i=0; i<assertions.size(); i++ ){
172 Node prev = assertions[i];
173 std::map< Node, Node > var_bound;
174 Trace("sort-inference-debug") << "Rewrite " << assertions[i] << std::endl;
175 assertions[i] = simplify( assertions[i], var_bound );
176 Trace("sort-inference-debug") << "Done." << std::endl;
177 if( prev!=assertions[i] ){
178 assertions[i] = theory::Rewriter::rewrite( assertions[i] );
179 rewritten = true;
180 Trace("sort-inference-rewrite") << prev << std::endl;
181 Trace("sort-inference-rewrite") << " --> " << assertions[i] << std::endl;
182 }
183 }
184 //now, ensure constants are distinct
185 for( std::map< TypeNode, std::map< Node, Node > >::iterator it = d_const_map.begin(); it != d_const_map.end(); ++it ){
186 std::vector< Node > consts;
187 for( std::map< Node, Node >::iterator it2 = it->second.begin(); it2 != it->second.end(); ++it2 ){
188 consts.push_back( it2->second );
189 }
190 //TODO: add lemma enforcing introduced constants to be distinct
191 }
192
193 //enforce constraints based on monotonicity
194 for( std::map< TypeNode, std::vector< int > >::iterator it = d_type_sub_sorts.begin(); it != d_type_sub_sorts.end(); ++it ){
195 int nmonSort = -1;
196 for( unsigned i=0; i<it->second.size(); i++ ){
197 if( d_non_monotonic_sorts.find( it->second[i] )!=d_non_monotonic_sorts.end() ){
198 nmonSort = it->second[i];
199 break;
200 }
201 }
202 if( nmonSort!=-1 ){
203 std::vector< Node > injections;
204 TypeNode base_tn = getOrCreateTypeForId( nmonSort, it->first );
205 for( unsigned i=0; i<it->second.size(); i++ ){
206 if( it->second[i]!=nmonSort ){
207 TypeNode new_tn = getOrCreateTypeForId( it->second[i], it->first );
208 //make injection to nmonSort
209 Node a1 = mkInjection( new_tn, base_tn );
210 injections.push_back( a1 );
211 if( d_non_monotonic_sorts.find( it->second[i] )!=d_non_monotonic_sorts.end() ){
212 //also must make injection from nmonSort to this
213 Node a2 = mkInjection( base_tn, new_tn );
214 injections.push_back( a2 );
215 }
216 }
217 }
218 Trace("sort-inference-rewrite") << "Add the following injections for " << it->first << " to ensure consistency wrt non-monotonic sorts : " << std::endl;
219 for( unsigned j=0; j<injections.size(); j++ ){
220 Trace("sort-inference-rewrite") << " " << injections[j] << std::endl;
221 }
222 assertions.insert( assertions.end(), injections.begin(), injections.end() );
223 if( !injections.empty() ){
224 rewritten = true;
225 }
226 }
227 }
228 //no sub-sort information is stored
229 reset();
230 return rewritten;
231 }
232 /*
233 else if( !options::ufssSymBreak() ){
234 //just add the unit lemmas between constants
235 std::map< TypeNode, std::map< int, Node > > constants;
236 for( std::map< Node, int >::iterator it = d_op_return_types.begin(); it != d_op_return_types.end(); ++it ){
237 int rt = d_type_union_find.getRepresentative( it->second );
238 if( d_op_arg_types[ it->first ].empty() ){
239 TypeNode tn = it->first.getType();
240 if( constants[ tn ].find( rt )==constants[ tn ].end() ){
241 constants[ tn ][ rt ] = it->first;
242 }
243 }
244 }
245 //add unit lemmas for each constant
246 for( std::map< TypeNode, std::map< int, Node > >::iterator it = constants.begin(); it != constants.end(); ++it ){
247 Node first_const;
248 for( std::map< int, Node >::iterator it2 = it->second.begin(); it2 != it->second.end(); ++it2 ){
249 if( first_const.isNull() ){
250 first_const = it2->second;
251 }else{
252 Node eq = first_const.eqNode( it2->second );
253 //eq = Rewriter::rewrite( eq );
254 Trace("sort-inference-lemma") << "Sort inference lemma : " << eq << std::endl;
255 assertions.push_back( eq );
256 }
257 }
258 }
259 }
260 */
261 initialSortCount = sortCount;
262 return false;
263 }
264
265 void SortInference::setEqual( int t1, int t2 ){
266 if( t1!=t2 ){
267 int rt1 = d_type_union_find.getRepresentative( t1 );
268 int rt2 = d_type_union_find.getRepresentative( t2 );
269 if( rt1!=rt2 ){
270 Trace("sort-inference-debug") << "Set equal : ";
271 printSort( "sort-inference-debug", rt1 );
272 Trace("sort-inference-debug") << " ";
273 printSort( "sort-inference-debug", rt2 );
274 Trace("sort-inference-debug") << std::endl;
275 /*
276 d_type_eq_class[rt1].insert( d_type_eq_class[rt1].end(), d_type_eq_class[rt2].begin(), d_type_eq_class[rt2].end() );
277 d_type_eq_class[rt2].clear();
278 Trace("sort-inference-debug") << "EqClass : { ";
279 for( int i=0; i<(int)d_type_eq_class[rt1].size(); i++ ){
280 Trace("sort-inference-debug") << d_type_eq_class[rt1][i] << ", ";
281 }
282 Trace("sort-inference-debug") << "}" << std::endl;
283 */
284 if( rt2>rt1 ){
285 //swap
286 int swap = rt1;
287 rt1 = rt2;
288 rt2 = swap;
289 }
290 d_type_union_find.d_eqc[rt1] = rt2;
291 std::map< int, TypeNode >::iterator it1 = d_type_types.find( rt1 );
292 if( it1!=d_type_types.end() ){
293 Assert( d_type_types.find( rt2 )==d_type_types.end() );
294 d_type_types[rt2] = it1->second;
295 d_type_types.erase( rt1 );
296 }
297 }
298 }
299 }
300
301 int SortInference::getIdForType( TypeNode tn ){
302 //register the return type
303 std::map< TypeNode, int >::iterator it = d_id_for_types.find( tn );
304 if( it==d_id_for_types.end() ){
305 int sc = sortCount;
306 d_type_types[ sortCount ] = tn;
307 d_id_for_types[ tn ] = sortCount;
308 sortCount++;
309 return sc;
310 }else{
311 return it->second;
312 }
313 }
314
315 int SortInference::process( Node n, std::map< Node, Node >& var_bound ){
316 Trace("sort-inference-debug") << "...Process " << n << std::endl;
317 //add to variable bindings
318 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
319 if( d_var_types.find( n )!=d_var_types.end() ){
320 return getIdForType( n.getType() );
321 }else{
322 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
323 //apply sort inference to quantified variables
324 d_var_types[n][ n[0][i] ] = sortCount;
325 sortCount++;
326
327 //type of the quantified variable must be the same
328 var_bound[ n[0][i] ] = n;
329 }
330 }
331 }
332
333 //process children
334 std::vector< Node > children;
335 std::vector< int > child_types;
336 for( size_t i=0; i<n.getNumChildren(); i++ ){
337 bool processChild = true;
338 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
339 processChild = options::userPatternsQuant()==theory::quantifiers::USER_PAT_MODE_IGNORE ? i==1 : i>=1;
340 }
341 if( processChild ){
342 children.push_back( n[i] );
343 child_types.push_back( process( n[i], var_bound ) );
344 }
345 }
346
347 //remove from variable bindings
348 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
349 //erase from variable bound
350 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
351 var_bound.erase( n[0][i] );
352 }
353 }
354
355 int retType;
356 if( n.getKind()==kind::EQUAL ){
357 //we only require that the left and right hand side must be equal
358 setEqual( child_types[0], child_types[1] );
359 //int eqType = getIdForType( n[0].getType() );
360 //setEqual( child_types[0], eqType );
361 //setEqual( child_types[1], eqType );
362 retType = getIdForType( n.getType() );
363 }else if( n.getKind()==kind::APPLY_UF ){
364 Node op = n.getOperator();
365 if( d_op_return_types.find( op )==d_op_return_types.end() ){
366 if( n.getType().isBoolean() ){
367 //use booleans
368 d_op_return_types[op] = getIdForType( n.getType() );
369 }else{
370 //assign arbitrary sort for return type
371 d_op_return_types[op] = sortCount;
372 sortCount++;
373 }
374 //d_type_eq_class[sortCount].push_back( op );
375 //assign arbitrary sort for argument types
376 for( size_t i=0; i<n.getNumChildren(); i++ ){
377 d_op_arg_types[op].push_back( sortCount );
378 sortCount++;
379 }
380 }
381 for( size_t i=0; i<n.getNumChildren(); i++ ){
382 //the argument of the operator must match the return type of the subterm
383 setEqual( child_types[i], d_op_arg_types[op][i] );
384 }
385 //return type is the return type
386 retType = d_op_return_types[op];
387 }else{
388 std::map< Node, Node >::iterator it = var_bound.find( n );
389 if( it!=var_bound.end() ){
390 Trace("sort-inference-debug") << n << " is a bound variable." << std::endl;
391 //the return type was specified while binding
392 retType = d_var_types[it->second][n];
393 }else if( n.getKind() == kind::VARIABLE || n.getKind()==kind::SKOLEM ){
394 Trace("sort-inference-debug") << n << " is a variable." << std::endl;
395 if( d_op_return_types.find( n )==d_op_return_types.end() ){
396 //assign arbitrary sort
397 d_op_return_types[n] = sortCount;
398 sortCount++;
399 //d_type_eq_class[sortCount].push_back( n );
400 }
401 retType = d_op_return_types[n];
402 //}else if( n.isConst() ){
403 // Trace("sort-inference-debug") << n << " is a constant." << std::endl;
404 //can be any type we want
405 // retType = sortCount;
406 // sortCount++;
407 }else{
408 Trace("sort-inference-debug") << n << " is a interpreted symbol." << std::endl;
409 //it is an interpretted term
410 for( size_t i=0; i<children.size(); i++ ){
411 Trace("sort-inference-debug") << children[i] << " forced to have " << children[i].getType() << std::endl;
412 //must enforce the actual type of the operator on the children
413 int ct = getIdForType( children[i].getType() );
414 setEqual( child_types[i], ct );
415 }
416 //return type must be the actual return type
417 retType = getIdForType( n.getType() );
418 }
419 }
420 Trace("sort-inference-debug") << "...Type( " << n << " ) = ";
421 printSort("sort-inference-debug", retType );
422 Trace("sort-inference-debug") << std::endl;
423 return retType;
424 }
425
426 void SortInference::processMonotonic( Node n, bool pol, bool hasPol, std::map< Node, Node >& var_bound ) {
427 Trace("sort-inference-debug") << "...Process monotonic " << pol << " " << hasPol << " " << n << std::endl;
428 if( n.getKind()==kind::FORALL ){
429 for( unsigned i=0; i<n[0].getNumChildren(); i++ ){
430 var_bound[n[0][i]] = n;
431 }
432 processMonotonic( n[1], pol, hasPol, var_bound );
433 for( unsigned i=0; i<n[0].getNumChildren(); i++ ){
434 var_bound.erase( n[0][i] );
435 }
436 }else if( n.getKind()==kind::EQUAL ){
437 if( !hasPol || pol ){
438 for( unsigned i=0; i<2; i++ ){
439 if( var_bound.find( n[i] )!=var_bound.end() ){
440 int sid = getSortId( var_bound[n[i]], n[i] );
441 d_non_monotonic_sorts[sid] = true;
442 break;
443 }
444 }
445 }
446 }
447 for( unsigned i=0; i<n.getNumChildren(); i++ ){
448 bool npol = pol;
449 bool nhasPol = hasPol;
450 if( n.getKind()==kind::NOT || ( n.getKind()==kind::IMPLIES && i==0 ) ){
451 npol = !npol;
452 }
453 if( ( n.getKind()==kind::ITE && i==0 ) || n.getKind()==kind::XOR || n.getKind()==kind::IFF ){
454 nhasPol = false;
455 }
456 processMonotonic( n[i], npol, nhasPol, var_bound );
457 }
458 }
459
460
461 TypeNode SortInference::getOrCreateTypeForId( int t, TypeNode pref ){
462 int rt = d_type_union_find.getRepresentative( t );
463 if( d_type_types.find( rt )!=d_type_types.end() ){
464 return d_type_types[rt];
465 }else{
466 TypeNode retType;
467 //see if we can assign pref
468 if( !pref.isNull() && d_id_for_types.find( pref )==d_id_for_types.end() ){
469 retType = pref;
470 }else{
471 //must create new type
472 std::stringstream ss;
473 ss << "it_" << t << "_" << pref;
474 retType = NodeManager::currentNM()->mkSort( ss.str() );
475 }
476 Trace("sort-inference") << "-> Make type " << retType << " to correspond to ";
477 printSort("sort-inference", t );
478 Trace("sort-inference") << std::endl;
479 d_id_for_types[ retType ] = rt;
480 d_type_types[ rt ] = retType;
481 return retType;
482 }
483 }
484
485 TypeNode SortInference::getTypeForId( int t ){
486 int rt = d_type_union_find.getRepresentative( t );
487 if( d_type_types.find( rt )!=d_type_types.end() ){
488 return d_type_types[rt];
489 }else{
490 return TypeNode::null();
491 }
492 }
493
494 Node SortInference::getNewSymbol( Node old, TypeNode tn ){
495 if( tn==old.getType() ){
496 return old;
497 }else if( old.isConst() ){
498 //must make constant of type tn
499 if( d_const_map[tn].find( old )==d_const_map[tn].end() ){
500 std::stringstream ss;
501 ss << "ic_" << tn << "_" << old;
502 d_const_map[tn][ old ] = NodeManager::currentNM()->mkSkolem( ss.str(), tn, "constant created during sort inference" ); //use mkConst???
503 }
504 return d_const_map[tn][ old ];
505 }else if( old.getKind()==kind::BOUND_VARIABLE ){
506 std::stringstream ss;
507 ss << "b_" << old;
508 return NodeManager::currentNM()->mkBoundVar( ss.str(), tn );
509 }else{
510 std::stringstream ss;
511 ss << "i_" << old;
512 return NodeManager::currentNM()->mkSkolem( ss.str(), tn, "created during sort inference" );
513 }
514 }
515
516 Node SortInference::simplify( Node n, std::map< Node, Node >& var_bound ){
517 Trace("sort-inference-debug2") << "Simplify " << n << std::endl;
518 std::vector< Node > children;
519 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
520 //recreate based on types of variables
521 std::vector< Node > new_children;
522 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
523 TypeNode tn = getOrCreateTypeForId( d_var_types[n][ n[0][i] ], n[0][i].getType() );
524 Node v = getNewSymbol( n[0][i], tn );
525 Trace("sort-inference-debug2") << "Map variable " << n[0][i] << " to " << v << std::endl;
526 new_children.push_back( v );
527 var_bound[ n[0][i] ] = v;
528 }
529 children.push_back( NodeManager::currentNM()->mkNode( n[0].getKind(), new_children ) );
530 }
531
532 //process children
533 if( n.getMetaKind() == kind::metakind::PARAMETERIZED ){
534 children.push_back( n.getOperator() );
535 }
536 bool childChanged = false;
537 for( size_t i=0; i<n.getNumChildren(); i++ ){
538 bool processChild = true;
539 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
540 processChild = options::userPatternsQuant()==theory::quantifiers::USER_PAT_MODE_IGNORE ? i==1 : i>=1;
541 }
542 if( processChild ){
543 Node nc = simplify( n[i], var_bound );
544 Trace("sort-inference-debug2") << "Simplify " << i << " " << n[i] << " returned " << nc << std::endl;
545 children.push_back( nc );
546 childChanged = childChanged || nc!=n[i];
547 }
548 }
549
550 //remove from variable bindings
551 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
552 //erase from variable bound
553 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
554 Trace("sort-inference-debug2") << "Remove bound for " << n[0][i] << std::endl;
555 var_bound.erase( n[0][i] );
556 }
557 return NodeManager::currentNM()->mkNode( n.getKind(), children );
558 }else if( n.getKind()==kind::EQUAL ){
559 if( children[0].getType()!=children[1].getType() ){
560 if( children[0].isConst() ){
561 children[0] = getNewSymbol( children[0], children[1].getType() );
562 }else if( children[1].isConst() ){
563 children[1] = getNewSymbol( children[1], children[0].getType() );
564 }else{
565 Trace("sort-inference-warn") << "Sort inference created bad equality: " << children[0] << " = " << children[1] << std::endl;
566 Trace("sort-inference-warn") << " Types : " << children[0].getType() << " " << children[1].getType() << std::endl;
567 Assert( false );
568 }
569 }
570 return NodeManager::currentNM()->mkNode( kind::EQUAL, children );
571 }else if( n.getKind()==kind::APPLY_UF ){
572 Node op = n.getOperator();
573 if( d_symbol_map.find( op )==d_symbol_map.end() ){
574 //make the new operator if necessary
575 bool opChanged = false;
576 std::vector< TypeNode > argTypes;
577 for( size_t i=0; i<n.getNumChildren(); i++ ){
578 TypeNode tn = getOrCreateTypeForId( d_op_arg_types[op][i], n[i].getType() );
579 argTypes.push_back( tn );
580 if( tn!=n[i].getType() ){
581 opChanged = true;
582 }
583 }
584 TypeNode retType = getOrCreateTypeForId( d_op_return_types[op], n.getType() );
585 if( retType!=n.getType() ){
586 opChanged = true;
587 }
588 if( opChanged ){
589 std::stringstream ss;
590 ss << "io_" << op;
591 TypeNode typ = NodeManager::currentNM()->mkFunctionType( argTypes, retType );
592 d_symbol_map[op] = NodeManager::currentNM()->mkSkolem( ss.str(), typ, "op created during sort inference" );
593 Trace("setp-model") << "Function " << op << " is replaced with " << d_symbol_map[op] << std::endl;
594 d_model_replace_f[op] = d_symbol_map[op];
595 }else{
596 d_symbol_map[op] = op;
597 }
598 }
599 children[0] = d_symbol_map[op];
600 //make sure all children have been taken care of
601 for( size_t i=0; i<n.getNumChildren(); i++ ){
602 TypeNode tn = children[i+1].getType();
603 TypeNode tna = getTypeForId( d_op_arg_types[op][i] );
604 if( tn!=tna ){
605 if( n[i].isConst() ){
606 children[i+1] = getNewSymbol( n[i], tna );
607 }else{
608 Trace("sort-inference-warn") << "Sort inference created bad child: " << n << " " << n[i] << " " << tn << " " << tna << std::endl;
609 Assert( false );
610 }
611 }
612 }
613 return NodeManager::currentNM()->mkNode( kind::APPLY_UF, children );
614 }else{
615 std::map< Node, Node >::iterator it = var_bound.find( n );
616 if( it!=var_bound.end() ){
617 return it->second;
618 }else if( n.getKind() == kind::VARIABLE || n.getKind() == kind::SKOLEM ){
619 if( d_symbol_map.find( n )==d_symbol_map.end() ){
620 TypeNode tn = getOrCreateTypeForId( d_op_return_types[n], n.getType() );
621 d_symbol_map[n] = getNewSymbol( n, tn );
622 }
623 return d_symbol_map[n];
624 }else if( n.isConst() ){
625 //just return n, we will fix at higher scope
626 return n;
627 }else{
628 if( childChanged ){
629 return NodeManager::currentNM()->mkNode( n.getKind(), children );
630 }else{
631 return n;
632 }
633 }
634 }
635
636 }
637
638 Node SortInference::mkInjection( TypeNode tn1, TypeNode tn2 ) {
639 std::vector< TypeNode > tns;
640 tns.push_back( tn1 );
641 TypeNode typ = NodeManager::currentNM()->mkFunctionType( tns, tn2 );
642 Node f = NodeManager::currentNM()->mkSkolem( "inj", typ, "injection for monotonicity constraint" );
643 Trace("sort-inference") << "-> Make injection " << f << " from " << tn1 << " to " << tn2 << std::endl;
644 Node v1 = NodeManager::currentNM()->mkBoundVar( "?x", tn1 );
645 Node v2 = NodeManager::currentNM()->mkBoundVar( "?y", tn1 );
646 Node ret = NodeManager::currentNM()->mkNode( kind::FORALL,
647 NodeManager::currentNM()->mkNode( kind::BOUND_VAR_LIST, v1, v2 ),
648 NodeManager::currentNM()->mkNode( kind::OR,
649 NodeManager::currentNM()->mkNode( kind::APPLY_UF, f, v1 ).eqNode( NodeManager::currentNM()->mkNode( kind::APPLY_UF, f, v2 ) ).negate(),
650 v1.eqNode( v2 ) ) );
651 ret = theory::Rewriter::rewrite( ret );
652 return ret;
653 }
654
655 int SortInference::getSortId( Node n ) {
656 Node op = n.getKind()==kind::APPLY_UF ? n.getOperator() : n;
657 if( d_op_return_types.find( op )!=d_op_return_types.end() ){
658 return d_type_union_find.getRepresentative( d_op_return_types[op] );
659 }else{
660 return 0;
661 }
662 }
663
664 int SortInference::getSortId( Node f, Node v ) {
665 if( d_var_types.find( f )!=d_var_types.end() ){
666 return d_type_union_find.getRepresentative( d_var_types[f][v] );
667 }else{
668 return 0;
669 }
670 }
671
672 void SortInference::setSkolemVar( Node f, Node v, Node sk ){
673 Trace("sort-inference-temp") << "Set skolem var for " << f << ", variable " << v << std::endl;
674 if( isWellSortedFormula( f ) && d_var_types.find( f )==d_var_types.end() ){
675 //calculate the sort for variables if not done so already
676 std::map< Node, Node > var_bound;
677 process( f, var_bound );
678 }
679 d_op_return_types[sk] = getSortId( f, v );
680 Trace("sort-inference-temp") << "Set skolem sort id for " << sk << " to " << d_op_return_types[sk] << std::endl;
681 }
682
683 bool SortInference::isWellSortedFormula( Node n ) {
684 if( n.getType().isBoolean() && n.getKind()!=kind::APPLY_UF ){
685 for( unsigned i=0; i<n.getNumChildren(); i++ ){
686 if( !isWellSortedFormula( n[i] ) ){
687 return false;
688 }
689 }
690 return true;
691 }else{
692 return isWellSorted( n );
693 }
694 }
695
696 bool SortInference::isWellSorted( Node n ) {
697 if( getSortId( n )==0 ){
698 return false;
699 }else{
700 if( n.getKind()==kind::APPLY_UF ){
701 for( unsigned i=0; i<n.getNumChildren(); i++ ){
702 int s1 = getSortId( n[i] );
703 int s2 = d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] );
704 if( s1!=s2 ){
705 return false;
706 }
707 if( !isWellSorted( n[i] ) ){
708 return false;
709 }
710 }
711 }
712 return true;
713 }
714 }
715
716 void SortInference::getSortConstraints( Node n, UnionFind& uf ) {
717 if( n.getKind()==kind::APPLY_UF ){
718 for( unsigned i=0; i<n.getNumChildren(); i++ ){
719 getSortConstraints( n[i], uf );
720 uf.setEqual( getSortId( n[i] ), d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] ) );
721 }
722 }
723 }
724
725 }/* CVC4 namespace */