Fix sort inference for top-level Boolean variables (#4012)
[cvc5.git] / src / theory / sort_inference.cpp
1 /********************* */
2 /*! \file sort_inference.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds, Paul Meng, Kshitij Bansal
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Sort inference module
13 **
14 ** This class implements sort inference, based on a simple algorithm:
15 ** First, we assume all functions and predicates have distinct uninterpreted types.
16 ** One pass is made through the input assertions, while a union-find data structure
17 ** maintains necessary information regarding constraints on these types.
18 **/
19
20 #include "theory/sort_inference.h"
21
22 #include <vector>
23
24 #include "options/quantifiers_options.h"
25 #include "options/smt_options.h"
26 #include "options/uf_options.h"
27 #include "proof/proof_manager.h"
28 #include "theory/rewriter.h"
29 #include "theory/quantifiers/quant_util.h"
30
31 using namespace CVC4;
32 using namespace std;
33
34 namespace CVC4 {
35
36 void SortInference::UnionFind::print(const char * c){
37 for( std::map< int, int >::iterator it = d_eqc.begin(); it != d_eqc.end(); ++it ){
38 Trace(c) << "s_" << it->first << " = s_" << it->second << ", ";
39 }
40 for( unsigned i=0; i<d_deq.size(); i++ ){
41 Trace(c) << "s_" << d_deq[i].first << " != s_" << d_deq[i].second << ", ";
42 }
43 Trace(c) << std::endl;
44 }
45 void SortInference::UnionFind::set( UnionFind& c ) {
46 clear();
47 for( std::map< int, int >::iterator it = c.d_eqc.begin(); it != c.d_eqc.end(); ++it ){
48 d_eqc[ it->first ] = it->second;
49 }
50 d_deq.insert( d_deq.end(), c.d_deq.begin(), c.d_deq.end() );
51 }
52 int SortInference::UnionFind::getRepresentative( int t ){
53 std::map< int, int >::iterator it = d_eqc.find( t );
54 if( it==d_eqc.end() || it->second==t ){
55 return t;
56 }else{
57 int rt = getRepresentative( it->second );
58 d_eqc[t] = rt;
59 return rt;
60 }
61 }
62 void SortInference::UnionFind::setEqual( int t1, int t2 ){
63 if( t1!=t2 ){
64 int rt1 = getRepresentative( t1 );
65 int rt2 = getRepresentative( t2 );
66 if( rt1>rt2 ){
67 d_eqc[rt1] = rt2;
68 }else{
69 d_eqc[rt2] = rt1;
70 }
71 }
72 }
73 bool SortInference::UnionFind::isValid() {
74 for( unsigned i=0; i<d_deq.size(); i++ ){
75 if( areEqual( d_deq[i].first, d_deq[i].second ) ){
76 return false;
77 }
78 }
79 return true;
80 }
81
82
83 void SortInference::recordSubsort( TypeNode tn, int s ){
84 s = d_type_union_find.getRepresentative( s );
85 if( std::find( d_sub_sorts.begin(), d_sub_sorts.end(), s )==d_sub_sorts.end() ){
86 d_sub_sorts.push_back( s );
87 d_type_sub_sorts[tn].push_back( s );
88 }
89 }
90
91 void SortInference::printSort( const char* c, int t ){
92 int rt = d_type_union_find.getRepresentative( t );
93 if( d_type_types.find( rt )!=d_type_types.end() ){
94 Trace(c) << d_type_types[rt];
95 }else{
96 Trace(c) << "s_" << rt;
97 }
98 }
99
100 void SortInference::reset() {
101 d_sub_sorts.clear();
102 d_non_monotonic_sorts.clear();
103 d_type_sub_sorts.clear();
104 //reset info
105 d_sortCount = 1;
106 d_type_union_find.clear();
107 d_type_types.clear();
108 d_id_for_types.clear();
109 d_op_return_types.clear();
110 d_op_arg_types.clear();
111 d_var_types.clear();
112 //for rewriting
113 d_symbol_map.clear();
114 d_const_map.clear();
115 }
116
117 void SortInference::initialize(const std::vector<Node>& assertions)
118 {
119 Trace("sort-inference-proc") << "Calculating sort inference..." << std::endl;
120 // process all assertions
121 std::map<Node, int> visited;
122 NodeManager * nm = NodeManager::currentNM();
123 int btId = getIdForType( nm->booleanType() );
124 for (const Node& a : assertions)
125 {
126 Trace("sort-inference-debug") << "Process " << a << std::endl;
127 std::map<Node, Node> var_bound;
128 int pid = process(a, var_bound, visited);
129 // the type of the topmost term must be Boolean
130 setEqual( pid, btId );
131 }
132 Trace("sort-inference-proc") << "...done" << std::endl;
133 for (const std::pair<const Node, int>& rt : d_op_return_types)
134 {
135 Trace("sort-inference") << rt.first << " : ";
136 TypeNode retTn = rt.first.getType();
137 if (!d_op_arg_types[rt.first].empty())
138 {
139 Trace("sort-inference") << "( ";
140 for (size_t i = 0; i < d_op_arg_types[rt.first].size(); i++)
141 {
142 recordSubsort(retTn[i], d_op_arg_types[rt.first][i]);
143 printSort("sort-inference", d_op_arg_types[rt.first][i]);
144 Trace("sort-inference") << " ";
145 }
146 Trace("sort-inference") << ") -> ";
147 retTn = retTn[(int)retTn.getNumChildren() - 1];
148 }
149 recordSubsort(retTn, rt.second);
150 printSort("sort-inference", rt.second);
151 Trace("sort-inference") << std::endl;
152 }
153 for (std::pair<const Node, std::map<Node, int> >& vt : d_var_types)
154 {
155 Trace("sort-inference")
156 << "Quantified formula : " << vt.first << " : " << std::endl;
157 for (const Node& v : vt.first[0])
158 {
159 recordSubsort(v.getType(), vt.second[v]);
160 printSort("sort-inference", vt.second[v]);
161 Trace("sort-inference") << std::endl;
162 }
163 Trace("sort-inference") << std::endl;
164 }
165
166 // determine monotonicity of sorts
167 Trace("sort-inference-proc")
168 << "Calculating monotonicty for subsorts..." << std::endl;
169 std::map<Node, std::map<int, bool> > visitedm;
170 for (const Node& a : assertions)
171 {
172 Trace("sort-inference-debug")
173 << "Process monotonicity for " << a << std::endl;
174 std::map<Node, Node> var_bound;
175 processMonotonic(a, true, true, var_bound, visitedm);
176 }
177 Trace("sort-inference-proc") << "...done" << std::endl;
178
179 Trace("sort-inference") << "We have " << d_sub_sorts.size()
180 << " sub-sorts : " << std::endl;
181 for (unsigned i = 0, size = d_sub_sorts.size(); i < size; i++)
182 {
183 printSort("sort-inference", d_sub_sorts[i]);
184 if (d_type_types.find(d_sub_sorts[i]) != d_type_types.end())
185 {
186 Trace("sort-inference") << " is interpreted." << std::endl;
187 }
188 else if (d_non_monotonic_sorts.find(d_sub_sorts[i])
189 == d_non_monotonic_sorts.end())
190 {
191 Trace("sort-inference") << " is monotonic." << std::endl;
192 }
193 else
194 {
195 Trace("sort-inference") << " is not monotonic." << std::endl;
196 }
197 }
198 }
199
200 Node SortInference::simplify(Node n,
201 std::map<Node, Node>& model_replace_f,
202 std::map<Node, std::map<TypeNode, Node> >& visited)
203 {
204 Trace("sort-inference-debug") << "Simplify " << n << std::endl;
205 std::map<Node, Node> var_bound;
206 TypeNode tnn;
207 Node ret = simplifyNode(n, var_bound, tnn, model_replace_f, visited);
208 ret = theory::Rewriter::rewrite(ret);
209 return ret;
210 }
211
212 void SortInference::getNewAssertions(std::vector<Node>& new_asserts)
213 {
214 NodeManager* nm = NodeManager::currentNM();
215 // now, ensure constants are distinct
216 for (const std::pair<const TypeNode, std::map<Node, Node> >& cm : d_const_map)
217 {
218 std::vector<Node> consts;
219 for (const std::pair<const Node, Node>& c : cm.second)
220 {
221 Assert(c.first.isConst());
222 consts.push_back(c.second);
223 }
224 // add lemma enforcing introduced constants to be distinct
225 if (consts.size() > 1)
226 {
227 Node distinct_const = nm->mkNode(kind::DISTINCT, consts);
228 Trace("sort-inference-rewrite")
229 << "Add the constant distinctness lemma: " << std::endl;
230 Trace("sort-inference-rewrite") << " " << distinct_const << std::endl;
231 new_asserts.push_back(distinct_const);
232 }
233 }
234
235 // enforce constraints based on monotonicity
236 Trace("sort-inference-proc") << "Enforce monotonicity..." << std::endl;
237
238 for (const std::pair<const TypeNode, std::vector<int> >& tss :
239 d_type_sub_sorts)
240 {
241 int nmonSort = -1;
242 unsigned nsorts = tss.second.size();
243 for (unsigned i = 0; i < nsorts; i++)
244 {
245 if (d_non_monotonic_sorts.find(tss.second[i])
246 != d_non_monotonic_sorts.end())
247 {
248 nmonSort = tss.second[i];
249 break;
250 }
251 }
252 if (nmonSort != -1)
253 {
254 std::vector<Node> injections;
255 TypeNode base_tn = getOrCreateTypeForId(nmonSort, tss.first);
256 for (unsigned i = 0; i < nsorts; i++)
257 {
258 if (tss.second[i] != nmonSort)
259 {
260 TypeNode new_tn = getOrCreateTypeForId(tss.second[i], tss.first);
261 // make injection to nmonSort
262 Node a1 = mkInjection(new_tn, base_tn);
263 injections.push_back(a1);
264 if (d_non_monotonic_sorts.find(tss.second[i])
265 != d_non_monotonic_sorts.end())
266 {
267 // also must make injection from nmonSort to this
268 Node a2 = mkInjection(base_tn, new_tn);
269 injections.push_back(a2);
270 }
271 }
272 }
273 if (Trace.isOn("sort-inference-rewrite"))
274 {
275 Trace("sort-inference-rewrite")
276 << "Add the following injections for " << tss.first
277 << " to ensure consistency wrt non-monotonic sorts : " << std::endl;
278 for (const Node& i : injections)
279 {
280 Trace("sort-inference-rewrite") << " " << i << std::endl;
281 }
282 }
283 new_asserts.insert(
284 new_asserts.end(), injections.begin(), injections.end());
285 }
286 }
287 Trace("sort-inference-proc") << "...done" << std::endl;
288 // no sub-sort information is stored
289 reset();
290 Trace("sort-inference-debug") << "Finished sort inference" << std::endl;
291 }
292
293 void SortInference::computeMonotonicity(const std::vector<Node>& assertions)
294 {
295 std::map<Node, std::map<int, bool> > visitedmt;
296 Trace("sort-inference-proc")
297 << "Calculating monotonicty for types..." << std::endl;
298 for (const Node& a : assertions)
299 {
300 Trace("sort-inference-debug")
301 << "Process type monotonicity for " << a << std::endl;
302 std::map<Node, Node> var_bound;
303 processMonotonic(a, true, true, var_bound, visitedmt, true);
304 }
305 Trace("sort-inference-proc") << "...done" << std::endl;
306 }
307
308 void SortInference::setEqual( int t1, int t2 ){
309 if( t1!=t2 ){
310 int rt1 = d_type_union_find.getRepresentative( t1 );
311 int rt2 = d_type_union_find.getRepresentative( t2 );
312 if( rt1!=rt2 ){
313 Trace("sort-inference-debug") << "Set equal : ";
314 printSort( "sort-inference-debug", rt1 );
315 Trace("sort-inference-debug") << " ";
316 printSort( "sort-inference-debug", rt2 );
317 Trace("sort-inference-debug") << std::endl;
318 /*
319 d_type_eq_class[rt1].insert( d_type_eq_class[rt1].end(), d_type_eq_class[rt2].begin(), d_type_eq_class[rt2].end() );
320 d_type_eq_class[rt2].clear();
321 Trace("sort-inference-debug") << "EqClass : { ";
322 for( int i=0; i<(int)d_type_eq_class[rt1].size(); i++ ){
323 Trace("sort-inference-debug") << d_type_eq_class[rt1][i] << ", ";
324 }
325 Trace("sort-inference-debug") << "}" << std::endl;
326 */
327 if( rt2>rt1 ){
328 //swap
329 int swap = rt1;
330 rt1 = rt2;
331 rt2 = swap;
332 }
333 std::map< int, TypeNode >::iterator it1 = d_type_types.find( rt1 );
334 if( it1!=d_type_types.end() ){
335 if( d_type_types.find( rt2 )==d_type_types.end() ){
336 d_type_types[rt2] = it1->second;
337 d_type_types.erase( rt1 );
338 }else{
339 Trace("sort-inference-debug") << "...fail : associated with types " << d_type_types[rt1] << " and " << d_type_types[rt2] << std::endl;
340 return;
341 }
342 }
343 d_type_union_find.d_eqc[rt1] = rt2;
344 }
345 }
346 }
347
348 int SortInference::getIdForType( TypeNode tn ){
349 //register the return type
350 std::map< TypeNode, int >::iterator it = d_id_for_types.find( tn );
351 if( it==d_id_for_types.end() ){
352 int sc = d_sortCount;
353 d_type_types[d_sortCount] = tn;
354 d_id_for_types[tn] = d_sortCount;
355 d_sortCount++;
356 return sc;
357 }else{
358 return it->second;
359 }
360 }
361
362 int SortInference::process( Node n, std::map< Node, Node >& var_bound, std::map< Node, int >& visited ){
363 std::map< Node, int >::iterator itv = visited.find( n );
364 if( itv!=visited.end() ){
365 return itv->second;
366 }else{
367 //add to variable bindings
368 bool use_new_visited = false;
369 std::map< Node, int > new_visited;
370 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
371 if( d_var_types.find( n )!=d_var_types.end() ){
372 return getIdForType( n.getType() );
373 }else{
374 //apply sort inference to quantified variables
375 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
376 TypeNode nitn = n[0][i].getType();
377 if( !nitn.isSort() )
378 {
379 // If the variable is of an interpreted sort, we assume the
380 // the sort of the variable will stay the same sort.
381 d_var_types[n][n[0][i]] = getIdForType( nitn );
382 }
383 else
384 {
385 // If it is of an uninterpreted sort, infer subsorts.
386 d_var_types[n][n[0][i]] = d_sortCount;
387 d_sortCount++;
388 }
389 var_bound[ n[0][i] ] = n;
390 }
391 }
392 use_new_visited = true;
393 }
394
395 //process children
396 std::vector< Node > children;
397 std::vector< int > child_types;
398 for( size_t i=0; i<n.getNumChildren(); i++ ){
399 bool processChild = true;
400 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
401 processChild =
402 options::userPatternsQuant() == options::UserPatMode::IGNORE
403 ? i == 1
404 : i >= 1;
405 }
406 if( processChild ){
407 children.push_back( n[i] );
408 child_types.push_back( process( n[i], var_bound, use_new_visited ? new_visited : visited ) );
409 }
410 }
411
412 //remove from variable bindings
413 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
414 //erase from variable bound
415 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
416 var_bound.erase( n[0][i] );
417 }
418 }
419 Trace("sort-inference-debug") << "...Process " << n << std::endl;
420
421 int retType;
422 if( n.getKind()==kind::EQUAL && !n[0].getType().isBoolean() ){
423 Trace("sort-inference-debug") << "For equality " << n << ", set equal types from : " << n[0].getType() << " " << n[1].getType() << std::endl;
424 //if original types are mixed (e.g. Int/Real), don't commit type equality in either direction
425 if( n[0].getType()!=n[1].getType() ){
426 //for now, assume the original types
427 for( unsigned i=0; i<2; i++ ){
428 int ct = getIdForType( n[i].getType() );
429 setEqual( child_types[i], ct );
430 }
431 }else{
432 //we only require that the left and right hand side must be equal
433 setEqual( child_types[0], child_types[1] );
434 }
435 d_equality_types[n] = child_types[0];
436 retType = getIdForType( n.getType() );
437 }else if( n.getKind()==kind::APPLY_UF ){
438 Node op = n.getOperator();
439 TypeNode tn_op = op.getType();
440 if( d_op_return_types.find( op )==d_op_return_types.end() ){
441 if( n.getType().isBoolean() ){
442 //use booleans
443 d_op_return_types[op] = getIdForType( n.getType() );
444 }else{
445 //assign arbitrary sort for return type
446 d_op_return_types[op] = d_sortCount;
447 d_sortCount++;
448 }
449 // d_type_eq_class[d_sortCount].push_back( op );
450 // assign arbitrary sort for argument types
451 for( size_t i=0; i<n.getNumChildren(); i++ ){
452 d_op_arg_types[op].push_back(d_sortCount);
453 d_sortCount++;
454 }
455 }
456 for( size_t i=0; i<n.getNumChildren(); i++ ){
457 //the argument of the operator must match the return type of the subterm
458 if( n[i].getType()!=tn_op[i] ){
459 //if type mismatch, assume original types
460 Trace("sort-inference-debug") << "Argument " << i << " of " << op << " " << n[i] << " has type " << n[i].getType();
461 Trace("sort-inference-debug") << ", while operator arg has type " << tn_op[i] << std::endl;
462 int ct1 = getIdForType( n[i].getType() );
463 setEqual( child_types[i], ct1 );
464 int ct2 = getIdForType( tn_op[i] );
465 setEqual( d_op_arg_types[op][i], ct2 );
466 }else{
467 setEqual( child_types[i], d_op_arg_types[op][i] );
468 }
469 }
470 //return type is the return type
471 retType = d_op_return_types[op];
472 }else{
473 std::map< Node, Node >::iterator it = var_bound.find( n );
474 if( it!=var_bound.end() ){
475 Trace("sort-inference-debug") << n << " is a bound variable." << std::endl;
476 //the return type was specified while binding
477 retType = d_var_types[it->second][n];
478 }else if( n.isVar() ){
479 Trace("sort-inference-debug") << n << " is a variable." << std::endl;
480 if( d_op_return_types.find( n )==d_op_return_types.end() ){
481 //assign arbitrary sort
482 d_op_return_types[n] = d_sortCount;
483 d_sortCount++;
484 // d_type_eq_class[d_sortCount].push_back( n );
485 }
486 retType = d_op_return_types[n];
487 }else if( n.isConst() ){
488 Trace("sort-inference-debug") << n << " is a constant." << std::endl;
489 //can be any type we want
490 retType = d_sortCount;
491 d_sortCount++;
492 }else{
493 Trace("sort-inference-debug") << n << " is a interpreted symbol." << std::endl;
494 //it is an interpreted term
495 for( size_t i=0; i<children.size(); i++ ){
496 Trace("sort-inference-debug") << children[i] << " forced to have " << children[i].getType() << std::endl;
497 //must enforce the actual type of the operator on the children
498 int ct = getIdForType( children[i].getType() );
499 setEqual( child_types[i], ct );
500 }
501 //return type must be the actual return type
502 retType = getIdForType( n.getType() );
503 }
504 }
505 Trace("sort-inference-debug") << "...Type( " << n << " ) = ";
506 printSort("sort-inference-debug", retType );
507 Trace("sort-inference-debug") << std::endl;
508 visited[n] = retType;
509 return retType;
510 }
511 }
512
513 void SortInference::processMonotonic( Node n, bool pol, bool hasPol, std::map< Node, Node >& var_bound, std::map< Node, std::map< int, bool > >& visited, bool typeMode ) {
514 int pindex = hasPol ? ( pol ? 1 : -1 ) : 0;
515 if( visited[n].find( pindex )==visited[n].end() ){
516 visited[n][pindex] = true;
517 Trace("sort-inference-debug") << "...Process monotonic " << pol << " " << hasPol << " " << n << std::endl;
518 if( n.getKind()==kind::FORALL ){
519 //only consider variables universally if it is possible this quantified formula is asserted positively
520 if( !hasPol || pol ){
521 for( unsigned i=0; i<n[0].getNumChildren(); i++ ){
522 var_bound[n[0][i]] = n;
523 }
524 }
525 processMonotonic( n[1], pol, hasPol, var_bound, visited, typeMode );
526 if( !hasPol || pol ){
527 for( unsigned i=0; i<n[0].getNumChildren(); i++ ){
528 var_bound.erase( n[0][i] );
529 }
530 }
531 return;
532 }else if( n.getKind()==kind::EQUAL ){
533 if( !hasPol || pol ){
534 for( unsigned i=0; i<2; i++ ){
535 if( var_bound.find( n[i] )!=var_bound.end() ){
536 if( !typeMode ){
537 int sid = getSortId( var_bound[n[i]], n[i] );
538 d_non_monotonic_sorts[sid] = true;
539 }else{
540 d_non_monotonic_sorts_orig[n[i].getType()] = true;
541 }
542 break;
543 }
544 }
545 }
546 }
547 for( unsigned i=0; i<n.getNumChildren(); i++ ){
548 bool npol;
549 bool nhasPol;
550 theory::QuantPhaseReq::getPolarity( n, i, hasPol, pol, nhasPol, npol );
551 processMonotonic( n[i], npol, nhasPol, var_bound, visited, typeMode );
552 }
553 }
554 }
555
556
557 TypeNode SortInference::getOrCreateTypeForId( int t, TypeNode pref ){
558 int rt = d_type_union_find.getRepresentative( t );
559 if( d_type_types.find( rt )!=d_type_types.end() ){
560 return d_type_types[rt];
561 }else{
562 TypeNode retType;
563 // See if we can assign pref. This is an optimization for reusing an
564 // uninterpreted sort as the first subsort, so that fewer symbols needed
565 // to be rewritten in the sort-inferred signature. Notice we only assign
566 // pref here if it is an uninterpreted sort.
567 if (!pref.isNull() && d_id_for_types.find(pref) == d_id_for_types.end()
568 && pref.isSort())
569 {
570 retType = pref;
571 }else{
572 //must create new type
573 std::stringstream ss;
574 ss << "it_" << t << "_" << pref;
575 retType = NodeManager::currentNM()->mkSort( ss.str() );
576 }
577 Trace("sort-inference") << "-> Make type " << retType << " to correspond to ";
578 printSort("sort-inference", t );
579 Trace("sort-inference") << std::endl;
580 d_id_for_types[ retType ] = rt;
581 d_type_types[ rt ] = retType;
582 return retType;
583 }
584 }
585
586 TypeNode SortInference::getTypeForId( int t ){
587 int rt = d_type_union_find.getRepresentative( t );
588 if( d_type_types.find( rt )!=d_type_types.end() ){
589 return d_type_types[rt];
590 }else{
591 return TypeNode::null();
592 }
593 }
594
595 Node SortInference::getNewSymbol( Node old, TypeNode tn ){
596 // if no sort was inferred for this node, return original
597 if( tn.isNull() || tn.isComparableTo( old.getType() ) ){
598 return old;
599 }else if( old.isConst() ){
600 //must make constant of type tn
601 if( d_const_map[tn].find( old )==d_const_map[tn].end() ){
602 std::stringstream ss;
603 ss << "ic_" << tn << "_" << old;
604 d_const_map[tn][ old ] = NodeManager::currentNM()->mkSkolem( ss.str(), tn, "constant created during sort inference" ); //use mkConst???
605 }
606 return d_const_map[tn][ old ];
607 }else if( old.getKind()==kind::BOUND_VARIABLE ){
608 std::stringstream ss;
609 ss << "b_" << old;
610 return NodeManager::currentNM()->mkBoundVar( ss.str(), tn );
611 }else{
612 std::stringstream ss;
613 ss << "i_" << old;
614 return NodeManager::currentNM()->mkSkolem( ss.str(), tn, "created during sort inference" );
615 }
616 }
617
618 Node SortInference::simplifyNode(
619 Node n,
620 std::map<Node, Node>& var_bound,
621 TypeNode tnn,
622 std::map<Node, Node>& model_replace_f,
623 std::map<Node, std::map<TypeNode, Node> >& visited)
624 {
625 std::map< TypeNode, Node >::iterator itv = visited[n].find( tnn );
626 if( itv!=visited[n].end() ){
627 return itv->second;
628 }else{
629 Trace("sort-inference-debug2") << "Simplify " << n << ", type context=" << tnn << std::endl;
630 std::vector< Node > children;
631 std::map< Node, std::map< TypeNode, Node > > new_visited;
632 bool use_new_visited = false;
633 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
634 //recreate based on types of variables
635 std::vector< Node > new_children;
636 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
637 TypeNode tn = getOrCreateTypeForId( d_var_types[n][ n[0][i] ], n[0][i].getType() );
638 Node v = getNewSymbol( n[0][i], tn );
639 Trace("sort-inference-debug2") << "Map variable " << n[0][i] << " to " << v << std::endl;
640 new_children.push_back( v );
641 var_bound[ n[0][i] ] = v;
642 }
643 children.push_back( NodeManager::currentNM()->mkNode( n[0].getKind(), new_children ) );
644 use_new_visited = true;
645 }
646
647 //process children
648 if( n.getMetaKind() == kind::metakind::PARAMETERIZED ){
649 children.push_back( n.getOperator() );
650 }
651 Node op;
652 if( n.hasOperator() ){
653 op = n.getOperator();
654 }
655 bool childChanged = false;
656 TypeNode tnnc;
657 for( size_t i=0; i<n.getNumChildren(); i++ ){
658 bool processChild = true;
659 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
660 processChild =
661 options::userPatternsQuant() == options::UserPatMode::IGNORE
662 ? i == 1
663 : i >= 1;
664 }
665 if( processChild ){
666 if( n.getKind()==kind::APPLY_UF ){
667 Assert(d_op_arg_types.find(op) != d_op_arg_types.end());
668 tnnc = getOrCreateTypeForId( d_op_arg_types[op][i], n[i].getType() );
669 Assert(!tnnc.isNull());
670 }
671 else if (n.getKind() == kind::EQUAL && !n[0].getType().isBoolean()
672 && i == 0)
673 {
674 Assert(d_equality_types.find(n) != d_equality_types.end());
675 tnnc = getOrCreateTypeForId( d_equality_types[n], n[0].getType() );
676 Assert(!tnnc.isNull());
677 }
678 Node nc = simplifyNode(n[i],
679 var_bound,
680 tnnc,
681 model_replace_f,
682 use_new_visited ? new_visited : visited);
683 Trace("sort-inference-debug2") << "Simplify " << i << " " << n[i] << " returned " << nc << std::endl;
684 children.push_back( nc );
685 childChanged = childChanged || nc!=n[i];
686 }
687 }
688
689 //remove from variable bindings
690 Node ret;
691 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
692 //erase from variable bound
693 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
694 Trace("sort-inference-debug2") << "Remove bound for " << n[0][i] << std::endl;
695 var_bound.erase( n[0][i] );
696 }
697 ret = NodeManager::currentNM()->mkNode( n.getKind(), children );
698 }else if( n.getKind()==kind::EQUAL ){
699 TypeNode tn1 = children[0].getType();
700 TypeNode tn2 = children[1].getType();
701 if( !tn1.isComparableTo( tn2 ) ){
702 Trace("sort-inference-warn") << "Sort inference created bad equality: " << children[0] << " = " << children[1] << std::endl;
703 Trace("sort-inference-warn") << " Types : " << children[0].getType() << " " << children[1].getType() << std::endl;
704 Assert(false);
705 }
706 ret = NodeManager::currentNM()->mkNode( kind::EQUAL, children );
707 }else if( n.getKind()==kind::APPLY_UF ){
708 if( d_symbol_map.find( op )==d_symbol_map.end() ){
709 //make the new operator if necessary
710 bool opChanged = false;
711 std::vector< TypeNode > argTypes;
712 for( size_t i=0; i<n.getNumChildren(); i++ ){
713 TypeNode tn = getOrCreateTypeForId( d_op_arg_types[op][i], n[i].getType() );
714 argTypes.push_back( tn );
715 if( tn!=n[i].getType() ){
716 opChanged = true;
717 }
718 }
719 TypeNode retType = getOrCreateTypeForId( d_op_return_types[op], n.getType() );
720 if( retType!=n.getType() ){
721 opChanged = true;
722 }
723 if( opChanged ){
724 std::stringstream ss;
725 ss << "io_" << op;
726 TypeNode typ = NodeManager::currentNM()->mkFunctionType( argTypes, retType );
727 d_symbol_map[op] = NodeManager::currentNM()->mkSkolem( ss.str(), typ, "op created during sort inference" );
728 Trace("setp-model") << "Function " << op << " is replaced with " << d_symbol_map[op] << std::endl;
729 model_replace_f[op] = d_symbol_map[op];
730 }else{
731 d_symbol_map[op] = op;
732 }
733 }
734 children[0] = d_symbol_map[op];
735 // make sure all children have been given proper types
736 for (size_t i = 0, size = n.getNumChildren(); i < size; i++)
737 {
738 TypeNode tn = children[i+1].getType();
739 TypeNode tna = getTypeForId( d_op_arg_types[op][i] );
740 if (!tn.isSubtypeOf(tna))
741 {
742 Trace("sort-inference-warn") << "Sort inference created bad child: " << n << " " << n[i] << " " << tn << " " << tna << std::endl;
743 Assert(false);
744 }
745 }
746 ret = NodeManager::currentNM()->mkNode( kind::APPLY_UF, children );
747 }else{
748 std::map< Node, Node >::iterator it = var_bound.find( n );
749 if( it!=var_bound.end() ){
750 ret = it->second;
751 }else if( n.getKind() == kind::VARIABLE || n.getKind() == kind::SKOLEM ){
752 if( d_symbol_map.find( n )==d_symbol_map.end() ){
753 TypeNode tn = getOrCreateTypeForId( d_op_return_types[n], n.getType() );
754 d_symbol_map[n] = getNewSymbol( n, tn );
755 }
756 ret = d_symbol_map[n];
757 }else if( n.isConst() ){
758 //type is determined by context
759 ret = getNewSymbol( n, tnn );
760 }else if( childChanged ){
761 ret = NodeManager::currentNM()->mkNode( n.getKind(), children );
762 }else{
763 ret = n;
764 }
765 }
766 visited[n][tnn] = ret;
767 return ret;
768 }
769 }
770
771 Node SortInference::mkInjection( TypeNode tn1, TypeNode tn2 ) {
772 std::vector< TypeNode > tns;
773 tns.push_back( tn1 );
774 TypeNode typ = NodeManager::currentNM()->mkFunctionType( tns, tn2 );
775 Node f = NodeManager::currentNM()->mkSkolem( "inj", typ, "injection for monotonicity constraint" );
776 Trace("sort-inference") << "-> Make injection " << f << " from " << tn1 << " to " << tn2 << std::endl;
777 Node v1 = NodeManager::currentNM()->mkBoundVar( "?x", tn1 );
778 Node v2 = NodeManager::currentNM()->mkBoundVar( "?y", tn1 );
779 Node ret = NodeManager::currentNM()->mkNode( kind::FORALL,
780 NodeManager::currentNM()->mkNode( kind::BOUND_VAR_LIST, v1, v2 ),
781 NodeManager::currentNM()->mkNode( kind::OR,
782 NodeManager::currentNM()->mkNode( kind::APPLY_UF, f, v1 ).eqNode( NodeManager::currentNM()->mkNode( kind::APPLY_UF, f, v2 ) ).negate(),
783 v1.eqNode( v2 ) ) );
784 ret = theory::Rewriter::rewrite( ret );
785 return ret;
786 }
787
788 int SortInference::getSortId( Node n ) {
789 Node op = n.getKind()==kind::APPLY_UF ? n.getOperator() : n;
790 if( d_op_return_types.find( op )!=d_op_return_types.end() ){
791 return d_type_union_find.getRepresentative( d_op_return_types[op] );
792 }else{
793 return 0;
794 }
795 }
796
797 int SortInference::getSortId( Node f, Node v ) {
798 if( d_var_types.find( f )!=d_var_types.end() ){
799 return d_type_union_find.getRepresentative( d_var_types[f][v] );
800 }else{
801 return 0;
802 }
803 }
804
805 void SortInference::setSkolemVar( Node f, Node v, Node sk ){
806 Trace("sort-inference-temp") << "Set skolem var for " << f << ", variable " << v << std::endl;
807 if( isWellSortedFormula( f ) && d_var_types.find( f )==d_var_types.end() ){
808 //calculate the sort for variables if not done so already
809 std::map< Node, Node > var_bound;
810 std::map< Node, int > visited;
811 process( f, var_bound, visited );
812 }
813 d_op_return_types[sk] = getSortId( f, v );
814 Trace("sort-inference-temp") << "Set skolem sort id for " << sk << " to " << d_op_return_types[sk] << std::endl;
815 }
816
817 bool SortInference::isWellSortedFormula( Node n ) {
818 if( n.getType().isBoolean() && n.getKind()!=kind::APPLY_UF ){
819 for( unsigned i=0; i<n.getNumChildren(); i++ ){
820 if( !isWellSortedFormula( n[i] ) ){
821 return false;
822 }
823 }
824 return true;
825 }else{
826 return isWellSorted( n );
827 }
828 }
829
830 bool SortInference::isWellSorted( Node n ) {
831 if( getSortId( n )==0 ){
832 return false;
833 }else{
834 if( n.getKind()==kind::APPLY_UF ){
835 for( unsigned i=0; i<n.getNumChildren(); i++ ){
836 int s1 = getSortId( n[i] );
837 int s2 = d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] );
838 if( s1!=s2 ){
839 return false;
840 }
841 if( !isWellSorted( n[i] ) ){
842 return false;
843 }
844 }
845 }
846 return true;
847 }
848 }
849
850 void SortInference::getSortConstraints( Node n, UnionFind& uf ) {
851 if( n.getKind()==kind::APPLY_UF ){
852 for( unsigned i=0; i<n.getNumChildren(); i++ ){
853 getSortConstraints( n[i], uf );
854 uf.setEqual( getSortId( n[i] ), d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] ) );
855 }
856 }
857 }
858
859 bool SortInference::isMonotonic( TypeNode tn ) {
860 Assert(tn.isSort());
861 return d_non_monotonic_sorts_orig.find( tn )==d_non_monotonic_sorts_orig.end();
862 }
863
864 }/* CVC4 namespace */