Fixes for quantifiers + incremental (#2009)
[cvc5.git] / src / theory / sort_inference.cpp
1 /********************* */
2 /*! \file sort_inference.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds, Paul Meng, Morgan Deters
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2017 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Sort inference module
13 **
14 ** This class implements sort inference, based on a simple algorithm:
15 ** First, we assume all functions and predicates have distinct uninterpreted types.
16 ** One pass is made through the input assertions, while a union-find data structure
17 ** maintains necessary information regarding constraints on these types.
18 **/
19
20 #include "theory/sort_inference.h"
21
22 #include <vector>
23
24 #include "options/quantifiers_options.h"
25 #include "options/smt_options.h"
26 #include "options/uf_options.h"
27 #include "proof/proof_manager.h"
28 #include "theory/rewriter.h"
29 #include "theory/quantifiers/quant_util.h"
30
31 using namespace CVC4;
32 using namespace std;
33
34 namespace CVC4 {
35
36 void SortInference::UnionFind::print(const char * c){
37 for( std::map< int, int >::iterator it = d_eqc.begin(); it != d_eqc.end(); ++it ){
38 Trace(c) << "s_" << it->first << " = s_" << it->second << ", ";
39 }
40 for( unsigned i=0; i<d_deq.size(); i++ ){
41 Trace(c) << "s_" << d_deq[i].first << " != s_" << d_deq[i].second << ", ";
42 }
43 Trace(c) << std::endl;
44 }
45 void SortInference::UnionFind::set( UnionFind& c ) {
46 clear();
47 for( std::map< int, int >::iterator it = c.d_eqc.begin(); it != c.d_eqc.end(); ++it ){
48 d_eqc[ it->first ] = it->second;
49 }
50 d_deq.insert( d_deq.end(), c.d_deq.begin(), c.d_deq.end() );
51 }
52 int SortInference::UnionFind::getRepresentative( int t ){
53 std::map< int, int >::iterator it = d_eqc.find( t );
54 if( it==d_eqc.end() || it->second==t ){
55 return t;
56 }else{
57 int rt = getRepresentative( it->second );
58 d_eqc[t] = rt;
59 return rt;
60 }
61 }
62 void SortInference::UnionFind::setEqual( int t1, int t2 ){
63 if( t1!=t2 ){
64 int rt1 = getRepresentative( t1 );
65 int rt2 = getRepresentative( t2 );
66 if( rt1>rt2 ){
67 d_eqc[rt1] = rt2;
68 }else{
69 d_eqc[rt2] = rt1;
70 }
71 }
72 }
73 bool SortInference::UnionFind::isValid() {
74 for( unsigned i=0; i<d_deq.size(); i++ ){
75 if( areEqual( d_deq[i].first, d_deq[i].second ) ){
76 return false;
77 }
78 }
79 return true;
80 }
81
82
83 void SortInference::recordSubsort( TypeNode tn, int s ){
84 s = d_type_union_find.getRepresentative( s );
85 if( std::find( d_sub_sorts.begin(), d_sub_sorts.end(), s )==d_sub_sorts.end() ){
86 d_sub_sorts.push_back( s );
87 d_type_sub_sorts[tn].push_back( s );
88 }
89 }
90
91 void SortInference::printSort( const char* c, int t ){
92 int rt = d_type_union_find.getRepresentative( t );
93 if( d_type_types.find( rt )!=d_type_types.end() ){
94 Trace(c) << d_type_types[rt];
95 }else{
96 Trace(c) << "s_" << rt;
97 }
98 }
99
100 void SortInference::reset() {
101 d_sub_sorts.clear();
102 d_non_monotonic_sorts.clear();
103 d_type_sub_sorts.clear();
104 //reset info
105 sortCount = 1;
106 d_type_union_find.clear();
107 d_type_types.clear();
108 d_id_for_types.clear();
109 d_op_return_types.clear();
110 d_op_arg_types.clear();
111 d_var_types.clear();
112 //for rewriting
113 d_symbol_map.clear();
114 d_const_map.clear();
115 }
116
117 void SortInference::simplify( std::vector< Node >& assertions, bool doSortInference, bool doMonotonicyInference ){
118 if( doSortInference ){
119 Trace("sort-inference-proc") << "Calculating sort inference..." << std::endl;
120 NodeManager* nm = NodeManager::currentNM();
121 //process all assertions
122 std::map< Node, int > visited;
123 for( unsigned i=0; i<assertions.size(); i++ ){
124 Trace("sort-inference-debug") << "Process " << assertions[i] << std::endl;
125 std::map< Node, Node > var_bound;
126 process( assertions[i], var_bound, visited );
127 }
128 Trace("sort-inference-proc") << "...done" << std::endl;
129 for( std::map< Node, int >::iterator it = d_op_return_types.begin(); it != d_op_return_types.end(); ++it ){
130 Trace("sort-inference") << it->first << " : ";
131 TypeNode retTn = it->first.getType();
132 if( !d_op_arg_types[ it->first ].empty() ){
133 Trace("sort-inference") << "( ";
134 for( size_t i=0; i<d_op_arg_types[ it->first ].size(); i++ ){
135 recordSubsort( retTn[i], d_op_arg_types[ it->first ][i] );
136 printSort( "sort-inference", d_op_arg_types[ it->first ][i] );
137 Trace("sort-inference") << " ";
138 }
139 Trace("sort-inference") << ") -> ";
140 retTn = retTn[(int)retTn.getNumChildren()-1];
141 }
142 recordSubsort( retTn, it->second );
143 printSort( "sort-inference", it->second );
144 Trace("sort-inference") << std::endl;
145 }
146 for( std::map< Node, std::map< Node, int > >::iterator it = d_var_types.begin(); it != d_var_types.end(); ++it ){
147 Trace("sort-inference") << "Quantified formula : " << it->first << " : " << std::endl;
148 for( unsigned i=0; i<it->first[0].getNumChildren(); i++ ){
149 recordSubsort( it->first[0][i].getType(), it->second[it->first[0][i]] );
150 printSort( "sort-inference", it->second[it->first[0][i]] );
151 Trace("sort-inference") << std::endl;
152 }
153 Trace("sort-inference") << std::endl;
154 }
155
156 bool rewritten = false;
157 // determine monotonicity of sorts
158 Trace("sort-inference-proc") << "Calculating monotonicty for subsorts..."
159 << std::endl;
160 std::map<Node, std::map<int, bool> > visitedm;
161 for (const Node& a : assertions)
162 {
163 Trace("sort-inference-debug") << "Process monotonicity for " << a
164 << std::endl;
165 std::map<Node, Node> var_bound;
166 processMonotonic(a, true, true, var_bound, visitedm);
167 }
168 Trace("sort-inference-proc") << "...done" << std::endl;
169
170 Trace("sort-inference") << "We have " << d_sub_sorts.size()
171 << " sub-sorts : " << std::endl;
172 for (unsigned i = 0, size = d_sub_sorts.size(); i < size; i++)
173 {
174 printSort("sort-inference", d_sub_sorts[i]);
175 if (d_type_types.find(d_sub_sorts[i]) != d_type_types.end())
176 {
177 Trace("sort-inference") << " is interpreted." << std::endl;
178 }
179 else if (d_non_monotonic_sorts.find(d_sub_sorts[i])
180 == d_non_monotonic_sorts.end())
181 {
182 Trace("sort-inference") << " is monotonic." << std::endl;
183 }
184 else
185 {
186 Trace("sort-inference") << " is not monotonic." << std::endl;
187 }
188 }
189
190 // simplify all assertions by introducing new symbols wherever necessary
191 Trace("sort-inference-proc") << "Perform simplification..." << std::endl;
192 std::map<Node, std::map<TypeNode, Node> > visited2;
193 for (unsigned i = 0, size = assertions.size(); i < size; i++)
194 {
195 Node prev = assertions[i];
196 std::map<Node, Node> var_bound;
197 Trace("sort-inference-debug") << "Simplify " << prev << std::endl;
198 TypeNode tnn;
199 Node curr = simplifyNode(assertions[i], var_bound, tnn, visited2);
200 Trace("sort-inference-debug") << "Done." << std::endl;
201 if (curr != assertions[i])
202 {
203 Trace("sort-inference-debug") << "Rewrite " << curr << std::endl;
204 curr = theory::Rewriter::rewrite(curr);
205 rewritten = true;
206 Trace("sort-inference-rewrite") << assertions << std::endl;
207 Trace("sort-inference-rewrite") << " --> " << curr << std::endl;
208 PROOF(ProofManager::currentPM()->addDependence(curr, assertions[i]););
209 assertions[i] = curr;
210 }
211 }
212 Trace("sort-inference-proc") << "...done" << std::endl;
213 // now, ensure constants are distinct
214 for (std::map<TypeNode, std::map<Node, Node> >::iterator it =
215 d_const_map.begin();
216 it != d_const_map.end();
217 ++it)
218 {
219 std::vector<Node> consts;
220 for (std::map<Node, Node>::iterator it2 = it->second.begin();
221 it2 != it->second.end();
222 ++it2)
223 {
224 Assert(it2->first.isConst());
225 consts.push_back(it2->second);
226 }
227 // add lemma enforcing introduced constants to be distinct
228 if (consts.size() > 1)
229 {
230 Node distinct_const = nm->mkNode(kind::DISTINCT, consts);
231 Trace("sort-inference-rewrite")
232 << "Add the constant distinctness lemma: " << std::endl;
233 Trace("sort-inference-rewrite") << " " << distinct_const << std::endl;
234 assertions.push_back(distinct_const);
235 rewritten = true;
236 }
237 }
238
239 // enforce constraints based on monotonicity
240 Trace("sort-inference-proc") << "Enforce monotonicity..." << std::endl;
241 for (std::map<TypeNode, std::vector<int> >::iterator it =
242 d_type_sub_sorts.begin();
243 it != d_type_sub_sorts.end();
244 ++it)
245 {
246 int nmonSort = -1;
247 unsigned nsorts = it->second.size();
248 for (unsigned i = 0; i < nsorts; i++)
249 {
250 if (d_non_monotonic_sorts.find(it->second[i])
251 != d_non_monotonic_sorts.end())
252 {
253 nmonSort = it->second[i];
254 break;
255 }
256 }
257 if (nmonSort != -1)
258 {
259 std::vector<Node> injections;
260 TypeNode base_tn = getOrCreateTypeForId(nmonSort, it->first);
261 for (unsigned i = 0; i < nsorts; i++)
262 {
263 if (it->second[i] != nmonSort)
264 {
265 TypeNode new_tn = getOrCreateTypeForId(it->second[i], it->first);
266 // make injection to nmonSort
267 Node a1 = mkInjection(new_tn, base_tn);
268 injections.push_back(a1);
269 if (d_non_monotonic_sorts.find(it->second[i])
270 != d_non_monotonic_sorts.end())
271 {
272 // also must make injection from nmonSort to this
273 Node a2 = mkInjection(base_tn, new_tn);
274 injections.push_back(a2);
275 }
276 }
277 }
278 if (Trace.isOn("sort-inference-rewrite"))
279 {
280 Trace("sort-inference-rewrite") << "Add the following injections for " << it->first << " to ensure consistency wrt non-monotonic sorts : " << std::endl;
281 for (const Node& i : injections)
282 {
283 Trace("sort-inference-rewrite") << " " << i << std::endl;
284 }
285 }
286 assertions.insert(
287 assertions.end(), injections.begin(), injections.end());
288 if (!injections.empty())
289 {
290 rewritten = true;
291 }
292 }
293 }
294 Trace("sort-inference-proc") << "...done" << std::endl;
295 // no sub-sort information is stored
296 reset();
297 Trace("sort-inference-debug")
298 << "Finished sort inference, rewritten = " << rewritten << std::endl;
299
300 initialSortCount = sortCount;
301 }
302 if( doMonotonicyInference ){
303 std::map<Node, std::map<int, bool> > visitedmt;
304 Trace("sort-inference-proc") << "Calculating monotonicty for types..." << std::endl;
305 for (const Node& a : assertions)
306 {
307 Trace("sort-inference-debug") << "Process type monotonicity for " << a
308 << std::endl;
309 std::map< Node, Node > var_bound;
310 processMonotonic(a, true, true, var_bound, visitedmt, true);
311 }
312 Trace("sort-inference-proc") << "...done" << std::endl;
313 }
314 }
315
316 void SortInference::setEqual( int t1, int t2 ){
317 if( t1!=t2 ){
318 int rt1 = d_type_union_find.getRepresentative( t1 );
319 int rt2 = d_type_union_find.getRepresentative( t2 );
320 if( rt1!=rt2 ){
321 Trace("sort-inference-debug") << "Set equal : ";
322 printSort( "sort-inference-debug", rt1 );
323 Trace("sort-inference-debug") << " ";
324 printSort( "sort-inference-debug", rt2 );
325 Trace("sort-inference-debug") << std::endl;
326 /*
327 d_type_eq_class[rt1].insert( d_type_eq_class[rt1].end(), d_type_eq_class[rt2].begin(), d_type_eq_class[rt2].end() );
328 d_type_eq_class[rt2].clear();
329 Trace("sort-inference-debug") << "EqClass : { ";
330 for( int i=0; i<(int)d_type_eq_class[rt1].size(); i++ ){
331 Trace("sort-inference-debug") << d_type_eq_class[rt1][i] << ", ";
332 }
333 Trace("sort-inference-debug") << "}" << std::endl;
334 */
335 if( rt2>rt1 ){
336 //swap
337 int swap = rt1;
338 rt1 = rt2;
339 rt2 = swap;
340 }
341 std::map< int, TypeNode >::iterator it1 = d_type_types.find( rt1 );
342 if( it1!=d_type_types.end() ){
343 if( d_type_types.find( rt2 )==d_type_types.end() ){
344 d_type_types[rt2] = it1->second;
345 d_type_types.erase( rt1 );
346 }else{
347 Trace("sort-inference-debug") << "...fail : associated with types " << d_type_types[rt1] << " and " << d_type_types[rt2] << std::endl;
348 return;
349 }
350 }
351 d_type_union_find.d_eqc[rt1] = rt2;
352 }
353 }
354 }
355
356 int SortInference::getIdForType( TypeNode tn ){
357 //register the return type
358 std::map< TypeNode, int >::iterator it = d_id_for_types.find( tn );
359 if( it==d_id_for_types.end() ){
360 int sc = sortCount;
361 d_type_types[ sortCount ] = tn;
362 d_id_for_types[ tn ] = sortCount;
363 sortCount++;
364 return sc;
365 }else{
366 return it->second;
367 }
368 }
369
370 int SortInference::process( Node n, std::map< Node, Node >& var_bound, std::map< Node, int >& visited ){
371 std::map< Node, int >::iterator itv = visited.find( n );
372 if( itv!=visited.end() ){
373 return itv->second;
374 }else{
375 //add to variable bindings
376 bool use_new_visited = false;
377 std::map< Node, int > new_visited;
378 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
379 if( d_var_types.find( n )!=d_var_types.end() ){
380 return getIdForType( n.getType() );
381 }else{
382 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
383 //apply sort inference to quantified variables
384 d_var_types[n][ n[0][i] ] = sortCount;
385 sortCount++;
386
387 //type of the quantified variable must be the same
388 var_bound[ n[0][i] ] = n;
389 }
390 }
391 use_new_visited = true;
392 }
393
394 //process children
395 std::vector< Node > children;
396 std::vector< int > child_types;
397 for( size_t i=0; i<n.getNumChildren(); i++ ){
398 bool processChild = true;
399 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
400 processChild = options::userPatternsQuant()==theory::quantifiers::USER_PAT_MODE_IGNORE ? i==1 : i>=1;
401 }
402 if( processChild ){
403 children.push_back( n[i] );
404 child_types.push_back( process( n[i], var_bound, use_new_visited ? new_visited : visited ) );
405 }
406 }
407
408 //remove from variable bindings
409 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
410 //erase from variable bound
411 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
412 var_bound.erase( n[0][i] );
413 }
414 }
415 Trace("sort-inference-debug") << "...Process " << n << std::endl;
416
417 int retType;
418 if( n.getKind()==kind::EQUAL && !n[0].getType().isBoolean() ){
419 Trace("sort-inference-debug") << "For equality " << n << ", set equal types from : " << n[0].getType() << " " << n[1].getType() << std::endl;
420 //if original types are mixed (e.g. Int/Real), don't commit type equality in either direction
421 if( n[0].getType()!=n[1].getType() ){
422 //for now, assume the original types
423 for( unsigned i=0; i<2; i++ ){
424 int ct = getIdForType( n[i].getType() );
425 setEqual( child_types[i], ct );
426 }
427 }else{
428 //we only require that the left and right hand side must be equal
429 setEqual( child_types[0], child_types[1] );
430 }
431 d_equality_types[n] = child_types[0];
432 retType = getIdForType( n.getType() );
433 }else if( n.getKind()==kind::APPLY_UF ){
434 Node op = n.getOperator();
435 TypeNode tn_op = op.getType();
436 if( d_op_return_types.find( op )==d_op_return_types.end() ){
437 if( n.getType().isBoolean() ){
438 //use booleans
439 d_op_return_types[op] = getIdForType( n.getType() );
440 }else{
441 //assign arbitrary sort for return type
442 d_op_return_types[op] = sortCount;
443 sortCount++;
444 }
445 //d_type_eq_class[sortCount].push_back( op );
446 //assign arbitrary sort for argument types
447 for( size_t i=0; i<n.getNumChildren(); i++ ){
448 d_op_arg_types[op].push_back( sortCount );
449 sortCount++;
450 }
451 }
452 for( size_t i=0; i<n.getNumChildren(); i++ ){
453 //the argument of the operator must match the return type of the subterm
454 if( n[i].getType()!=tn_op[i] ){
455 //if type mismatch, assume original types
456 Trace("sort-inference-debug") << "Argument " << i << " of " << op << " " << n[i] << " has type " << n[i].getType();
457 Trace("sort-inference-debug") << ", while operator arg has type " << tn_op[i] << std::endl;
458 int ct1 = getIdForType( n[i].getType() );
459 setEqual( child_types[i], ct1 );
460 int ct2 = getIdForType( tn_op[i] );
461 setEqual( d_op_arg_types[op][i], ct2 );
462 }else{
463 setEqual( child_types[i], d_op_arg_types[op][i] );
464 }
465 }
466 //return type is the return type
467 retType = d_op_return_types[op];
468 }else{
469 std::map< Node, Node >::iterator it = var_bound.find( n );
470 if( it!=var_bound.end() ){
471 Trace("sort-inference-debug") << n << " is a bound variable." << std::endl;
472 //the return type was specified while binding
473 retType = d_var_types[it->second][n];
474 }else if( n.getKind() == kind::VARIABLE || n.getKind()==kind::SKOLEM ){
475 Trace("sort-inference-debug") << n << " is a variable." << std::endl;
476 if( d_op_return_types.find( n )==d_op_return_types.end() ){
477 //assign arbitrary sort
478 d_op_return_types[n] = sortCount;
479 sortCount++;
480 //d_type_eq_class[sortCount].push_back( n );
481 }
482 retType = d_op_return_types[n];
483 }else if( n.isConst() ){
484 Trace("sort-inference-debug") << n << " is a constant." << std::endl;
485 //can be any type we want
486 retType = sortCount;
487 sortCount++;
488 }else{
489 Trace("sort-inference-debug") << n << " is a interpreted symbol." << std::endl;
490 //it is an interpreted term
491 for( size_t i=0; i<children.size(); i++ ){
492 Trace("sort-inference-debug") << children[i] << " forced to have " << children[i].getType() << std::endl;
493 //must enforce the actual type of the operator on the children
494 int ct = getIdForType( children[i].getType() );
495 setEqual( child_types[i], ct );
496 }
497 //return type must be the actual return type
498 retType = getIdForType( n.getType() );
499 }
500 }
501 Trace("sort-inference-debug") << "...Type( " << n << " ) = ";
502 printSort("sort-inference-debug", retType );
503 Trace("sort-inference-debug") << std::endl;
504 visited[n] = retType;
505 return retType;
506 }
507 }
508
509 void SortInference::processMonotonic( Node n, bool pol, bool hasPol, std::map< Node, Node >& var_bound, std::map< Node, std::map< int, bool > >& visited, bool typeMode ) {
510 int pindex = hasPol ? ( pol ? 1 : -1 ) : 0;
511 if( visited[n].find( pindex )==visited[n].end() ){
512 visited[n][pindex] = true;
513 Trace("sort-inference-debug") << "...Process monotonic " << pol << " " << hasPol << " " << n << std::endl;
514 if( n.getKind()==kind::FORALL ){
515 //only consider variables universally if it is possible this quantified formula is asserted positively
516 if( !hasPol || pol ){
517 for( unsigned i=0; i<n[0].getNumChildren(); i++ ){
518 var_bound[n[0][i]] = n;
519 }
520 }
521 processMonotonic( n[1], pol, hasPol, var_bound, visited, typeMode );
522 if( !hasPol || pol ){
523 for( unsigned i=0; i<n[0].getNumChildren(); i++ ){
524 var_bound.erase( n[0][i] );
525 }
526 }
527 return;
528 }else if( n.getKind()==kind::EQUAL ){
529 if( !hasPol || pol ){
530 for( unsigned i=0; i<2; i++ ){
531 if( var_bound.find( n[i] )!=var_bound.end() ){
532 if( !typeMode ){
533 int sid = getSortId( var_bound[n[i]], n[i] );
534 d_non_monotonic_sorts[sid] = true;
535 }else{
536 d_non_monotonic_sorts_orig[n[i].getType()] = true;
537 }
538 break;
539 }
540 }
541 }
542 }
543 for( unsigned i=0; i<n.getNumChildren(); i++ ){
544 bool npol;
545 bool nhasPol;
546 theory::QuantPhaseReq::getPolarity( n, i, hasPol, pol, nhasPol, npol );
547 processMonotonic( n[i], npol, nhasPol, var_bound, visited, typeMode );
548 }
549 }
550 }
551
552
553 TypeNode SortInference::getOrCreateTypeForId( int t, TypeNode pref ){
554 int rt = d_type_union_find.getRepresentative( t );
555 if( d_type_types.find( rt )!=d_type_types.end() ){
556 return d_type_types[rt];
557 }else{
558 TypeNode retType;
559 //see if we can assign pref
560 if( !pref.isNull() && d_id_for_types.find( pref )==d_id_for_types.end() ){
561 retType = pref;
562 }else{
563 //must create new type
564 std::stringstream ss;
565 ss << "it_" << t << "_" << pref;
566 retType = NodeManager::currentNM()->mkSort( ss.str() );
567 }
568 Trace("sort-inference") << "-> Make type " << retType << " to correspond to ";
569 printSort("sort-inference", t );
570 Trace("sort-inference") << std::endl;
571 d_id_for_types[ retType ] = rt;
572 d_type_types[ rt ] = retType;
573 return retType;
574 }
575 }
576
577 TypeNode SortInference::getTypeForId( int t ){
578 int rt = d_type_union_find.getRepresentative( t );
579 if( d_type_types.find( rt )!=d_type_types.end() ){
580 return d_type_types[rt];
581 }else{
582 return TypeNode::null();
583 }
584 }
585
586 Node SortInference::getNewSymbol( Node old, TypeNode tn ){
587 // if no sort was inferred for this node, return original
588 if( tn.isNull() || tn.isComparableTo( old.getType() ) ){
589 return old;
590 }else if( old.isConst() ){
591 //must make constant of type tn
592 if( d_const_map[tn].find( old )==d_const_map[tn].end() ){
593 std::stringstream ss;
594 ss << "ic_" << tn << "_" << old;
595 d_const_map[tn][ old ] = NodeManager::currentNM()->mkSkolem( ss.str(), tn, "constant created during sort inference" ); //use mkConst???
596 }
597 return d_const_map[tn][ old ];
598 }else if( old.getKind()==kind::BOUND_VARIABLE ){
599 std::stringstream ss;
600 ss << "b_" << old;
601 return NodeManager::currentNM()->mkBoundVar( ss.str(), tn );
602 }else{
603 std::stringstream ss;
604 ss << "i_" << old;
605 return NodeManager::currentNM()->mkSkolem( ss.str(), tn, "created during sort inference" );
606 }
607 }
608
609 Node SortInference::simplifyNode( Node n, std::map< Node, Node >& var_bound, TypeNode tnn, std::map< Node, std::map< TypeNode, Node > >& visited ){
610 std::map< TypeNode, Node >::iterator itv = visited[n].find( tnn );
611 if( itv!=visited[n].end() ){
612 return itv->second;
613 }else{
614 Trace("sort-inference-debug2") << "Simplify " << n << ", type context=" << tnn << std::endl;
615 std::vector< Node > children;
616 std::map< Node, std::map< TypeNode, Node > > new_visited;
617 bool use_new_visited = false;
618 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
619 //recreate based on types of variables
620 std::vector< Node > new_children;
621 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
622 TypeNode tn = getOrCreateTypeForId( d_var_types[n][ n[0][i] ], n[0][i].getType() );
623 Node v = getNewSymbol( n[0][i], tn );
624 Trace("sort-inference-debug2") << "Map variable " << n[0][i] << " to " << v << std::endl;
625 new_children.push_back( v );
626 var_bound[ n[0][i] ] = v;
627 }
628 children.push_back( NodeManager::currentNM()->mkNode( n[0].getKind(), new_children ) );
629 use_new_visited = true;
630 }
631
632 //process children
633 if( n.getMetaKind() == kind::metakind::PARAMETERIZED ){
634 children.push_back( n.getOperator() );
635 }
636 Node op;
637 if( n.hasOperator() ){
638 op = n.getOperator();
639 }
640 bool childChanged = false;
641 TypeNode tnnc;
642 for( size_t i=0; i<n.getNumChildren(); i++ ){
643 bool processChild = true;
644 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
645 processChild = options::userPatternsQuant()==theory::quantifiers::USER_PAT_MODE_IGNORE ? i==1 : i>=1;
646 }
647 if( processChild ){
648 if( n.getKind()==kind::APPLY_UF ){
649 Assert( d_op_arg_types.find( op )!=d_op_arg_types.end() );
650 tnnc = getOrCreateTypeForId( d_op_arg_types[op][i], n[i].getType() );
651 Assert( !tnnc.isNull() );
652 }else if( n.getKind()==kind::EQUAL && i==0 ){
653 Assert( d_equality_types.find( n )!=d_equality_types.end() );
654 tnnc = getOrCreateTypeForId( d_equality_types[n], n[0].getType() );
655 Assert( !tnnc.isNull() );
656 }
657 Node nc = simplifyNode( n[i], var_bound, tnnc, use_new_visited ? new_visited : visited );
658 Trace("sort-inference-debug2") << "Simplify " << i << " " << n[i] << " returned " << nc << std::endl;
659 children.push_back( nc );
660 childChanged = childChanged || nc!=n[i];
661 }
662 }
663
664 //remove from variable bindings
665 Node ret;
666 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
667 //erase from variable bound
668 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
669 Trace("sort-inference-debug2") << "Remove bound for " << n[0][i] << std::endl;
670 var_bound.erase( n[0][i] );
671 }
672 ret = NodeManager::currentNM()->mkNode( n.getKind(), children );
673 }else if( n.getKind()==kind::EQUAL ){
674 TypeNode tn1 = children[0].getType();
675 TypeNode tn2 = children[1].getType();
676 if( !tn1.isSubtypeOf( tn2 ) && !tn2.isSubtypeOf( tn1 ) ){
677 Trace("sort-inference-warn") << "Sort inference created bad equality: " << children[0] << " = " << children[1] << std::endl;
678 Trace("sort-inference-warn") << " Types : " << children[0].getType() << " " << children[1].getType() << std::endl;
679 Assert( false );
680 }
681 ret = NodeManager::currentNM()->mkNode( kind::EQUAL, children );
682 }else if( n.getKind()==kind::APPLY_UF ){
683 if( d_symbol_map.find( op )==d_symbol_map.end() ){
684 //make the new operator if necessary
685 bool opChanged = false;
686 std::vector< TypeNode > argTypes;
687 for( size_t i=0; i<n.getNumChildren(); i++ ){
688 TypeNode tn = getOrCreateTypeForId( d_op_arg_types[op][i], n[i].getType() );
689 argTypes.push_back( tn );
690 if( tn!=n[i].getType() ){
691 opChanged = true;
692 }
693 }
694 TypeNode retType = getOrCreateTypeForId( d_op_return_types[op], n.getType() );
695 if( retType!=n.getType() ){
696 opChanged = true;
697 }
698 if( opChanged ){
699 std::stringstream ss;
700 ss << "io_" << op;
701 TypeNode typ = NodeManager::currentNM()->mkFunctionType( argTypes, retType );
702 d_symbol_map[op] = NodeManager::currentNM()->mkSkolem( ss.str(), typ, "op created during sort inference" );
703 Trace("setp-model") << "Function " << op << " is replaced with " << d_symbol_map[op] << std::endl;
704 d_model_replace_f[op] = d_symbol_map[op];
705 }else{
706 d_symbol_map[op] = op;
707 }
708 }
709 children[0] = d_symbol_map[op];
710 //make sure all children have been taken care of
711 for( size_t i=0; i<n.getNumChildren(); i++ ){
712 TypeNode tn = children[i+1].getType();
713 TypeNode tna = getTypeForId( d_op_arg_types[op][i] );
714 if( tn!=tna ){
715 Trace("sort-inference-warn") << "Sort inference created bad child: " << n << " " << n[i] << " " << tn << " " << tna << std::endl;
716 Assert( false );
717 }
718 }
719 ret = NodeManager::currentNM()->mkNode( kind::APPLY_UF, children );
720 }else{
721 std::map< Node, Node >::iterator it = var_bound.find( n );
722 if( it!=var_bound.end() ){
723 ret = it->second;
724 }else if( n.getKind() == kind::VARIABLE || n.getKind() == kind::SKOLEM ){
725 if( d_symbol_map.find( n )==d_symbol_map.end() ){
726 TypeNode tn = getOrCreateTypeForId( d_op_return_types[n], n.getType() );
727 d_symbol_map[n] = getNewSymbol( n, tn );
728 }
729 ret = d_symbol_map[n];
730 }else if( n.isConst() ){
731 //type is determined by context
732 ret = getNewSymbol( n, tnn );
733 }else if( childChanged ){
734 ret = NodeManager::currentNM()->mkNode( n.getKind(), children );
735 }else{
736 ret = n;
737 }
738 }
739 visited[n][tnn] = ret;
740 return ret;
741 }
742 }
743
744 Node SortInference::mkInjection( TypeNode tn1, TypeNode tn2 ) {
745 std::vector< TypeNode > tns;
746 tns.push_back( tn1 );
747 TypeNode typ = NodeManager::currentNM()->mkFunctionType( tns, tn2 );
748 Node f = NodeManager::currentNM()->mkSkolem( "inj", typ, "injection for monotonicity constraint" );
749 Trace("sort-inference") << "-> Make injection " << f << " from " << tn1 << " to " << tn2 << std::endl;
750 Node v1 = NodeManager::currentNM()->mkBoundVar( "?x", tn1 );
751 Node v2 = NodeManager::currentNM()->mkBoundVar( "?y", tn1 );
752 Node ret = NodeManager::currentNM()->mkNode( kind::FORALL,
753 NodeManager::currentNM()->mkNode( kind::BOUND_VAR_LIST, v1, v2 ),
754 NodeManager::currentNM()->mkNode( kind::OR,
755 NodeManager::currentNM()->mkNode( kind::APPLY_UF, f, v1 ).eqNode( NodeManager::currentNM()->mkNode( kind::APPLY_UF, f, v2 ) ).negate(),
756 v1.eqNode( v2 ) ) );
757 ret = theory::Rewriter::rewrite( ret );
758 return ret;
759 }
760
761 int SortInference::getSortId( Node n ) {
762 Node op = n.getKind()==kind::APPLY_UF ? n.getOperator() : n;
763 if( d_op_return_types.find( op )!=d_op_return_types.end() ){
764 return d_type_union_find.getRepresentative( d_op_return_types[op] );
765 }else{
766 return 0;
767 }
768 }
769
770 int SortInference::getSortId( Node f, Node v ) {
771 if( d_var_types.find( f )!=d_var_types.end() ){
772 return d_type_union_find.getRepresentative( d_var_types[f][v] );
773 }else{
774 return 0;
775 }
776 }
777
778 void SortInference::setSkolemVar( Node f, Node v, Node sk ){
779 Trace("sort-inference-temp") << "Set skolem var for " << f << ", variable " << v << std::endl;
780 if( isWellSortedFormula( f ) && d_var_types.find( f )==d_var_types.end() ){
781 //calculate the sort for variables if not done so already
782 std::map< Node, Node > var_bound;
783 std::map< Node, int > visited;
784 process( f, var_bound, visited );
785 }
786 d_op_return_types[sk] = getSortId( f, v );
787 Trace("sort-inference-temp") << "Set skolem sort id for " << sk << " to " << d_op_return_types[sk] << std::endl;
788 }
789
790 bool SortInference::isWellSortedFormula( Node n ) {
791 if( n.getType().isBoolean() && n.getKind()!=kind::APPLY_UF ){
792 for( unsigned i=0; i<n.getNumChildren(); i++ ){
793 if( !isWellSortedFormula( n[i] ) ){
794 return false;
795 }
796 }
797 return true;
798 }else{
799 return isWellSorted( n );
800 }
801 }
802
803 bool SortInference::isWellSorted( Node n ) {
804 if( getSortId( n )==0 ){
805 return false;
806 }else{
807 if( n.getKind()==kind::APPLY_UF ){
808 for( unsigned i=0; i<n.getNumChildren(); i++ ){
809 int s1 = getSortId( n[i] );
810 int s2 = d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] );
811 if( s1!=s2 ){
812 return false;
813 }
814 if( !isWellSorted( n[i] ) ){
815 return false;
816 }
817 }
818 }
819 return true;
820 }
821 }
822
823 void SortInference::getSortConstraints( Node n, UnionFind& uf ) {
824 if( n.getKind()==kind::APPLY_UF ){
825 for( unsigned i=0; i<n.getNumChildren(); i++ ){
826 getSortConstraints( n[i], uf );
827 uf.setEqual( getSortId( n[i] ), d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] ) );
828 }
829 }
830 }
831
832 bool SortInference::isMonotonic( TypeNode tn ) {
833 Assert( tn.isSort() );
834 return d_non_monotonic_sorts_orig.find( tn )==d_non_monotonic_sorts_orig.end();
835 }
836
837 }/* CVC4 namespace */