Update copyright headers.
[cvc5.git] / src / theory / sort_inference.cpp
1 /********************* */
2 /*! \file sort_inference.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds, Paul Meng, Kshitij Bansal
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Sort inference module
13 **
14 ** This class implements sort inference, based on a simple algorithm:
15 ** First, we assume all functions and predicates have distinct uninterpreted types.
16 ** One pass is made through the input assertions, while a union-find data structure
17 ** maintains necessary information regarding constraints on these types.
18 **/
19
20 #include "theory/sort_inference.h"
21
22 #include <vector>
23
24 #include "options/quantifiers_options.h"
25 #include "options/smt_options.h"
26 #include "options/uf_options.h"
27 #include "proof/proof_manager.h"
28 #include "theory/rewriter.h"
29 #include "theory/quantifiers/quant_util.h"
30
31 using namespace CVC4;
32 using namespace std;
33
34 namespace CVC4 {
35
36 void SortInference::UnionFind::print(const char * c){
37 for( std::map< int, int >::iterator it = d_eqc.begin(); it != d_eqc.end(); ++it ){
38 Trace(c) << "s_" << it->first << " = s_" << it->second << ", ";
39 }
40 for( unsigned i=0; i<d_deq.size(); i++ ){
41 Trace(c) << "s_" << d_deq[i].first << " != s_" << d_deq[i].second << ", ";
42 }
43 Trace(c) << std::endl;
44 }
45 void SortInference::UnionFind::set( UnionFind& c ) {
46 clear();
47 for( std::map< int, int >::iterator it = c.d_eqc.begin(); it != c.d_eqc.end(); ++it ){
48 d_eqc[ it->first ] = it->second;
49 }
50 d_deq.insert( d_deq.end(), c.d_deq.begin(), c.d_deq.end() );
51 }
52 int SortInference::UnionFind::getRepresentative( int t ){
53 std::map< int, int >::iterator it = d_eqc.find( t );
54 if( it==d_eqc.end() || it->second==t ){
55 return t;
56 }else{
57 int rt = getRepresentative( it->second );
58 d_eqc[t] = rt;
59 return rt;
60 }
61 }
62 void SortInference::UnionFind::setEqual( int t1, int t2 ){
63 if( t1!=t2 ){
64 int rt1 = getRepresentative( t1 );
65 int rt2 = getRepresentative( t2 );
66 if( rt1>rt2 ){
67 d_eqc[rt1] = rt2;
68 }else{
69 d_eqc[rt2] = rt1;
70 }
71 }
72 }
73 bool SortInference::UnionFind::isValid() {
74 for( unsigned i=0; i<d_deq.size(); i++ ){
75 if( areEqual( d_deq[i].first, d_deq[i].second ) ){
76 return false;
77 }
78 }
79 return true;
80 }
81
82
83 void SortInference::recordSubsort( TypeNode tn, int s ){
84 s = d_type_union_find.getRepresentative( s );
85 if( std::find( d_sub_sorts.begin(), d_sub_sorts.end(), s )==d_sub_sorts.end() ){
86 d_sub_sorts.push_back( s );
87 d_type_sub_sorts[tn].push_back( s );
88 }
89 }
90
91 void SortInference::printSort( const char* c, int t ){
92 int rt = d_type_union_find.getRepresentative( t );
93 if( d_type_types.find( rt )!=d_type_types.end() ){
94 Trace(c) << d_type_types[rt];
95 }else{
96 Trace(c) << "s_" << rt;
97 }
98 }
99
100 void SortInference::reset() {
101 d_sub_sorts.clear();
102 d_non_monotonic_sorts.clear();
103 d_type_sub_sorts.clear();
104 //reset info
105 d_sortCount = 1;
106 d_type_union_find.clear();
107 d_type_types.clear();
108 d_id_for_types.clear();
109 d_op_return_types.clear();
110 d_op_arg_types.clear();
111 d_var_types.clear();
112 //for rewriting
113 d_symbol_map.clear();
114 d_const_map.clear();
115 }
116
117 void SortInference::initialize(const std::vector<Node>& assertions)
118 {
119 Trace("sort-inference-proc") << "Calculating sort inference..." << std::endl;
120 // process all assertions
121 std::map<Node, int> visited;
122 for (const Node& a : assertions)
123 {
124 Trace("sort-inference-debug") << "Process " << a << std::endl;
125 std::map<Node, Node> var_bound;
126 process(a, var_bound, visited);
127 }
128 Trace("sort-inference-proc") << "...done" << std::endl;
129 for (const std::pair<const Node, int>& rt : d_op_return_types)
130 {
131 Trace("sort-inference") << rt.first << " : ";
132 TypeNode retTn = rt.first.getType();
133 if (!d_op_arg_types[rt.first].empty())
134 {
135 Trace("sort-inference") << "( ";
136 for (size_t i = 0; i < d_op_arg_types[rt.first].size(); i++)
137 {
138 recordSubsort(retTn[i], d_op_arg_types[rt.first][i]);
139 printSort("sort-inference", d_op_arg_types[rt.first][i]);
140 Trace("sort-inference") << " ";
141 }
142 Trace("sort-inference") << ") -> ";
143 retTn = retTn[(int)retTn.getNumChildren() - 1];
144 }
145 recordSubsort(retTn, rt.second);
146 printSort("sort-inference", rt.second);
147 Trace("sort-inference") << std::endl;
148 }
149 for (std::pair<const Node, std::map<Node, int> >& vt : d_var_types)
150 {
151 Trace("sort-inference")
152 << "Quantified formula : " << vt.first << " : " << std::endl;
153 for (const Node& v : vt.first[0])
154 {
155 recordSubsort(v.getType(), vt.second[v]);
156 printSort("sort-inference", vt.second[v]);
157 Trace("sort-inference") << std::endl;
158 }
159 Trace("sort-inference") << std::endl;
160 }
161
162 // determine monotonicity of sorts
163 Trace("sort-inference-proc")
164 << "Calculating monotonicty for subsorts..." << std::endl;
165 std::map<Node, std::map<int, bool> > visitedm;
166 for (const Node& a : assertions)
167 {
168 Trace("sort-inference-debug")
169 << "Process monotonicity for " << a << std::endl;
170 std::map<Node, Node> var_bound;
171 processMonotonic(a, true, true, var_bound, visitedm);
172 }
173 Trace("sort-inference-proc") << "...done" << std::endl;
174
175 Trace("sort-inference") << "We have " << d_sub_sorts.size()
176 << " sub-sorts : " << std::endl;
177 for (unsigned i = 0, size = d_sub_sorts.size(); i < size; i++)
178 {
179 printSort("sort-inference", d_sub_sorts[i]);
180 if (d_type_types.find(d_sub_sorts[i]) != d_type_types.end())
181 {
182 Trace("sort-inference") << " is interpreted." << std::endl;
183 }
184 else if (d_non_monotonic_sorts.find(d_sub_sorts[i])
185 == d_non_monotonic_sorts.end())
186 {
187 Trace("sort-inference") << " is monotonic." << std::endl;
188 }
189 else
190 {
191 Trace("sort-inference") << " is not monotonic." << std::endl;
192 }
193 }
194 }
195
196 Node SortInference::simplify(Node n,
197 std::map<Node, Node>& model_replace_f,
198 std::map<Node, std::map<TypeNode, Node> >& visited)
199 {
200 Trace("sort-inference-debug") << "Simplify " << n << std::endl;
201 std::map<Node, Node> var_bound;
202 TypeNode tnn;
203 Node ret = simplifyNode(n, var_bound, tnn, model_replace_f, visited);
204 ret = theory::Rewriter::rewrite(ret);
205 return ret;
206 }
207
208 void SortInference::getNewAssertions(std::vector<Node>& new_asserts)
209 {
210 NodeManager* nm = NodeManager::currentNM();
211 // now, ensure constants are distinct
212 for (const std::pair<const TypeNode, std::map<Node, Node> >& cm : d_const_map)
213 {
214 std::vector<Node> consts;
215 for (const std::pair<const Node, Node>& c : cm.second)
216 {
217 Assert(c.first.isConst());
218 consts.push_back(c.second);
219 }
220 // add lemma enforcing introduced constants to be distinct
221 if (consts.size() > 1)
222 {
223 Node distinct_const = nm->mkNode(kind::DISTINCT, consts);
224 Trace("sort-inference-rewrite")
225 << "Add the constant distinctness lemma: " << std::endl;
226 Trace("sort-inference-rewrite") << " " << distinct_const << std::endl;
227 new_asserts.push_back(distinct_const);
228 }
229 }
230
231 // enforce constraints based on monotonicity
232 Trace("sort-inference-proc") << "Enforce monotonicity..." << std::endl;
233
234 for (const std::pair<const TypeNode, std::vector<int> >& tss :
235 d_type_sub_sorts)
236 {
237 int nmonSort = -1;
238 unsigned nsorts = tss.second.size();
239 for (unsigned i = 0; i < nsorts; i++)
240 {
241 if (d_non_monotonic_sorts.find(tss.second[i])
242 != d_non_monotonic_sorts.end())
243 {
244 nmonSort = tss.second[i];
245 break;
246 }
247 }
248 if (nmonSort != -1)
249 {
250 std::vector<Node> injections;
251 TypeNode base_tn = getOrCreateTypeForId(nmonSort, tss.first);
252 for (unsigned i = 0; i < nsorts; i++)
253 {
254 if (tss.second[i] != nmonSort)
255 {
256 TypeNode new_tn = getOrCreateTypeForId(tss.second[i], tss.first);
257 // make injection to nmonSort
258 Node a1 = mkInjection(new_tn, base_tn);
259 injections.push_back(a1);
260 if (d_non_monotonic_sorts.find(tss.second[i])
261 != d_non_monotonic_sorts.end())
262 {
263 // also must make injection from nmonSort to this
264 Node a2 = mkInjection(base_tn, new_tn);
265 injections.push_back(a2);
266 }
267 }
268 }
269 if (Trace.isOn("sort-inference-rewrite"))
270 {
271 Trace("sort-inference-rewrite")
272 << "Add the following injections for " << tss.first
273 << " to ensure consistency wrt non-monotonic sorts : " << std::endl;
274 for (const Node& i : injections)
275 {
276 Trace("sort-inference-rewrite") << " " << i << std::endl;
277 }
278 }
279 new_asserts.insert(
280 new_asserts.end(), injections.begin(), injections.end());
281 }
282 }
283 Trace("sort-inference-proc") << "...done" << std::endl;
284 // no sub-sort information is stored
285 reset();
286 Trace("sort-inference-debug") << "Finished sort inference" << std::endl;
287 }
288
289 void SortInference::computeMonotonicity(const std::vector<Node>& assertions)
290 {
291 std::map<Node, std::map<int, bool> > visitedmt;
292 Trace("sort-inference-proc")
293 << "Calculating monotonicty for types..." << std::endl;
294 for (const Node& a : assertions)
295 {
296 Trace("sort-inference-debug")
297 << "Process type monotonicity for " << a << std::endl;
298 std::map<Node, Node> var_bound;
299 processMonotonic(a, true, true, var_bound, visitedmt, true);
300 }
301 Trace("sort-inference-proc") << "...done" << std::endl;
302 }
303
304 void SortInference::setEqual( int t1, int t2 ){
305 if( t1!=t2 ){
306 int rt1 = d_type_union_find.getRepresentative( t1 );
307 int rt2 = d_type_union_find.getRepresentative( t2 );
308 if( rt1!=rt2 ){
309 Trace("sort-inference-debug") << "Set equal : ";
310 printSort( "sort-inference-debug", rt1 );
311 Trace("sort-inference-debug") << " ";
312 printSort( "sort-inference-debug", rt2 );
313 Trace("sort-inference-debug") << std::endl;
314 /*
315 d_type_eq_class[rt1].insert( d_type_eq_class[rt1].end(), d_type_eq_class[rt2].begin(), d_type_eq_class[rt2].end() );
316 d_type_eq_class[rt2].clear();
317 Trace("sort-inference-debug") << "EqClass : { ";
318 for( int i=0; i<(int)d_type_eq_class[rt1].size(); i++ ){
319 Trace("sort-inference-debug") << d_type_eq_class[rt1][i] << ", ";
320 }
321 Trace("sort-inference-debug") << "}" << std::endl;
322 */
323 if( rt2>rt1 ){
324 //swap
325 int swap = rt1;
326 rt1 = rt2;
327 rt2 = swap;
328 }
329 std::map< int, TypeNode >::iterator it1 = d_type_types.find( rt1 );
330 if( it1!=d_type_types.end() ){
331 if( d_type_types.find( rt2 )==d_type_types.end() ){
332 d_type_types[rt2] = it1->second;
333 d_type_types.erase( rt1 );
334 }else{
335 Trace("sort-inference-debug") << "...fail : associated with types " << d_type_types[rt1] << " and " << d_type_types[rt2] << std::endl;
336 return;
337 }
338 }
339 d_type_union_find.d_eqc[rt1] = rt2;
340 }
341 }
342 }
343
344 int SortInference::getIdForType( TypeNode tn ){
345 //register the return type
346 std::map< TypeNode, int >::iterator it = d_id_for_types.find( tn );
347 if( it==d_id_for_types.end() ){
348 int sc = d_sortCount;
349 d_type_types[d_sortCount] = tn;
350 d_id_for_types[tn] = d_sortCount;
351 d_sortCount++;
352 return sc;
353 }else{
354 return it->second;
355 }
356 }
357
358 int SortInference::process( Node n, std::map< Node, Node >& var_bound, std::map< Node, int >& visited ){
359 std::map< Node, int >::iterator itv = visited.find( n );
360 if( itv!=visited.end() ){
361 return itv->second;
362 }else{
363 //add to variable bindings
364 bool use_new_visited = false;
365 std::map< Node, int > new_visited;
366 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
367 if( d_var_types.find( n )!=d_var_types.end() ){
368 return getIdForType( n.getType() );
369 }else{
370 //apply sort inference to quantified variables
371 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
372 TypeNode nitn = n[0][i].getType();
373 if( !nitn.isSort() )
374 {
375 // If the variable is of an interpreted sort, we assume the
376 // the sort of the variable will stay the same sort.
377 d_var_types[n][n[0][i]] = getIdForType( nitn );
378 }
379 else
380 {
381 // If it is of an uninterpreted sort, infer subsorts.
382 d_var_types[n][n[0][i]] = d_sortCount;
383 d_sortCount++;
384 }
385 var_bound[ n[0][i] ] = n;
386 }
387 }
388 use_new_visited = true;
389 }
390
391 //process children
392 std::vector< Node > children;
393 std::vector< int > child_types;
394 for( size_t i=0; i<n.getNumChildren(); i++ ){
395 bool processChild = true;
396 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
397 processChild = options::userPatternsQuant()==theory::quantifiers::USER_PAT_MODE_IGNORE ? i==1 : i>=1;
398 }
399 if( processChild ){
400 children.push_back( n[i] );
401 child_types.push_back( process( n[i], var_bound, use_new_visited ? new_visited : visited ) );
402 }
403 }
404
405 //remove from variable bindings
406 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
407 //erase from variable bound
408 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
409 var_bound.erase( n[0][i] );
410 }
411 }
412 Trace("sort-inference-debug") << "...Process " << n << std::endl;
413
414 int retType;
415 if( n.getKind()==kind::EQUAL && !n[0].getType().isBoolean() ){
416 Trace("sort-inference-debug") << "For equality " << n << ", set equal types from : " << n[0].getType() << " " << n[1].getType() << std::endl;
417 //if original types are mixed (e.g. Int/Real), don't commit type equality in either direction
418 if( n[0].getType()!=n[1].getType() ){
419 //for now, assume the original types
420 for( unsigned i=0; i<2; i++ ){
421 int ct = getIdForType( n[i].getType() );
422 setEqual( child_types[i], ct );
423 }
424 }else{
425 //we only require that the left and right hand side must be equal
426 setEqual( child_types[0], child_types[1] );
427 }
428 d_equality_types[n] = child_types[0];
429 retType = getIdForType( n.getType() );
430 }else if( n.getKind()==kind::APPLY_UF ){
431 Node op = n.getOperator();
432 TypeNode tn_op = op.getType();
433 if( d_op_return_types.find( op )==d_op_return_types.end() ){
434 if( n.getType().isBoolean() ){
435 //use booleans
436 d_op_return_types[op] = getIdForType( n.getType() );
437 }else{
438 //assign arbitrary sort for return type
439 d_op_return_types[op] = d_sortCount;
440 d_sortCount++;
441 }
442 // d_type_eq_class[d_sortCount].push_back( op );
443 // assign arbitrary sort for argument types
444 for( size_t i=0; i<n.getNumChildren(); i++ ){
445 d_op_arg_types[op].push_back(d_sortCount);
446 d_sortCount++;
447 }
448 }
449 for( size_t i=0; i<n.getNumChildren(); i++ ){
450 //the argument of the operator must match the return type of the subterm
451 if( n[i].getType()!=tn_op[i] ){
452 //if type mismatch, assume original types
453 Trace("sort-inference-debug") << "Argument " << i << " of " << op << " " << n[i] << " has type " << n[i].getType();
454 Trace("sort-inference-debug") << ", while operator arg has type " << tn_op[i] << std::endl;
455 int ct1 = getIdForType( n[i].getType() );
456 setEqual( child_types[i], ct1 );
457 int ct2 = getIdForType( tn_op[i] );
458 setEqual( d_op_arg_types[op][i], ct2 );
459 }else{
460 setEqual( child_types[i], d_op_arg_types[op][i] );
461 }
462 }
463 //return type is the return type
464 retType = d_op_return_types[op];
465 }else{
466 std::map< Node, Node >::iterator it = var_bound.find( n );
467 if( it!=var_bound.end() ){
468 Trace("sort-inference-debug") << n << " is a bound variable." << std::endl;
469 //the return type was specified while binding
470 retType = d_var_types[it->second][n];
471 }else if( n.getKind() == kind::VARIABLE || n.getKind()==kind::SKOLEM ){
472 Trace("sort-inference-debug") << n << " is a variable." << std::endl;
473 if( d_op_return_types.find( n )==d_op_return_types.end() ){
474 //assign arbitrary sort
475 d_op_return_types[n] = d_sortCount;
476 d_sortCount++;
477 // d_type_eq_class[d_sortCount].push_back( n );
478 }
479 retType = d_op_return_types[n];
480 }else if( n.isConst() ){
481 Trace("sort-inference-debug") << n << " is a constant." << std::endl;
482 //can be any type we want
483 retType = d_sortCount;
484 d_sortCount++;
485 }else{
486 Trace("sort-inference-debug") << n << " is a interpreted symbol." << std::endl;
487 //it is an interpreted term
488 for( size_t i=0; i<children.size(); i++ ){
489 Trace("sort-inference-debug") << children[i] << " forced to have " << children[i].getType() << std::endl;
490 //must enforce the actual type of the operator on the children
491 int ct = getIdForType( children[i].getType() );
492 setEqual( child_types[i], ct );
493 }
494 //return type must be the actual return type
495 retType = getIdForType( n.getType() );
496 }
497 }
498 Trace("sort-inference-debug") << "...Type( " << n << " ) = ";
499 printSort("sort-inference-debug", retType );
500 Trace("sort-inference-debug") << std::endl;
501 visited[n] = retType;
502 return retType;
503 }
504 }
505
506 void SortInference::processMonotonic( Node n, bool pol, bool hasPol, std::map< Node, Node >& var_bound, std::map< Node, std::map< int, bool > >& visited, bool typeMode ) {
507 int pindex = hasPol ? ( pol ? 1 : -1 ) : 0;
508 if( visited[n].find( pindex )==visited[n].end() ){
509 visited[n][pindex] = true;
510 Trace("sort-inference-debug") << "...Process monotonic " << pol << " " << hasPol << " " << n << std::endl;
511 if( n.getKind()==kind::FORALL ){
512 //only consider variables universally if it is possible this quantified formula is asserted positively
513 if( !hasPol || pol ){
514 for( unsigned i=0; i<n[0].getNumChildren(); i++ ){
515 var_bound[n[0][i]] = n;
516 }
517 }
518 processMonotonic( n[1], pol, hasPol, var_bound, visited, typeMode );
519 if( !hasPol || pol ){
520 for( unsigned i=0; i<n[0].getNumChildren(); i++ ){
521 var_bound.erase( n[0][i] );
522 }
523 }
524 return;
525 }else if( n.getKind()==kind::EQUAL ){
526 if( !hasPol || pol ){
527 for( unsigned i=0; i<2; i++ ){
528 if( var_bound.find( n[i] )!=var_bound.end() ){
529 if( !typeMode ){
530 int sid = getSortId( var_bound[n[i]], n[i] );
531 d_non_monotonic_sorts[sid] = true;
532 }else{
533 d_non_monotonic_sorts_orig[n[i].getType()] = true;
534 }
535 break;
536 }
537 }
538 }
539 }
540 for( unsigned i=0; i<n.getNumChildren(); i++ ){
541 bool npol;
542 bool nhasPol;
543 theory::QuantPhaseReq::getPolarity( n, i, hasPol, pol, nhasPol, npol );
544 processMonotonic( n[i], npol, nhasPol, var_bound, visited, typeMode );
545 }
546 }
547 }
548
549
550 TypeNode SortInference::getOrCreateTypeForId( int t, TypeNode pref ){
551 int rt = d_type_union_find.getRepresentative( t );
552 if( d_type_types.find( rt )!=d_type_types.end() ){
553 return d_type_types[rt];
554 }else{
555 TypeNode retType;
556 // See if we can assign pref. This is an optimization for reusing an
557 // uninterpreted sort as the first subsort, so that fewer symbols needed
558 // to be rewritten in the sort-inferred signature. Notice we only assign
559 // pref here if it is an uninterpreted sort.
560 if (!pref.isNull() && d_id_for_types.find(pref) == d_id_for_types.end()
561 && pref.isSort())
562 {
563 retType = pref;
564 }else{
565 //must create new type
566 std::stringstream ss;
567 ss << "it_" << t << "_" << pref;
568 retType = NodeManager::currentNM()->mkSort( ss.str() );
569 }
570 Trace("sort-inference") << "-> Make type " << retType << " to correspond to ";
571 printSort("sort-inference", t );
572 Trace("sort-inference") << std::endl;
573 d_id_for_types[ retType ] = rt;
574 d_type_types[ rt ] = retType;
575 return retType;
576 }
577 }
578
579 TypeNode SortInference::getTypeForId( int t ){
580 int rt = d_type_union_find.getRepresentative( t );
581 if( d_type_types.find( rt )!=d_type_types.end() ){
582 return d_type_types[rt];
583 }else{
584 return TypeNode::null();
585 }
586 }
587
588 Node SortInference::getNewSymbol( Node old, TypeNode tn ){
589 // if no sort was inferred for this node, return original
590 if( tn.isNull() || tn.isComparableTo( old.getType() ) ){
591 return old;
592 }else if( old.isConst() ){
593 //must make constant of type tn
594 if( d_const_map[tn].find( old )==d_const_map[tn].end() ){
595 std::stringstream ss;
596 ss << "ic_" << tn << "_" << old;
597 d_const_map[tn][ old ] = NodeManager::currentNM()->mkSkolem( ss.str(), tn, "constant created during sort inference" ); //use mkConst???
598 }
599 return d_const_map[tn][ old ];
600 }else if( old.getKind()==kind::BOUND_VARIABLE ){
601 std::stringstream ss;
602 ss << "b_" << old;
603 return NodeManager::currentNM()->mkBoundVar( ss.str(), tn );
604 }else{
605 std::stringstream ss;
606 ss << "i_" << old;
607 return NodeManager::currentNM()->mkSkolem( ss.str(), tn, "created during sort inference" );
608 }
609 }
610
611 Node SortInference::simplifyNode(
612 Node n,
613 std::map<Node, Node>& var_bound,
614 TypeNode tnn,
615 std::map<Node, Node>& model_replace_f,
616 std::map<Node, std::map<TypeNode, Node> >& visited)
617 {
618 std::map< TypeNode, Node >::iterator itv = visited[n].find( tnn );
619 if( itv!=visited[n].end() ){
620 return itv->second;
621 }else{
622 Trace("sort-inference-debug2") << "Simplify " << n << ", type context=" << tnn << std::endl;
623 std::vector< Node > children;
624 std::map< Node, std::map< TypeNode, Node > > new_visited;
625 bool use_new_visited = false;
626 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
627 //recreate based on types of variables
628 std::vector< Node > new_children;
629 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
630 TypeNode tn = getOrCreateTypeForId( d_var_types[n][ n[0][i] ], n[0][i].getType() );
631 Node v = getNewSymbol( n[0][i], tn );
632 Trace("sort-inference-debug2") << "Map variable " << n[0][i] << " to " << v << std::endl;
633 new_children.push_back( v );
634 var_bound[ n[0][i] ] = v;
635 }
636 children.push_back( NodeManager::currentNM()->mkNode( n[0].getKind(), new_children ) );
637 use_new_visited = true;
638 }
639
640 //process children
641 if( n.getMetaKind() == kind::metakind::PARAMETERIZED ){
642 children.push_back( n.getOperator() );
643 }
644 Node op;
645 if( n.hasOperator() ){
646 op = n.getOperator();
647 }
648 bool childChanged = false;
649 TypeNode tnnc;
650 for( size_t i=0; i<n.getNumChildren(); i++ ){
651 bool processChild = true;
652 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
653 processChild = options::userPatternsQuant()==theory::quantifiers::USER_PAT_MODE_IGNORE ? i==1 : i>=1;
654 }
655 if( processChild ){
656 if( n.getKind()==kind::APPLY_UF ){
657 Assert( d_op_arg_types.find( op )!=d_op_arg_types.end() );
658 tnnc = getOrCreateTypeForId( d_op_arg_types[op][i], n[i].getType() );
659 Assert( !tnnc.isNull() );
660 }else if( n.getKind()==kind::EQUAL && i==0 ){
661 Assert( d_equality_types.find( n )!=d_equality_types.end() );
662 tnnc = getOrCreateTypeForId( d_equality_types[n], n[0].getType() );
663 Assert( !tnnc.isNull() );
664 }
665 Node nc = simplifyNode(n[i],
666 var_bound,
667 tnnc,
668 model_replace_f,
669 use_new_visited ? new_visited : visited);
670 Trace("sort-inference-debug2") << "Simplify " << i << " " << n[i] << " returned " << nc << std::endl;
671 children.push_back( nc );
672 childChanged = childChanged || nc!=n[i];
673 }
674 }
675
676 //remove from variable bindings
677 Node ret;
678 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
679 //erase from variable bound
680 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
681 Trace("sort-inference-debug2") << "Remove bound for " << n[0][i] << std::endl;
682 var_bound.erase( n[0][i] );
683 }
684 ret = NodeManager::currentNM()->mkNode( n.getKind(), children );
685 }else if( n.getKind()==kind::EQUAL ){
686 TypeNode tn1 = children[0].getType();
687 TypeNode tn2 = children[1].getType();
688 if( !tn1.isSubtypeOf( tn2 ) && !tn2.isSubtypeOf( tn1 ) ){
689 Trace("sort-inference-warn") << "Sort inference created bad equality: " << children[0] << " = " << children[1] << std::endl;
690 Trace("sort-inference-warn") << " Types : " << children[0].getType() << " " << children[1].getType() << std::endl;
691 Assert( false );
692 }
693 ret = NodeManager::currentNM()->mkNode( kind::EQUAL, children );
694 }else if( n.getKind()==kind::APPLY_UF ){
695 if( d_symbol_map.find( op )==d_symbol_map.end() ){
696 //make the new operator if necessary
697 bool opChanged = false;
698 std::vector< TypeNode > argTypes;
699 for( size_t i=0; i<n.getNumChildren(); i++ ){
700 TypeNode tn = getOrCreateTypeForId( d_op_arg_types[op][i], n[i].getType() );
701 argTypes.push_back( tn );
702 if( tn!=n[i].getType() ){
703 opChanged = true;
704 }
705 }
706 TypeNode retType = getOrCreateTypeForId( d_op_return_types[op], n.getType() );
707 if( retType!=n.getType() ){
708 opChanged = true;
709 }
710 if( opChanged ){
711 std::stringstream ss;
712 ss << "io_" << op;
713 TypeNode typ = NodeManager::currentNM()->mkFunctionType( argTypes, retType );
714 d_symbol_map[op] = NodeManager::currentNM()->mkSkolem( ss.str(), typ, "op created during sort inference" );
715 Trace("setp-model") << "Function " << op << " is replaced with " << d_symbol_map[op] << std::endl;
716 model_replace_f[op] = d_symbol_map[op];
717 }else{
718 d_symbol_map[op] = op;
719 }
720 }
721 children[0] = d_symbol_map[op];
722 // make sure all children have been given proper types
723 for (size_t i = 0, size = n.getNumChildren(); i < size; i++)
724 {
725 TypeNode tn = children[i+1].getType();
726 TypeNode tna = getTypeForId( d_op_arg_types[op][i] );
727 if (!tn.isSubtypeOf(tna))
728 {
729 Trace("sort-inference-warn") << "Sort inference created bad child: " << n << " " << n[i] << " " << tn << " " << tna << std::endl;
730 Assert( false );
731 }
732 }
733 ret = NodeManager::currentNM()->mkNode( kind::APPLY_UF, children );
734 }else{
735 std::map< Node, Node >::iterator it = var_bound.find( n );
736 if( it!=var_bound.end() ){
737 ret = it->second;
738 }else if( n.getKind() == kind::VARIABLE || n.getKind() == kind::SKOLEM ){
739 if( d_symbol_map.find( n )==d_symbol_map.end() ){
740 TypeNode tn = getOrCreateTypeForId( d_op_return_types[n], n.getType() );
741 d_symbol_map[n] = getNewSymbol( n, tn );
742 }
743 ret = d_symbol_map[n];
744 }else if( n.isConst() ){
745 //type is determined by context
746 ret = getNewSymbol( n, tnn );
747 }else if( childChanged ){
748 ret = NodeManager::currentNM()->mkNode( n.getKind(), children );
749 }else{
750 ret = n;
751 }
752 }
753 visited[n][tnn] = ret;
754 return ret;
755 }
756 }
757
758 Node SortInference::mkInjection( TypeNode tn1, TypeNode tn2 ) {
759 std::vector< TypeNode > tns;
760 tns.push_back( tn1 );
761 TypeNode typ = NodeManager::currentNM()->mkFunctionType( tns, tn2 );
762 Node f = NodeManager::currentNM()->mkSkolem( "inj", typ, "injection for monotonicity constraint" );
763 Trace("sort-inference") << "-> Make injection " << f << " from " << tn1 << " to " << tn2 << std::endl;
764 Node v1 = NodeManager::currentNM()->mkBoundVar( "?x", tn1 );
765 Node v2 = NodeManager::currentNM()->mkBoundVar( "?y", tn1 );
766 Node ret = NodeManager::currentNM()->mkNode( kind::FORALL,
767 NodeManager::currentNM()->mkNode( kind::BOUND_VAR_LIST, v1, v2 ),
768 NodeManager::currentNM()->mkNode( kind::OR,
769 NodeManager::currentNM()->mkNode( kind::APPLY_UF, f, v1 ).eqNode( NodeManager::currentNM()->mkNode( kind::APPLY_UF, f, v2 ) ).negate(),
770 v1.eqNode( v2 ) ) );
771 ret = theory::Rewriter::rewrite( ret );
772 return ret;
773 }
774
775 int SortInference::getSortId( Node n ) {
776 Node op = n.getKind()==kind::APPLY_UF ? n.getOperator() : n;
777 if( d_op_return_types.find( op )!=d_op_return_types.end() ){
778 return d_type_union_find.getRepresentative( d_op_return_types[op] );
779 }else{
780 return 0;
781 }
782 }
783
784 int SortInference::getSortId( Node f, Node v ) {
785 if( d_var_types.find( f )!=d_var_types.end() ){
786 return d_type_union_find.getRepresentative( d_var_types[f][v] );
787 }else{
788 return 0;
789 }
790 }
791
792 void SortInference::setSkolemVar( Node f, Node v, Node sk ){
793 Trace("sort-inference-temp") << "Set skolem var for " << f << ", variable " << v << std::endl;
794 if( isWellSortedFormula( f ) && d_var_types.find( f )==d_var_types.end() ){
795 //calculate the sort for variables if not done so already
796 std::map< Node, Node > var_bound;
797 std::map< Node, int > visited;
798 process( f, var_bound, visited );
799 }
800 d_op_return_types[sk] = getSortId( f, v );
801 Trace("sort-inference-temp") << "Set skolem sort id for " << sk << " to " << d_op_return_types[sk] << std::endl;
802 }
803
804 bool SortInference::isWellSortedFormula( Node n ) {
805 if( n.getType().isBoolean() && n.getKind()!=kind::APPLY_UF ){
806 for( unsigned i=0; i<n.getNumChildren(); i++ ){
807 if( !isWellSortedFormula( n[i] ) ){
808 return false;
809 }
810 }
811 return true;
812 }else{
813 return isWellSorted( n );
814 }
815 }
816
817 bool SortInference::isWellSorted( Node n ) {
818 if( getSortId( n )==0 ){
819 return false;
820 }else{
821 if( n.getKind()==kind::APPLY_UF ){
822 for( unsigned i=0; i<n.getNumChildren(); i++ ){
823 int s1 = getSortId( n[i] );
824 int s2 = d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] );
825 if( s1!=s2 ){
826 return false;
827 }
828 if( !isWellSorted( n[i] ) ){
829 return false;
830 }
831 }
832 }
833 return true;
834 }
835 }
836
837 void SortInference::getSortConstraints( Node n, UnionFind& uf ) {
838 if( n.getKind()==kind::APPLY_UF ){
839 for( unsigned i=0; i<n.getNumChildren(); i++ ){
840 getSortConstraints( n[i], uf );
841 uf.setEqual( getSortId( n[i] ), d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] ) );
842 }
843 }
844 }
845
846 bool SortInference::isMonotonic( TypeNode tn ) {
847 Assert( tn.isSort() );
848 return d_non_monotonic_sorts_orig.find( tn )==d_non_monotonic_sorts_orig.end();
849 }
850
851 }/* CVC4 namespace */