Update copyright headers.
[cvc5.git] / src / theory / sort_inference.cpp
1 /********************* */
2 /*! \file sort_inference.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds, Paul Meng, Morgan Deters
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2017 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Sort inference module
13 **
14 ** This class implements sort inference, based on a simple algorithm:
15 ** First, we assume all functions and predicates have distinct uninterpreted types.
16 ** One pass is made through the input assertions, while a union-find data structure
17 ** maintains necessary information regarding constraints on these types.
18 **/
19
20 #include "theory/sort_inference.h"
21
22 #include <vector>
23
24 #include "options/quantifiers_options.h"
25 #include "options/smt_options.h"
26 #include "options/uf_options.h"
27 #include "proof/proof_manager.h"
28 #include "theory/rewriter.h"
29 #include "theory/quantifiers/quant_util.h"
30
31 using namespace CVC4;
32 using namespace std;
33
34 namespace CVC4 {
35
36 void SortInference::UnionFind::print(const char * c){
37 for( std::map< int, int >::iterator it = d_eqc.begin(); it != d_eqc.end(); ++it ){
38 Trace(c) << "s_" << it->first << " = s_" << it->second << ", ";
39 }
40 for( unsigned i=0; i<d_deq.size(); i++ ){
41 Trace(c) << "s_" << d_deq[i].first << " != s_" << d_deq[i].second << ", ";
42 }
43 Trace(c) << std::endl;
44 }
45 void SortInference::UnionFind::set( UnionFind& c ) {
46 clear();
47 for( std::map< int, int >::iterator it = c.d_eqc.begin(); it != c.d_eqc.end(); ++it ){
48 d_eqc[ it->first ] = it->second;
49 }
50 d_deq.insert( d_deq.end(), c.d_deq.begin(), c.d_deq.end() );
51 }
52 int SortInference::UnionFind::getRepresentative( int t ){
53 std::map< int, int >::iterator it = d_eqc.find( t );
54 if( it==d_eqc.end() || it->second==t ){
55 return t;
56 }else{
57 int rt = getRepresentative( it->second );
58 d_eqc[t] = rt;
59 return rt;
60 }
61 }
62 void SortInference::UnionFind::setEqual( int t1, int t2 ){
63 if( t1!=t2 ){
64 int rt1 = getRepresentative( t1 );
65 int rt2 = getRepresentative( t2 );
66 if( rt1>rt2 ){
67 d_eqc[rt1] = rt2;
68 }else{
69 d_eqc[rt2] = rt1;
70 }
71 }
72 }
73 bool SortInference::UnionFind::isValid() {
74 for( unsigned i=0; i<d_deq.size(); i++ ){
75 if( areEqual( d_deq[i].first, d_deq[i].second ) ){
76 return false;
77 }
78 }
79 return true;
80 }
81
82
83 void SortInference::recordSubsort( TypeNode tn, int s ){
84 s = d_type_union_find.getRepresentative( s );
85 if( std::find( d_sub_sorts.begin(), d_sub_sorts.end(), s )==d_sub_sorts.end() ){
86 d_sub_sorts.push_back( s );
87 d_type_sub_sorts[tn].push_back( s );
88 }
89 }
90
91 void SortInference::printSort( const char* c, int t ){
92 int rt = d_type_union_find.getRepresentative( t );
93 if( d_type_types.find( rt )!=d_type_types.end() ){
94 Trace(c) << d_type_types[rt];
95 }else{
96 Trace(c) << "s_" << rt;
97 }
98 }
99
100 void SortInference::reset() {
101 d_sub_sorts.clear();
102 d_non_monotonic_sorts.clear();
103 d_type_sub_sorts.clear();
104 //reset info
105 sortCount = 1;
106 d_type_union_find.clear();
107 d_type_types.clear();
108 d_id_for_types.clear();
109 d_op_return_types.clear();
110 d_op_arg_types.clear();
111 d_var_types.clear();
112 //for rewriting
113 d_symbol_map.clear();
114 d_const_map.clear();
115 }
116
117 void SortInference::simplify( std::vector< Node >& assertions, bool doSortInference, bool doMonotonicyInference ){
118 if( doSortInference ){
119 Trace("sort-inference-proc") << "Calculating sort inference..." << std::endl;
120 //process all assertions
121 std::map< Node, int > visited;
122 for( unsigned i=0; i<assertions.size(); i++ ){
123 Trace("sort-inference-debug") << "Process " << assertions[i] << std::endl;
124 std::map< Node, Node > var_bound;
125 process( assertions[i], var_bound, visited );
126 }
127 Trace("sort-inference-proc") << "...done" << std::endl;
128 for( std::map< Node, int >::iterator it = d_op_return_types.begin(); it != d_op_return_types.end(); ++it ){
129 Trace("sort-inference") << it->first << " : ";
130 TypeNode retTn = it->first.getType();
131 if( !d_op_arg_types[ it->first ].empty() ){
132 Trace("sort-inference") << "( ";
133 for( size_t i=0; i<d_op_arg_types[ it->first ].size(); i++ ){
134 recordSubsort( retTn[i], d_op_arg_types[ it->first ][i] );
135 printSort( "sort-inference", d_op_arg_types[ it->first ][i] );
136 Trace("sort-inference") << " ";
137 }
138 Trace("sort-inference") << ") -> ";
139 retTn = retTn[(int)retTn.getNumChildren()-1];
140 }
141 recordSubsort( retTn, it->second );
142 printSort( "sort-inference", it->second );
143 Trace("sort-inference") << std::endl;
144 }
145 for( std::map< Node, std::map< Node, int > >::iterator it = d_var_types.begin(); it != d_var_types.end(); ++it ){
146 Trace("sort-inference") << "Quantified formula : " << it->first << " : " << std::endl;
147 for( unsigned i=0; i<it->first[0].getNumChildren(); i++ ){
148 recordSubsort( it->first[0][i].getType(), it->second[it->first[0][i]] );
149 printSort( "sort-inference", it->second[it->first[0][i]] );
150 Trace("sort-inference") << std::endl;
151 }
152 Trace("sort-inference") << std::endl;
153 }
154
155 if( !options::ufssSymBreak() ){
156 bool rewritten = false;
157 //determine monotonicity of sorts
158 Trace("sort-inference-proc") << "Calculating monotonicty for subsorts..." << std::endl;
159 std::map< Node, std::map< int, bool > > visited;
160 for( unsigned i=0; i<assertions.size(); i++ ){
161 Trace("sort-inference-debug") << "Process monotonicity for " << assertions[i] << std::endl;
162 std::map< Node, Node > var_bound;
163 processMonotonic( assertions[i], true, true, var_bound, visited );
164 }
165 Trace("sort-inference-proc") << "...done" << std::endl;
166
167 Trace("sort-inference") << "We have " << d_sub_sorts.size() << " sub-sorts : " << std::endl;
168 for( unsigned i=0; i<d_sub_sorts.size(); i++ ){
169 printSort( "sort-inference", d_sub_sorts[i] );
170 if( d_type_types.find( d_sub_sorts[i] )!=d_type_types.end() ){
171 Trace("sort-inference") << " is interpreted." << std::endl;
172 }else if( d_non_monotonic_sorts.find( d_sub_sorts[i] )==d_non_monotonic_sorts.end() ){
173 Trace("sort-inference") << " is monotonic." << std::endl;
174 }else{
175 Trace("sort-inference") << " is not monotonic." << std::endl;
176 }
177 }
178
179 //simplify all assertions by introducing new symbols wherever necessary
180 Trace("sort-inference-proc") << "Perform simplification..." << std::endl;
181 std::map< Node, std::map< TypeNode, Node > > visited2;
182 for( unsigned i=0; i<assertions.size(); i++ ){
183 Node prev = assertions[i];
184 std::map< Node, Node > var_bound;
185 Trace("sort-inference-debug") << "Simplify " << assertions[i] << std::endl;
186 TypeNode tnn;
187 Node curr = simplifyNode( assertions[i], var_bound, tnn, visited2 );
188 Trace("sort-inference-debug") << "Done." << std::endl;
189 if( curr!=assertions[i] ){
190 Trace("sort-inference-debug") << "Rewrite " << curr << std::endl;
191 curr = theory::Rewriter::rewrite( curr );
192 rewritten = true;
193 Trace("sort-inference-rewrite") << assertions << std::endl;
194 Trace("sort-inference-rewrite") << " --> " << curr << std::endl;
195 PROOF( ProofManager::currentPM()->addDependence(curr, assertions[i]); );
196 assertions[i] = curr;
197 }
198 }
199 Trace("sort-inference-proc") << "...done" << std::endl;
200 //now, ensure constants are distinct
201 for( std::map< TypeNode, std::map< Node, Node > >::iterator it = d_const_map.begin(); it != d_const_map.end(); ++it ){
202 std::vector< Node > consts;
203 for( std::map< Node, Node >::iterator it2 = it->second.begin(); it2 != it->second.end(); ++it2 ){
204 Assert( it2->first.isConst() );
205 consts.push_back( it2->second );
206 }
207 //add lemma enforcing introduced constants to be distinct
208 if( consts.size()>1 ){
209 Node distinct_const = NodeManager::currentNM()->mkNode( kind::DISTINCT, consts );
210 Trace("sort-inference-rewrite") << "Add the constant distinctness lemma: " << std::endl;
211 Trace("sort-inference-rewrite") << " " << distinct_const << std::endl;
212 assertions.push_back( distinct_const );
213 rewritten = true;
214 }
215 }
216
217 //enforce constraints based on monotonicity
218 Trace("sort-inference-proc") << "Enforce monotonicity..." << std::endl;
219 for( std::map< TypeNode, std::vector< int > >::iterator it = d_type_sub_sorts.begin(); it != d_type_sub_sorts.end(); ++it ){
220 int nmonSort = -1;
221 for( unsigned i=0; i<it->second.size(); i++ ){
222 if( d_non_monotonic_sorts.find( it->second[i] )!=d_non_monotonic_sorts.end() ){
223 nmonSort = it->second[i];
224 break;
225 }
226 }
227 if( nmonSort!=-1 ){
228 std::vector< Node > injections;
229 TypeNode base_tn = getOrCreateTypeForId( nmonSort, it->first );
230 for( unsigned i=0; i<it->second.size(); i++ ){
231 if( it->second[i]!=nmonSort ){
232 TypeNode new_tn = getOrCreateTypeForId( it->second[i], it->first );
233 //make injection to nmonSort
234 Node a1 = mkInjection( new_tn, base_tn );
235 injections.push_back( a1 );
236 if( d_non_monotonic_sorts.find( it->second[i] )!=d_non_monotonic_sorts.end() ){
237 //also must make injection from nmonSort to this
238 Node a2 = mkInjection( base_tn, new_tn );
239 injections.push_back( a2 );
240 }
241 }
242 }
243 Trace("sort-inference-rewrite") << "Add the following injections for " << it->first << " to ensure consistency wrt non-monotonic sorts : " << std::endl;
244 for( unsigned j=0; j<injections.size(); j++ ){
245 Trace("sort-inference-rewrite") << " " << injections[j] << std::endl;
246 }
247 assertions.insert( assertions.end(), injections.begin(), injections.end() );
248 if( !injections.empty() ){
249 rewritten = true;
250 }
251 }
252 }
253 Trace("sort-inference-proc") << "...done" << std::endl;
254 //no sub-sort information is stored
255 reset();
256 Trace("sort-inference-debug") << "Finished sort inference, rewritten = " << rewritten << std::endl;
257 }
258 initialSortCount = sortCount;
259 }
260 if( doMonotonicyInference ){
261 std::map< Node, std::map< int, bool > > visited;
262 Trace("sort-inference-proc") << "Calculating monotonicty for types..." << std::endl;
263 for( unsigned i=0; i<assertions.size(); i++ ){
264 Trace("sort-inference-debug") << "Process type monotonicity for " << assertions[i] << std::endl;
265 std::map< Node, Node > var_bound;
266 processMonotonic( assertions[i], true, true, var_bound, visited, true );
267 }
268 Trace("sort-inference-proc") << "...done" << std::endl;
269 }
270 }
271
272 void SortInference::setEqual( int t1, int t2 ){
273 if( t1!=t2 ){
274 int rt1 = d_type_union_find.getRepresentative( t1 );
275 int rt2 = d_type_union_find.getRepresentative( t2 );
276 if( rt1!=rt2 ){
277 Trace("sort-inference-debug") << "Set equal : ";
278 printSort( "sort-inference-debug", rt1 );
279 Trace("sort-inference-debug") << " ";
280 printSort( "sort-inference-debug", rt2 );
281 Trace("sort-inference-debug") << std::endl;
282 /*
283 d_type_eq_class[rt1].insert( d_type_eq_class[rt1].end(), d_type_eq_class[rt2].begin(), d_type_eq_class[rt2].end() );
284 d_type_eq_class[rt2].clear();
285 Trace("sort-inference-debug") << "EqClass : { ";
286 for( int i=0; i<(int)d_type_eq_class[rt1].size(); i++ ){
287 Trace("sort-inference-debug") << d_type_eq_class[rt1][i] << ", ";
288 }
289 Trace("sort-inference-debug") << "}" << std::endl;
290 */
291 if( rt2>rt1 ){
292 //swap
293 int swap = rt1;
294 rt1 = rt2;
295 rt2 = swap;
296 }
297 std::map< int, TypeNode >::iterator it1 = d_type_types.find( rt1 );
298 if( it1!=d_type_types.end() ){
299 if( d_type_types.find( rt2 )==d_type_types.end() ){
300 d_type_types[rt2] = it1->second;
301 d_type_types.erase( rt1 );
302 }else{
303 Trace("sort-inference-debug") << "...fail : associated with types " << d_type_types[rt1] << " and " << d_type_types[rt2] << std::endl;
304 return;
305 }
306 }
307 d_type_union_find.d_eqc[rt1] = rt2;
308 }
309 }
310 }
311
312 int SortInference::getIdForType( TypeNode tn ){
313 //register the return type
314 std::map< TypeNode, int >::iterator it = d_id_for_types.find( tn );
315 if( it==d_id_for_types.end() ){
316 int sc = sortCount;
317 d_type_types[ sortCount ] = tn;
318 d_id_for_types[ tn ] = sortCount;
319 sortCount++;
320 return sc;
321 }else{
322 return it->second;
323 }
324 }
325
326 int SortInference::process( Node n, std::map< Node, Node >& var_bound, std::map< Node, int >& visited ){
327 std::map< Node, int >::iterator itv = visited.find( n );
328 if( itv!=visited.end() ){
329 return itv->second;
330 }else{
331 //add to variable bindings
332 bool use_new_visited = false;
333 std::map< Node, int > new_visited;
334 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
335 if( d_var_types.find( n )!=d_var_types.end() ){
336 return getIdForType( n.getType() );
337 }else{
338 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
339 //apply sort inference to quantified variables
340 d_var_types[n][ n[0][i] ] = sortCount;
341 sortCount++;
342
343 //type of the quantified variable must be the same
344 var_bound[ n[0][i] ] = n;
345 }
346 }
347 use_new_visited = true;
348 }
349
350 //process children
351 std::vector< Node > children;
352 std::vector< int > child_types;
353 for( size_t i=0; i<n.getNumChildren(); i++ ){
354 bool processChild = true;
355 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
356 processChild = options::userPatternsQuant()==theory::quantifiers::USER_PAT_MODE_IGNORE ? i==1 : i>=1;
357 }
358 if( processChild ){
359 children.push_back( n[i] );
360 child_types.push_back( process( n[i], var_bound, use_new_visited ? new_visited : visited ) );
361 }
362 }
363
364 //remove from variable bindings
365 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
366 //erase from variable bound
367 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
368 var_bound.erase( n[0][i] );
369 }
370 }
371 Trace("sort-inference-debug") << "...Process " << n << std::endl;
372
373 int retType;
374 if( n.getKind()==kind::EQUAL && !n[0].getType().isBoolean() ){
375 Trace("sort-inference-debug") << "For equality " << n << ", set equal types from : " << n[0].getType() << " " << n[1].getType() << std::endl;
376 //if original types are mixed (e.g. Int/Real), don't commit type equality in either direction
377 if( n[0].getType()!=n[1].getType() ){
378 //for now, assume the original types
379 for( unsigned i=0; i<2; i++ ){
380 int ct = getIdForType( n[i].getType() );
381 setEqual( child_types[i], ct );
382 }
383 }else{
384 //we only require that the left and right hand side must be equal
385 setEqual( child_types[0], child_types[1] );
386 }
387 d_equality_types[n] = child_types[0];
388 retType = getIdForType( n.getType() );
389 }else if( n.getKind()==kind::APPLY_UF ){
390 Node op = n.getOperator();
391 TypeNode tn_op = op.getType();
392 if( d_op_return_types.find( op )==d_op_return_types.end() ){
393 if( n.getType().isBoolean() ){
394 //use booleans
395 d_op_return_types[op] = getIdForType( n.getType() );
396 }else{
397 //assign arbitrary sort for return type
398 d_op_return_types[op] = sortCount;
399 sortCount++;
400 }
401 //d_type_eq_class[sortCount].push_back( op );
402 //assign arbitrary sort for argument types
403 for( size_t i=0; i<n.getNumChildren(); i++ ){
404 d_op_arg_types[op].push_back( sortCount );
405 sortCount++;
406 }
407 }
408 for( size_t i=0; i<n.getNumChildren(); i++ ){
409 //the argument of the operator must match the return type of the subterm
410 if( n[i].getType()!=tn_op[i] ){
411 //if type mismatch, assume original types
412 Trace("sort-inference-debug") << "Argument " << i << " of " << op << " " << n[i] << " has type " << n[i].getType();
413 Trace("sort-inference-debug") << ", while operator arg has type " << tn_op[i] << std::endl;
414 int ct1 = getIdForType( n[i].getType() );
415 setEqual( child_types[i], ct1 );
416 int ct2 = getIdForType( tn_op[i] );
417 setEqual( d_op_arg_types[op][i], ct2 );
418 }else{
419 setEqual( child_types[i], d_op_arg_types[op][i] );
420 }
421 }
422 //return type is the return type
423 retType = d_op_return_types[op];
424 }else{
425 std::map< Node, Node >::iterator it = var_bound.find( n );
426 if( it!=var_bound.end() ){
427 Trace("sort-inference-debug") << n << " is a bound variable." << std::endl;
428 //the return type was specified while binding
429 retType = d_var_types[it->second][n];
430 }else if( n.getKind() == kind::VARIABLE || n.getKind()==kind::SKOLEM ){
431 Trace("sort-inference-debug") << n << " is a variable." << std::endl;
432 if( d_op_return_types.find( n )==d_op_return_types.end() ){
433 //assign arbitrary sort
434 d_op_return_types[n] = sortCount;
435 sortCount++;
436 //d_type_eq_class[sortCount].push_back( n );
437 }
438 retType = d_op_return_types[n];
439 }else if( n.isConst() ){
440 Trace("sort-inference-debug") << n << " is a constant." << std::endl;
441 //can be any type we want
442 retType = sortCount;
443 sortCount++;
444 }else{
445 Trace("sort-inference-debug") << n << " is a interpreted symbol." << std::endl;
446 //it is an interpreted term
447 for( size_t i=0; i<children.size(); i++ ){
448 Trace("sort-inference-debug") << children[i] << " forced to have " << children[i].getType() << std::endl;
449 //must enforce the actual type of the operator on the children
450 int ct = getIdForType( children[i].getType() );
451 setEqual( child_types[i], ct );
452 }
453 //return type must be the actual return type
454 retType = getIdForType( n.getType() );
455 }
456 }
457 Trace("sort-inference-debug") << "...Type( " << n << " ) = ";
458 printSort("sort-inference-debug", retType );
459 Trace("sort-inference-debug") << std::endl;
460 visited[n] = retType;
461 return retType;
462 }
463 }
464
465 void SortInference::processMonotonic( Node n, bool pol, bool hasPol, std::map< Node, Node >& var_bound, std::map< Node, std::map< int, bool > >& visited, bool typeMode ) {
466 int pindex = hasPol ? ( pol ? 1 : -1 ) : 0;
467 if( visited[n].find( pindex )==visited[n].end() ){
468 visited[n][pindex] = true;
469 Trace("sort-inference-debug") << "...Process monotonic " << pol << " " << hasPol << " " << n << std::endl;
470 if( n.getKind()==kind::FORALL ){
471 //only consider variables universally if it is possible this quantified formula is asserted positively
472 if( !hasPol || pol ){
473 for( unsigned i=0; i<n[0].getNumChildren(); i++ ){
474 var_bound[n[0][i]] = n;
475 }
476 }
477 processMonotonic( n[1], pol, hasPol, var_bound, visited, typeMode );
478 if( !hasPol || pol ){
479 for( unsigned i=0; i<n[0].getNumChildren(); i++ ){
480 var_bound.erase( n[0][i] );
481 }
482 }
483 return;
484 }else if( n.getKind()==kind::EQUAL ){
485 if( !hasPol || pol ){
486 for( unsigned i=0; i<2; i++ ){
487 if( var_bound.find( n[i] )!=var_bound.end() ){
488 if( !typeMode ){
489 int sid = getSortId( var_bound[n[i]], n[i] );
490 d_non_monotonic_sorts[sid] = true;
491 }else{
492 d_non_monotonic_sorts_orig[n[i].getType()] = true;
493 }
494 break;
495 }
496 }
497 }
498 }
499 for( unsigned i=0; i<n.getNumChildren(); i++ ){
500 bool npol;
501 bool nhasPol;
502 theory::QuantPhaseReq::getPolarity( n, i, hasPol, pol, nhasPol, npol );
503 processMonotonic( n[i], npol, nhasPol, var_bound, visited, typeMode );
504 }
505 }
506 }
507
508
509 TypeNode SortInference::getOrCreateTypeForId( int t, TypeNode pref ){
510 int rt = d_type_union_find.getRepresentative( t );
511 if( d_type_types.find( rt )!=d_type_types.end() ){
512 return d_type_types[rt];
513 }else{
514 TypeNode retType;
515 //see if we can assign pref
516 if( !pref.isNull() && d_id_for_types.find( pref )==d_id_for_types.end() ){
517 retType = pref;
518 }else{
519 //must create new type
520 std::stringstream ss;
521 ss << "it_" << t << "_" << pref;
522 retType = NodeManager::currentNM()->mkSort( ss.str() );
523 }
524 Trace("sort-inference") << "-> Make type " << retType << " to correspond to ";
525 printSort("sort-inference", t );
526 Trace("sort-inference") << std::endl;
527 d_id_for_types[ retType ] = rt;
528 d_type_types[ rt ] = retType;
529 return retType;
530 }
531 }
532
533 TypeNode SortInference::getTypeForId( int t ){
534 int rt = d_type_union_find.getRepresentative( t );
535 if( d_type_types.find( rt )!=d_type_types.end() ){
536 return d_type_types[rt];
537 }else{
538 return TypeNode::null();
539 }
540 }
541
542 Node SortInference::getNewSymbol( Node old, TypeNode tn ){
543 if( tn.isNull() || tn==old.getType() ){
544 return old;
545 }else if( old.isConst() ){
546 //must make constant of type tn
547 if( d_const_map[tn].find( old )==d_const_map[tn].end() ){
548 std::stringstream ss;
549 ss << "ic_" << tn << "_" << old;
550 d_const_map[tn][ old ] = NodeManager::currentNM()->mkSkolem( ss.str(), tn, "constant created during sort inference" ); //use mkConst???
551 }
552 return d_const_map[tn][ old ];
553 }else if( old.getKind()==kind::BOUND_VARIABLE ){
554 std::stringstream ss;
555 ss << "b_" << old;
556 return NodeManager::currentNM()->mkBoundVar( ss.str(), tn );
557 }else{
558 std::stringstream ss;
559 ss << "i_" << old;
560 return NodeManager::currentNM()->mkSkolem( ss.str(), tn, "created during sort inference" );
561 }
562 }
563
564 Node SortInference::simplifyNode( Node n, std::map< Node, Node >& var_bound, TypeNode tnn, std::map< Node, std::map< TypeNode, Node > >& visited ){
565 std::map< TypeNode, Node >::iterator itv = visited[n].find( tnn );
566 if( itv!=visited[n].end() ){
567 return itv->second;
568 }else{
569 Trace("sort-inference-debug2") << "Simplify " << n << ", type context=" << tnn << std::endl;
570 std::vector< Node > children;
571 std::map< Node, std::map< TypeNode, Node > > new_visited;
572 bool use_new_visited = false;
573 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
574 //recreate based on types of variables
575 std::vector< Node > new_children;
576 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
577 TypeNode tn = getOrCreateTypeForId( d_var_types[n][ n[0][i] ], n[0][i].getType() );
578 Node v = getNewSymbol( n[0][i], tn );
579 Trace("sort-inference-debug2") << "Map variable " << n[0][i] << " to " << v << std::endl;
580 new_children.push_back( v );
581 var_bound[ n[0][i] ] = v;
582 }
583 children.push_back( NodeManager::currentNM()->mkNode( n[0].getKind(), new_children ) );
584 use_new_visited = true;
585 }
586
587 //process children
588 if( n.getMetaKind() == kind::metakind::PARAMETERIZED ){
589 children.push_back( n.getOperator() );
590 }
591 Node op;
592 if( n.hasOperator() ){
593 op = n.getOperator();
594 }
595 bool childChanged = false;
596 TypeNode tnnc;
597 for( size_t i=0; i<n.getNumChildren(); i++ ){
598 bool processChild = true;
599 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
600 processChild = options::userPatternsQuant()==theory::quantifiers::USER_PAT_MODE_IGNORE ? i==1 : i>=1;
601 }
602 if( processChild ){
603 if( n.getKind()==kind::APPLY_UF ){
604 Assert( d_op_arg_types.find( op )!=d_op_arg_types.end() );
605 tnnc = getOrCreateTypeForId( d_op_arg_types[op][i], n[i].getType() );
606 Assert( !tnnc.isNull() );
607 }else if( n.getKind()==kind::EQUAL && i==0 ){
608 Assert( d_equality_types.find( n )!=d_equality_types.end() );
609 tnnc = getOrCreateTypeForId( d_equality_types[n], n[0].getType() );
610 Assert( !tnnc.isNull() );
611 }
612 Node nc = simplifyNode( n[i], var_bound, tnnc, use_new_visited ? new_visited : visited );
613 Trace("sort-inference-debug2") << "Simplify " << i << " " << n[i] << " returned " << nc << std::endl;
614 children.push_back( nc );
615 childChanged = childChanged || nc!=n[i];
616 }
617 }
618
619 //remove from variable bindings
620 Node ret;
621 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
622 //erase from variable bound
623 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
624 Trace("sort-inference-debug2") << "Remove bound for " << n[0][i] << std::endl;
625 var_bound.erase( n[0][i] );
626 }
627 ret = NodeManager::currentNM()->mkNode( n.getKind(), children );
628 }else if( n.getKind()==kind::EQUAL ){
629 TypeNode tn1 = children[0].getType();
630 TypeNode tn2 = children[1].getType();
631 if( !tn1.isSubtypeOf( tn2 ) && !tn2.isSubtypeOf( tn1 ) ){
632 Trace("sort-inference-warn") << "Sort inference created bad equality: " << children[0] << " = " << children[1] << std::endl;
633 Trace("sort-inference-warn") << " Types : " << children[0].getType() << " " << children[1].getType() << std::endl;
634 Assert( false );
635 }
636 ret = NodeManager::currentNM()->mkNode( kind::EQUAL, children );
637 }else if( n.getKind()==kind::APPLY_UF ){
638 if( d_symbol_map.find( op )==d_symbol_map.end() ){
639 //make the new operator if necessary
640 bool opChanged = false;
641 std::vector< TypeNode > argTypes;
642 for( size_t i=0; i<n.getNumChildren(); i++ ){
643 TypeNode tn = getOrCreateTypeForId( d_op_arg_types[op][i], n[i].getType() );
644 argTypes.push_back( tn );
645 if( tn!=n[i].getType() ){
646 opChanged = true;
647 }
648 }
649 TypeNode retType = getOrCreateTypeForId( d_op_return_types[op], n.getType() );
650 if( retType!=n.getType() ){
651 opChanged = true;
652 }
653 if( opChanged ){
654 std::stringstream ss;
655 ss << "io_" << op;
656 TypeNode typ = NodeManager::currentNM()->mkFunctionType( argTypes, retType );
657 d_symbol_map[op] = NodeManager::currentNM()->mkSkolem( ss.str(), typ, "op created during sort inference" );
658 Trace("setp-model") << "Function " << op << " is replaced with " << d_symbol_map[op] << std::endl;
659 d_model_replace_f[op] = d_symbol_map[op];
660 }else{
661 d_symbol_map[op] = op;
662 }
663 }
664 children[0] = d_symbol_map[op];
665 //make sure all children have been taken care of
666 for( size_t i=0; i<n.getNumChildren(); i++ ){
667 TypeNode tn = children[i+1].getType();
668 TypeNode tna = getTypeForId( d_op_arg_types[op][i] );
669 if( tn!=tna ){
670 Trace("sort-inference-warn") << "Sort inference created bad child: " << n << " " << n[i] << " " << tn << " " << tna << std::endl;
671 Assert( false );
672 }
673 }
674 ret = NodeManager::currentNM()->mkNode( kind::APPLY_UF, children );
675 }else{
676 std::map< Node, Node >::iterator it = var_bound.find( n );
677 if( it!=var_bound.end() ){
678 ret = it->second;
679 }else if( n.getKind() == kind::VARIABLE || n.getKind() == kind::SKOLEM ){
680 if( d_symbol_map.find( n )==d_symbol_map.end() ){
681 TypeNode tn = getOrCreateTypeForId( d_op_return_types[n], n.getType() );
682 d_symbol_map[n] = getNewSymbol( n, tn );
683 }
684 ret = d_symbol_map[n];
685 }else if( n.isConst() ){
686 //type is determined by context
687 ret = getNewSymbol( n, tnn );
688 }else if( childChanged ){
689 ret = NodeManager::currentNM()->mkNode( n.getKind(), children );
690 }else{
691 ret = n;
692 }
693 }
694 visited[n][tnn] = ret;
695 return ret;
696 }
697 }
698
699 Node SortInference::mkInjection( TypeNode tn1, TypeNode tn2 ) {
700 std::vector< TypeNode > tns;
701 tns.push_back( tn1 );
702 TypeNode typ = NodeManager::currentNM()->mkFunctionType( tns, tn2 );
703 Node f = NodeManager::currentNM()->mkSkolem( "inj", typ, "injection for monotonicity constraint" );
704 Trace("sort-inference") << "-> Make injection " << f << " from " << tn1 << " to " << tn2 << std::endl;
705 Node v1 = NodeManager::currentNM()->mkBoundVar( "?x", tn1 );
706 Node v2 = NodeManager::currentNM()->mkBoundVar( "?y", tn1 );
707 Node ret = NodeManager::currentNM()->mkNode( kind::FORALL,
708 NodeManager::currentNM()->mkNode( kind::BOUND_VAR_LIST, v1, v2 ),
709 NodeManager::currentNM()->mkNode( kind::OR,
710 NodeManager::currentNM()->mkNode( kind::APPLY_UF, f, v1 ).eqNode( NodeManager::currentNM()->mkNode( kind::APPLY_UF, f, v2 ) ).negate(),
711 v1.eqNode( v2 ) ) );
712 ret = theory::Rewriter::rewrite( ret );
713 return ret;
714 }
715
716 int SortInference::getSortId( Node n ) {
717 Node op = n.getKind()==kind::APPLY_UF ? n.getOperator() : n;
718 if( d_op_return_types.find( op )!=d_op_return_types.end() ){
719 return d_type_union_find.getRepresentative( d_op_return_types[op] );
720 }else{
721 return 0;
722 }
723 }
724
725 int SortInference::getSortId( Node f, Node v ) {
726 if( d_var_types.find( f )!=d_var_types.end() ){
727 return d_type_union_find.getRepresentative( d_var_types[f][v] );
728 }else{
729 return 0;
730 }
731 }
732
733 void SortInference::setSkolemVar( Node f, Node v, Node sk ){
734 Trace("sort-inference-temp") << "Set skolem var for " << f << ", variable " << v << std::endl;
735 if( isWellSortedFormula( f ) && d_var_types.find( f )==d_var_types.end() ){
736 //calculate the sort for variables if not done so already
737 std::map< Node, Node > var_bound;
738 std::map< Node, int > visited;
739 process( f, var_bound, visited );
740 }
741 d_op_return_types[sk] = getSortId( f, v );
742 Trace("sort-inference-temp") << "Set skolem sort id for " << sk << " to " << d_op_return_types[sk] << std::endl;
743 }
744
745 bool SortInference::isWellSortedFormula( Node n ) {
746 if( n.getType().isBoolean() && n.getKind()!=kind::APPLY_UF ){
747 for( unsigned i=0; i<n.getNumChildren(); i++ ){
748 if( !isWellSortedFormula( n[i] ) ){
749 return false;
750 }
751 }
752 return true;
753 }else{
754 return isWellSorted( n );
755 }
756 }
757
758 bool SortInference::isWellSorted( Node n ) {
759 if( getSortId( n )==0 ){
760 return false;
761 }else{
762 if( n.getKind()==kind::APPLY_UF ){
763 for( unsigned i=0; i<n.getNumChildren(); i++ ){
764 int s1 = getSortId( n[i] );
765 int s2 = d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] );
766 if( s1!=s2 ){
767 return false;
768 }
769 if( !isWellSorted( n[i] ) ){
770 return false;
771 }
772 }
773 }
774 return true;
775 }
776 }
777
778 void SortInference::getSortConstraints( Node n, UnionFind& uf ) {
779 if( n.getKind()==kind::APPLY_UF ){
780 for( unsigned i=0; i<n.getNumChildren(); i++ ){
781 getSortConstraints( n[i], uf );
782 uf.setEqual( getSortId( n[i] ), d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] ) );
783 }
784 }
785 }
786
787 bool SortInference::isMonotonic( TypeNode tn ) {
788 Assert( tn.isSort() );
789 return d_non_monotonic_sorts_orig.find( tn )==d_non_monotonic_sorts_orig.end();
790 }
791
792 }/* CVC4 namespace */