Make sort inference a preprocessing pass (#2309)
[cvc5.git] / src / theory / sort_inference.cpp
1 /********************* */
2 /*! \file sort_inference.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds, Paul Meng, Tim King
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Sort inference module
13 **
14 ** This class implements sort inference, based on a simple algorithm:
15 ** First, we assume all functions and predicates have distinct uninterpreted types.
16 ** One pass is made through the input assertions, while a union-find data structure
17 ** maintains necessary information regarding constraints on these types.
18 **/
19
20 #include "theory/sort_inference.h"
21
22 #include <vector>
23
24 #include "options/quantifiers_options.h"
25 #include "options/smt_options.h"
26 #include "options/uf_options.h"
27 #include "proof/proof_manager.h"
28 #include "theory/rewriter.h"
29 #include "theory/quantifiers/quant_util.h"
30
31 using namespace CVC4;
32 using namespace std;
33
34 namespace CVC4 {
35
36 void SortInference::UnionFind::print(const char * c){
37 for( std::map< int, int >::iterator it = d_eqc.begin(); it != d_eqc.end(); ++it ){
38 Trace(c) << "s_" << it->first << " = s_" << it->second << ", ";
39 }
40 for( unsigned i=0; i<d_deq.size(); i++ ){
41 Trace(c) << "s_" << d_deq[i].first << " != s_" << d_deq[i].second << ", ";
42 }
43 Trace(c) << std::endl;
44 }
45 void SortInference::UnionFind::set( UnionFind& c ) {
46 clear();
47 for( std::map< int, int >::iterator it = c.d_eqc.begin(); it != c.d_eqc.end(); ++it ){
48 d_eqc[ it->first ] = it->second;
49 }
50 d_deq.insert( d_deq.end(), c.d_deq.begin(), c.d_deq.end() );
51 }
52 int SortInference::UnionFind::getRepresentative( int t ){
53 std::map< int, int >::iterator it = d_eqc.find( t );
54 if( it==d_eqc.end() || it->second==t ){
55 return t;
56 }else{
57 int rt = getRepresentative( it->second );
58 d_eqc[t] = rt;
59 return rt;
60 }
61 }
62 void SortInference::UnionFind::setEqual( int t1, int t2 ){
63 if( t1!=t2 ){
64 int rt1 = getRepresentative( t1 );
65 int rt2 = getRepresentative( t2 );
66 if( rt1>rt2 ){
67 d_eqc[rt1] = rt2;
68 }else{
69 d_eqc[rt2] = rt1;
70 }
71 }
72 }
73 bool SortInference::UnionFind::isValid() {
74 for( unsigned i=0; i<d_deq.size(); i++ ){
75 if( areEqual( d_deq[i].first, d_deq[i].second ) ){
76 return false;
77 }
78 }
79 return true;
80 }
81
82
83 void SortInference::recordSubsort( TypeNode tn, int s ){
84 s = d_type_union_find.getRepresentative( s );
85 if( std::find( d_sub_sorts.begin(), d_sub_sorts.end(), s )==d_sub_sorts.end() ){
86 d_sub_sorts.push_back( s );
87 d_type_sub_sorts[tn].push_back( s );
88 }
89 }
90
91 void SortInference::printSort( const char* c, int t ){
92 int rt = d_type_union_find.getRepresentative( t );
93 if( d_type_types.find( rt )!=d_type_types.end() ){
94 Trace(c) << d_type_types[rt];
95 }else{
96 Trace(c) << "s_" << rt;
97 }
98 }
99
100 void SortInference::reset() {
101 d_sub_sorts.clear();
102 d_non_monotonic_sorts.clear();
103 d_type_sub_sorts.clear();
104 //reset info
105 d_sortCount = 1;
106 d_type_union_find.clear();
107 d_type_types.clear();
108 d_id_for_types.clear();
109 d_op_return_types.clear();
110 d_op_arg_types.clear();
111 d_var_types.clear();
112 //for rewriting
113 d_symbol_map.clear();
114 d_const_map.clear();
115 }
116
117 void SortInference::initialize(const std::vector<Node>& assertions)
118 {
119 Trace("sort-inference-proc") << "Calculating sort inference..." << std::endl;
120 // process all assertions
121 std::map<Node, int> visited;
122 for (const Node& a : assertions)
123 {
124 Trace("sort-inference-debug") << "Process " << a << std::endl;
125 std::map<Node, Node> var_bound;
126 process(a, var_bound, visited);
127 }
128 Trace("sort-inference-proc") << "...done" << std::endl;
129 for (const std::pair<const Node, int>& rt : d_op_return_types)
130 {
131 Trace("sort-inference") << rt.first << " : ";
132 TypeNode retTn = rt.first.getType();
133 if (!d_op_arg_types[rt.first].empty())
134 {
135 Trace("sort-inference") << "( ";
136 for (size_t i = 0; i < d_op_arg_types[rt.first].size(); i++)
137 {
138 recordSubsort(retTn[i], d_op_arg_types[rt.first][i]);
139 printSort("sort-inference", d_op_arg_types[rt.first][i]);
140 Trace("sort-inference") << " ";
141 }
142 Trace("sort-inference") << ") -> ";
143 retTn = retTn[(int)retTn.getNumChildren() - 1];
144 }
145 recordSubsort(retTn, rt.second);
146 printSort("sort-inference", rt.second);
147 Trace("sort-inference") << std::endl;
148 }
149 for (std::pair<const Node, std::map<Node, int> >& vt : d_var_types)
150 {
151 Trace("sort-inference")
152 << "Quantified formula : " << vt.first << " : " << std::endl;
153 for (const Node& v : vt.first[0])
154 {
155 recordSubsort(v.getType(), vt.second[v]);
156 printSort("sort-inference", vt.second[v]);
157 Trace("sort-inference") << std::endl;
158 }
159 Trace("sort-inference") << std::endl;
160 }
161
162 // determine monotonicity of sorts
163 Trace("sort-inference-proc")
164 << "Calculating monotonicty for subsorts..." << std::endl;
165 std::map<Node, std::map<int, bool> > visitedm;
166 for (const Node& a : assertions)
167 {
168 Trace("sort-inference-debug")
169 << "Process monotonicity for " << a << std::endl;
170 std::map<Node, Node> var_bound;
171 processMonotonic(a, true, true, var_bound, visitedm);
172 }
173 Trace("sort-inference-proc") << "...done" << std::endl;
174
175 Trace("sort-inference") << "We have " << d_sub_sorts.size()
176 << " sub-sorts : " << std::endl;
177 for (unsigned i = 0, size = d_sub_sorts.size(); i < size; i++)
178 {
179 printSort("sort-inference", d_sub_sorts[i]);
180 if (d_type_types.find(d_sub_sorts[i]) != d_type_types.end())
181 {
182 Trace("sort-inference") << " is interpreted." << std::endl;
183 }
184 else if (d_non_monotonic_sorts.find(d_sub_sorts[i])
185 == d_non_monotonic_sorts.end())
186 {
187 Trace("sort-inference") << " is monotonic." << std::endl;
188 }
189 else
190 {
191 Trace("sort-inference") << " is not monotonic." << std::endl;
192 }
193 }
194 }
195
196 Node SortInference::simplify(Node n,
197 std::map<Node, Node>& model_replace_f,
198 std::map<Node, std::map<TypeNode, Node> >& visited)
199 {
200 Trace("sort-inference-debug") << "Simplify " << n << std::endl;
201 std::map<Node, Node> var_bound;
202 TypeNode tnn;
203 Node ret = simplifyNode(n, var_bound, tnn, model_replace_f, visited);
204 ret = theory::Rewriter::rewrite(ret);
205 return ret;
206 }
207
208 void SortInference::getNewAssertions(std::vector<Node>& new_asserts)
209 {
210 NodeManager* nm = NodeManager::currentNM();
211 // now, ensure constants are distinct
212 for (const std::pair<const TypeNode, std::map<Node, Node> >& cm : d_const_map)
213 {
214 std::vector<Node> consts;
215 for (const std::pair<const Node, Node>& c : cm.second)
216 {
217 Assert(c.first.isConst());
218 consts.push_back(c.second);
219 }
220 // add lemma enforcing introduced constants to be distinct
221 if (consts.size() > 1)
222 {
223 Node distinct_const = nm->mkNode(kind::DISTINCT, consts);
224 Trace("sort-inference-rewrite")
225 << "Add the constant distinctness lemma: " << std::endl;
226 Trace("sort-inference-rewrite") << " " << distinct_const << std::endl;
227 new_asserts.push_back(distinct_const);
228 }
229 }
230
231 // enforce constraints based on monotonicity
232 Trace("sort-inference-proc") << "Enforce monotonicity..." << std::endl;
233
234 for (const std::pair<const TypeNode, std::vector<int> >& tss :
235 d_type_sub_sorts)
236 {
237 int nmonSort = -1;
238 unsigned nsorts = tss.second.size();
239 for (unsigned i = 0; i < nsorts; i++)
240 {
241 if (d_non_monotonic_sorts.find(tss.second[i])
242 != d_non_monotonic_sorts.end())
243 {
244 nmonSort = tss.second[i];
245 break;
246 }
247 }
248 if (nmonSort != -1)
249 {
250 std::vector<Node> injections;
251 TypeNode base_tn = getOrCreateTypeForId(nmonSort, tss.first);
252 for (unsigned i = 0; i < nsorts; i++)
253 {
254 if (tss.second[i] != nmonSort)
255 {
256 TypeNode new_tn = getOrCreateTypeForId(tss.second[i], tss.first);
257 // make injection to nmonSort
258 Node a1 = mkInjection(new_tn, base_tn);
259 injections.push_back(a1);
260 if (d_non_monotonic_sorts.find(tss.second[i])
261 != d_non_monotonic_sorts.end())
262 {
263 // also must make injection from nmonSort to this
264 Node a2 = mkInjection(base_tn, new_tn);
265 injections.push_back(a2);
266 }
267 }
268 }
269 if (Trace.isOn("sort-inference-rewrite"))
270 {
271 Trace("sort-inference-rewrite")
272 << "Add the following injections for " << tss.first
273 << " to ensure consistency wrt non-monotonic sorts : " << std::endl;
274 for (const Node& i : injections)
275 {
276 Trace("sort-inference-rewrite") << " " << i << std::endl;
277 }
278 }
279 new_asserts.insert(
280 new_asserts.end(), injections.begin(), injections.end());
281 }
282 }
283 Trace("sort-inference-proc") << "...done" << std::endl;
284 // no sub-sort information is stored
285 reset();
286 Trace("sort-inference-debug") << "Finished sort inference" << std::endl;
287 }
288
289 void SortInference::computeMonotonicity(const std::vector<Node>& assertions)
290 {
291 std::map<Node, std::map<int, bool> > visitedmt;
292 Trace("sort-inference-proc")
293 << "Calculating monotonicty for types..." << std::endl;
294 for (const Node& a : assertions)
295 {
296 Trace("sort-inference-debug")
297 << "Process type monotonicity for " << a << std::endl;
298 std::map<Node, Node> var_bound;
299 processMonotonic(a, true, true, var_bound, visitedmt, true);
300 }
301 Trace("sort-inference-proc") << "...done" << std::endl;
302 }
303
304 void SortInference::setEqual( int t1, int t2 ){
305 if( t1!=t2 ){
306 int rt1 = d_type_union_find.getRepresentative( t1 );
307 int rt2 = d_type_union_find.getRepresentative( t2 );
308 if( rt1!=rt2 ){
309 Trace("sort-inference-debug") << "Set equal : ";
310 printSort( "sort-inference-debug", rt1 );
311 Trace("sort-inference-debug") << " ";
312 printSort( "sort-inference-debug", rt2 );
313 Trace("sort-inference-debug") << std::endl;
314 /*
315 d_type_eq_class[rt1].insert( d_type_eq_class[rt1].end(), d_type_eq_class[rt2].begin(), d_type_eq_class[rt2].end() );
316 d_type_eq_class[rt2].clear();
317 Trace("sort-inference-debug") << "EqClass : { ";
318 for( int i=0; i<(int)d_type_eq_class[rt1].size(); i++ ){
319 Trace("sort-inference-debug") << d_type_eq_class[rt1][i] << ", ";
320 }
321 Trace("sort-inference-debug") << "}" << std::endl;
322 */
323 if( rt2>rt1 ){
324 //swap
325 int swap = rt1;
326 rt1 = rt2;
327 rt2 = swap;
328 }
329 std::map< int, TypeNode >::iterator it1 = d_type_types.find( rt1 );
330 if( it1!=d_type_types.end() ){
331 if( d_type_types.find( rt2 )==d_type_types.end() ){
332 d_type_types[rt2] = it1->second;
333 d_type_types.erase( rt1 );
334 }else{
335 Trace("sort-inference-debug") << "...fail : associated with types " << d_type_types[rt1] << " and " << d_type_types[rt2] << std::endl;
336 return;
337 }
338 }
339 d_type_union_find.d_eqc[rt1] = rt2;
340 }
341 }
342 }
343
344 int SortInference::getIdForType( TypeNode tn ){
345 //register the return type
346 std::map< TypeNode, int >::iterator it = d_id_for_types.find( tn );
347 if( it==d_id_for_types.end() ){
348 int sc = d_sortCount;
349 d_type_types[d_sortCount] = tn;
350 d_id_for_types[tn] = d_sortCount;
351 d_sortCount++;
352 return sc;
353 }else{
354 return it->second;
355 }
356 }
357
358 int SortInference::process( Node n, std::map< Node, Node >& var_bound, std::map< Node, int >& visited ){
359 std::map< Node, int >::iterator itv = visited.find( n );
360 if( itv!=visited.end() ){
361 return itv->second;
362 }else{
363 //add to variable bindings
364 bool use_new_visited = false;
365 std::map< Node, int > new_visited;
366 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
367 if( d_var_types.find( n )!=d_var_types.end() ){
368 return getIdForType( n.getType() );
369 }else{
370 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
371 //apply sort inference to quantified variables
372 d_var_types[n][n[0][i]] = d_sortCount;
373 d_sortCount++;
374
375 //type of the quantified variable must be the same
376 var_bound[ n[0][i] ] = n;
377 }
378 }
379 use_new_visited = true;
380 }
381
382 //process children
383 std::vector< Node > children;
384 std::vector< int > child_types;
385 for( size_t i=0; i<n.getNumChildren(); i++ ){
386 bool processChild = true;
387 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
388 processChild = options::userPatternsQuant()==theory::quantifiers::USER_PAT_MODE_IGNORE ? i==1 : i>=1;
389 }
390 if( processChild ){
391 children.push_back( n[i] );
392 child_types.push_back( process( n[i], var_bound, use_new_visited ? new_visited : visited ) );
393 }
394 }
395
396 //remove from variable bindings
397 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
398 //erase from variable bound
399 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
400 var_bound.erase( n[0][i] );
401 }
402 }
403 Trace("sort-inference-debug") << "...Process " << n << std::endl;
404
405 int retType;
406 if( n.getKind()==kind::EQUAL && !n[0].getType().isBoolean() ){
407 Trace("sort-inference-debug") << "For equality " << n << ", set equal types from : " << n[0].getType() << " " << n[1].getType() << std::endl;
408 //if original types are mixed (e.g. Int/Real), don't commit type equality in either direction
409 if( n[0].getType()!=n[1].getType() ){
410 //for now, assume the original types
411 for( unsigned i=0; i<2; i++ ){
412 int ct = getIdForType( n[i].getType() );
413 setEqual( child_types[i], ct );
414 }
415 }else{
416 //we only require that the left and right hand side must be equal
417 setEqual( child_types[0], child_types[1] );
418 }
419 d_equality_types[n] = child_types[0];
420 retType = getIdForType( n.getType() );
421 }else if( n.getKind()==kind::APPLY_UF ){
422 Node op = n.getOperator();
423 TypeNode tn_op = op.getType();
424 if( d_op_return_types.find( op )==d_op_return_types.end() ){
425 if( n.getType().isBoolean() ){
426 //use booleans
427 d_op_return_types[op] = getIdForType( n.getType() );
428 }else{
429 //assign arbitrary sort for return type
430 d_op_return_types[op] = d_sortCount;
431 d_sortCount++;
432 }
433 // d_type_eq_class[d_sortCount].push_back( op );
434 // assign arbitrary sort for argument types
435 for( size_t i=0; i<n.getNumChildren(); i++ ){
436 d_op_arg_types[op].push_back(d_sortCount);
437 d_sortCount++;
438 }
439 }
440 for( size_t i=0; i<n.getNumChildren(); i++ ){
441 //the argument of the operator must match the return type of the subterm
442 if( n[i].getType()!=tn_op[i] ){
443 //if type mismatch, assume original types
444 Trace("sort-inference-debug") << "Argument " << i << " of " << op << " " << n[i] << " has type " << n[i].getType();
445 Trace("sort-inference-debug") << ", while operator arg has type " << tn_op[i] << std::endl;
446 int ct1 = getIdForType( n[i].getType() );
447 setEqual( child_types[i], ct1 );
448 int ct2 = getIdForType( tn_op[i] );
449 setEqual( d_op_arg_types[op][i], ct2 );
450 }else{
451 setEqual( child_types[i], d_op_arg_types[op][i] );
452 }
453 }
454 //return type is the return type
455 retType = d_op_return_types[op];
456 }else{
457 std::map< Node, Node >::iterator it = var_bound.find( n );
458 if( it!=var_bound.end() ){
459 Trace("sort-inference-debug") << n << " is a bound variable." << std::endl;
460 //the return type was specified while binding
461 retType = d_var_types[it->second][n];
462 }else if( n.getKind() == kind::VARIABLE || n.getKind()==kind::SKOLEM ){
463 Trace("sort-inference-debug") << n << " is a variable." << std::endl;
464 if( d_op_return_types.find( n )==d_op_return_types.end() ){
465 //assign arbitrary sort
466 d_op_return_types[n] = d_sortCount;
467 d_sortCount++;
468 // d_type_eq_class[d_sortCount].push_back( n );
469 }
470 retType = d_op_return_types[n];
471 }else if( n.isConst() ){
472 Trace("sort-inference-debug") << n << " is a constant." << std::endl;
473 //can be any type we want
474 retType = d_sortCount;
475 d_sortCount++;
476 }else{
477 Trace("sort-inference-debug") << n << " is a interpreted symbol." << std::endl;
478 //it is an interpreted term
479 for( size_t i=0; i<children.size(); i++ ){
480 Trace("sort-inference-debug") << children[i] << " forced to have " << children[i].getType() << std::endl;
481 //must enforce the actual type of the operator on the children
482 int ct = getIdForType( children[i].getType() );
483 setEqual( child_types[i], ct );
484 }
485 //return type must be the actual return type
486 retType = getIdForType( n.getType() );
487 }
488 }
489 Trace("sort-inference-debug") << "...Type( " << n << " ) = ";
490 printSort("sort-inference-debug", retType );
491 Trace("sort-inference-debug") << std::endl;
492 visited[n] = retType;
493 return retType;
494 }
495 }
496
497 void SortInference::processMonotonic( Node n, bool pol, bool hasPol, std::map< Node, Node >& var_bound, std::map< Node, std::map< int, bool > >& visited, bool typeMode ) {
498 int pindex = hasPol ? ( pol ? 1 : -1 ) : 0;
499 if( visited[n].find( pindex )==visited[n].end() ){
500 visited[n][pindex] = true;
501 Trace("sort-inference-debug") << "...Process monotonic " << pol << " " << hasPol << " " << n << std::endl;
502 if( n.getKind()==kind::FORALL ){
503 //only consider variables universally if it is possible this quantified formula is asserted positively
504 if( !hasPol || pol ){
505 for( unsigned i=0; i<n[0].getNumChildren(); i++ ){
506 var_bound[n[0][i]] = n;
507 }
508 }
509 processMonotonic( n[1], pol, hasPol, var_bound, visited, typeMode );
510 if( !hasPol || pol ){
511 for( unsigned i=0; i<n[0].getNumChildren(); i++ ){
512 var_bound.erase( n[0][i] );
513 }
514 }
515 return;
516 }else if( n.getKind()==kind::EQUAL ){
517 if( !hasPol || pol ){
518 for( unsigned i=0; i<2; i++ ){
519 if( var_bound.find( n[i] )!=var_bound.end() ){
520 if( !typeMode ){
521 int sid = getSortId( var_bound[n[i]], n[i] );
522 d_non_monotonic_sorts[sid] = true;
523 }else{
524 d_non_monotonic_sorts_orig[n[i].getType()] = true;
525 }
526 break;
527 }
528 }
529 }
530 }
531 for( unsigned i=0; i<n.getNumChildren(); i++ ){
532 bool npol;
533 bool nhasPol;
534 theory::QuantPhaseReq::getPolarity( n, i, hasPol, pol, nhasPol, npol );
535 processMonotonic( n[i], npol, nhasPol, var_bound, visited, typeMode );
536 }
537 }
538 }
539
540
541 TypeNode SortInference::getOrCreateTypeForId( int t, TypeNode pref ){
542 int rt = d_type_union_find.getRepresentative( t );
543 if( d_type_types.find( rt )!=d_type_types.end() ){
544 return d_type_types[rt];
545 }else{
546 TypeNode retType;
547 // See if we can assign pref. This is an optimization for reusing an
548 // uninterpreted sort as the first subsort, so that fewer symbols needed
549 // to be rewritten in the sort-inferred signature. Notice we only assign
550 // pref here if it is an uninterpreted sort.
551 if (!pref.isNull() && d_id_for_types.find(pref) == d_id_for_types.end()
552 && pref.isSort())
553 {
554 retType = pref;
555 }else{
556 //must create new type
557 std::stringstream ss;
558 ss << "it_" << t << "_" << pref;
559 retType = NodeManager::currentNM()->mkSort( ss.str() );
560 }
561 Trace("sort-inference") << "-> Make type " << retType << " to correspond to ";
562 printSort("sort-inference", t );
563 Trace("sort-inference") << std::endl;
564 d_id_for_types[ retType ] = rt;
565 d_type_types[ rt ] = retType;
566 return retType;
567 }
568 }
569
570 TypeNode SortInference::getTypeForId( int t ){
571 int rt = d_type_union_find.getRepresentative( t );
572 if( d_type_types.find( rt )!=d_type_types.end() ){
573 return d_type_types[rt];
574 }else{
575 return TypeNode::null();
576 }
577 }
578
579 Node SortInference::getNewSymbol( Node old, TypeNode tn ){
580 // if no sort was inferred for this node, return original
581 if( tn.isNull() || tn.isComparableTo( old.getType() ) ){
582 return old;
583 }else if( old.isConst() ){
584 //must make constant of type tn
585 if( d_const_map[tn].find( old )==d_const_map[tn].end() ){
586 std::stringstream ss;
587 ss << "ic_" << tn << "_" << old;
588 d_const_map[tn][ old ] = NodeManager::currentNM()->mkSkolem( ss.str(), tn, "constant created during sort inference" ); //use mkConst???
589 }
590 return d_const_map[tn][ old ];
591 }else if( old.getKind()==kind::BOUND_VARIABLE ){
592 std::stringstream ss;
593 ss << "b_" << old;
594 return NodeManager::currentNM()->mkBoundVar( ss.str(), tn );
595 }else{
596 std::stringstream ss;
597 ss << "i_" << old;
598 return NodeManager::currentNM()->mkSkolem( ss.str(), tn, "created during sort inference" );
599 }
600 }
601
602 Node SortInference::simplifyNode(
603 Node n,
604 std::map<Node, Node>& var_bound,
605 TypeNode tnn,
606 std::map<Node, Node>& model_replace_f,
607 std::map<Node, std::map<TypeNode, Node> >& visited)
608 {
609 std::map< TypeNode, Node >::iterator itv = visited[n].find( tnn );
610 if( itv!=visited[n].end() ){
611 return itv->second;
612 }else{
613 Trace("sort-inference-debug2") << "Simplify " << n << ", type context=" << tnn << std::endl;
614 std::vector< Node > children;
615 std::map< Node, std::map< TypeNode, Node > > new_visited;
616 bool use_new_visited = false;
617 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
618 //recreate based on types of variables
619 std::vector< Node > new_children;
620 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
621 TypeNode tn = getOrCreateTypeForId( d_var_types[n][ n[0][i] ], n[0][i].getType() );
622 Node v = getNewSymbol( n[0][i], tn );
623 Trace("sort-inference-debug2") << "Map variable " << n[0][i] << " to " << v << std::endl;
624 new_children.push_back( v );
625 var_bound[ n[0][i] ] = v;
626 }
627 children.push_back( NodeManager::currentNM()->mkNode( n[0].getKind(), new_children ) );
628 use_new_visited = true;
629 }
630
631 //process children
632 if( n.getMetaKind() == kind::metakind::PARAMETERIZED ){
633 children.push_back( n.getOperator() );
634 }
635 Node op;
636 if( n.hasOperator() ){
637 op = n.getOperator();
638 }
639 bool childChanged = false;
640 TypeNode tnnc;
641 for( size_t i=0; i<n.getNumChildren(); i++ ){
642 bool processChild = true;
643 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
644 processChild = options::userPatternsQuant()==theory::quantifiers::USER_PAT_MODE_IGNORE ? i==1 : i>=1;
645 }
646 if( processChild ){
647 if( n.getKind()==kind::APPLY_UF ){
648 Assert( d_op_arg_types.find( op )!=d_op_arg_types.end() );
649 tnnc = getOrCreateTypeForId( d_op_arg_types[op][i], n[i].getType() );
650 Assert( !tnnc.isNull() );
651 }else if( n.getKind()==kind::EQUAL && i==0 ){
652 Assert( d_equality_types.find( n )!=d_equality_types.end() );
653 tnnc = getOrCreateTypeForId( d_equality_types[n], n[0].getType() );
654 Assert( !tnnc.isNull() );
655 }
656 Node nc = simplifyNode(n[i],
657 var_bound,
658 tnnc,
659 model_replace_f,
660 use_new_visited ? new_visited : visited);
661 Trace("sort-inference-debug2") << "Simplify " << i << " " << n[i] << " returned " << nc << std::endl;
662 children.push_back( nc );
663 childChanged = childChanged || nc!=n[i];
664 }
665 }
666
667 //remove from variable bindings
668 Node ret;
669 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
670 //erase from variable bound
671 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
672 Trace("sort-inference-debug2") << "Remove bound for " << n[0][i] << std::endl;
673 var_bound.erase( n[0][i] );
674 }
675 ret = NodeManager::currentNM()->mkNode( n.getKind(), children );
676 }else if( n.getKind()==kind::EQUAL ){
677 TypeNode tn1 = children[0].getType();
678 TypeNode tn2 = children[1].getType();
679 if( !tn1.isSubtypeOf( tn2 ) && !tn2.isSubtypeOf( tn1 ) ){
680 Trace("sort-inference-warn") << "Sort inference created bad equality: " << children[0] << " = " << children[1] << std::endl;
681 Trace("sort-inference-warn") << " Types : " << children[0].getType() << " " << children[1].getType() << std::endl;
682 Assert( false );
683 }
684 ret = NodeManager::currentNM()->mkNode( kind::EQUAL, children );
685 }else if( n.getKind()==kind::APPLY_UF ){
686 if( d_symbol_map.find( op )==d_symbol_map.end() ){
687 //make the new operator if necessary
688 bool opChanged = false;
689 std::vector< TypeNode > argTypes;
690 for( size_t i=0; i<n.getNumChildren(); i++ ){
691 TypeNode tn = getOrCreateTypeForId( d_op_arg_types[op][i], n[i].getType() );
692 argTypes.push_back( tn );
693 if( tn!=n[i].getType() ){
694 opChanged = true;
695 }
696 }
697 TypeNode retType = getOrCreateTypeForId( d_op_return_types[op], n.getType() );
698 if( retType!=n.getType() ){
699 opChanged = true;
700 }
701 if( opChanged ){
702 std::stringstream ss;
703 ss << "io_" << op;
704 TypeNode typ = NodeManager::currentNM()->mkFunctionType( argTypes, retType );
705 d_symbol_map[op] = NodeManager::currentNM()->mkSkolem( ss.str(), typ, "op created during sort inference" );
706 Trace("setp-model") << "Function " << op << " is replaced with " << d_symbol_map[op] << std::endl;
707 model_replace_f[op] = d_symbol_map[op];
708 }else{
709 d_symbol_map[op] = op;
710 }
711 }
712 children[0] = d_symbol_map[op];
713 //make sure all children have been taken care of
714 for( size_t i=0; i<n.getNumChildren(); i++ ){
715 TypeNode tn = children[i+1].getType();
716 TypeNode tna = getTypeForId( d_op_arg_types[op][i] );
717 if( tn!=tna ){
718 Trace("sort-inference-warn") << "Sort inference created bad child: " << n << " " << n[i] << " " << tn << " " << tna << std::endl;
719 Assert( false );
720 }
721 }
722 ret = NodeManager::currentNM()->mkNode( kind::APPLY_UF, children );
723 }else{
724 std::map< Node, Node >::iterator it = var_bound.find( n );
725 if( it!=var_bound.end() ){
726 ret = it->second;
727 }else if( n.getKind() == kind::VARIABLE || n.getKind() == kind::SKOLEM ){
728 if( d_symbol_map.find( n )==d_symbol_map.end() ){
729 TypeNode tn = getOrCreateTypeForId( d_op_return_types[n], n.getType() );
730 d_symbol_map[n] = getNewSymbol( n, tn );
731 }
732 ret = d_symbol_map[n];
733 }else if( n.isConst() ){
734 //type is determined by context
735 ret = getNewSymbol( n, tnn );
736 }else if( childChanged ){
737 ret = NodeManager::currentNM()->mkNode( n.getKind(), children );
738 }else{
739 ret = n;
740 }
741 }
742 visited[n][tnn] = ret;
743 return ret;
744 }
745 }
746
747 Node SortInference::mkInjection( TypeNode tn1, TypeNode tn2 ) {
748 std::vector< TypeNode > tns;
749 tns.push_back( tn1 );
750 TypeNode typ = NodeManager::currentNM()->mkFunctionType( tns, tn2 );
751 Node f = NodeManager::currentNM()->mkSkolem( "inj", typ, "injection for monotonicity constraint" );
752 Trace("sort-inference") << "-> Make injection " << f << " from " << tn1 << " to " << tn2 << std::endl;
753 Node v1 = NodeManager::currentNM()->mkBoundVar( "?x", tn1 );
754 Node v2 = NodeManager::currentNM()->mkBoundVar( "?y", tn1 );
755 Node ret = NodeManager::currentNM()->mkNode( kind::FORALL,
756 NodeManager::currentNM()->mkNode( kind::BOUND_VAR_LIST, v1, v2 ),
757 NodeManager::currentNM()->mkNode( kind::OR,
758 NodeManager::currentNM()->mkNode( kind::APPLY_UF, f, v1 ).eqNode( NodeManager::currentNM()->mkNode( kind::APPLY_UF, f, v2 ) ).negate(),
759 v1.eqNode( v2 ) ) );
760 ret = theory::Rewriter::rewrite( ret );
761 return ret;
762 }
763
764 int SortInference::getSortId( Node n ) {
765 Node op = n.getKind()==kind::APPLY_UF ? n.getOperator() : n;
766 if( d_op_return_types.find( op )!=d_op_return_types.end() ){
767 return d_type_union_find.getRepresentative( d_op_return_types[op] );
768 }else{
769 return 0;
770 }
771 }
772
773 int SortInference::getSortId( Node f, Node v ) {
774 if( d_var_types.find( f )!=d_var_types.end() ){
775 return d_type_union_find.getRepresentative( d_var_types[f][v] );
776 }else{
777 return 0;
778 }
779 }
780
781 void SortInference::setSkolemVar( Node f, Node v, Node sk ){
782 Trace("sort-inference-temp") << "Set skolem var for " << f << ", variable " << v << std::endl;
783 if( isWellSortedFormula( f ) && d_var_types.find( f )==d_var_types.end() ){
784 //calculate the sort for variables if not done so already
785 std::map< Node, Node > var_bound;
786 std::map< Node, int > visited;
787 process( f, var_bound, visited );
788 }
789 d_op_return_types[sk] = getSortId( f, v );
790 Trace("sort-inference-temp") << "Set skolem sort id for " << sk << " to " << d_op_return_types[sk] << std::endl;
791 }
792
793 bool SortInference::isWellSortedFormula( Node n ) {
794 if( n.getType().isBoolean() && n.getKind()!=kind::APPLY_UF ){
795 for( unsigned i=0; i<n.getNumChildren(); i++ ){
796 if( !isWellSortedFormula( n[i] ) ){
797 return false;
798 }
799 }
800 return true;
801 }else{
802 return isWellSorted( n );
803 }
804 }
805
806 bool SortInference::isWellSorted( Node n ) {
807 if( getSortId( n )==0 ){
808 return false;
809 }else{
810 if( n.getKind()==kind::APPLY_UF ){
811 for( unsigned i=0; i<n.getNumChildren(); i++ ){
812 int s1 = getSortId( n[i] );
813 int s2 = d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] );
814 if( s1!=s2 ){
815 return false;
816 }
817 if( !isWellSorted( n[i] ) ){
818 return false;
819 }
820 }
821 }
822 return true;
823 }
824 }
825
826 void SortInference::getSortConstraints( Node n, UnionFind& uf ) {
827 if( n.getKind()==kind::APPLY_UF ){
828 for( unsigned i=0; i<n.getNumChildren(); i++ ){
829 getSortConstraints( n[i], uf );
830 uf.setEqual( getSortId( n[i] ), d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] ) );
831 }
832 }
833 }
834
835 bool SortInference::isMonotonic( TypeNode tn ) {
836 Assert( tn.isSort() );
837 return d_non_monotonic_sorts_orig.find( tn )==d_non_monotonic_sorts_orig.end();
838 }
839
840 }/* CVC4 namespace */