Sort inference does not handle APPLY_UF when higher-order is enabled (#4138)
[cvc5.git] / src / theory / sort_inference.cpp
1 /********************* */
2 /*! \file sort_inference.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds, Paul Meng, Kshitij Bansal
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Sort inference module
13 **
14 ** This class implements sort inference, based on a simple algorithm:
15 ** First, we assume all functions and predicates have distinct uninterpreted types.
16 ** One pass is made through the input assertions, while a union-find data structure
17 ** maintains necessary information regarding constraints on these types.
18 **/
19
20 #include "theory/sort_inference.h"
21
22 #include <vector>
23
24 #include "options/quantifiers_options.h"
25 #include "options/smt_options.h"
26 #include "options/uf_options.h"
27 #include "proof/proof_manager.h"
28 #include "theory/rewriter.h"
29 #include "theory/quantifiers/quant_util.h"
30
31 using namespace CVC4;
32 using namespace CVC4::kind;
33 using namespace std;
34
35 namespace CVC4 {
36
37 void SortInference::UnionFind::print(const char * c){
38 for( std::map< int, int >::iterator it = d_eqc.begin(); it != d_eqc.end(); ++it ){
39 Trace(c) << "s_" << it->first << " = s_" << it->second << ", ";
40 }
41 for( unsigned i=0; i<d_deq.size(); i++ ){
42 Trace(c) << "s_" << d_deq[i].first << " != s_" << d_deq[i].second << ", ";
43 }
44 Trace(c) << std::endl;
45 }
46 void SortInference::UnionFind::set( UnionFind& c ) {
47 clear();
48 for( std::map< int, int >::iterator it = c.d_eqc.begin(); it != c.d_eqc.end(); ++it ){
49 d_eqc[ it->first ] = it->second;
50 }
51 d_deq.insert( d_deq.end(), c.d_deq.begin(), c.d_deq.end() );
52 }
53 int SortInference::UnionFind::getRepresentative( int t ){
54 std::map< int, int >::iterator it = d_eqc.find( t );
55 if( it==d_eqc.end() || it->second==t ){
56 return t;
57 }else{
58 int rt = getRepresentative( it->second );
59 d_eqc[t] = rt;
60 return rt;
61 }
62 }
63 void SortInference::UnionFind::setEqual( int t1, int t2 ){
64 if( t1!=t2 ){
65 int rt1 = getRepresentative( t1 );
66 int rt2 = getRepresentative( t2 );
67 if( rt1>rt2 ){
68 d_eqc[rt1] = rt2;
69 }else{
70 d_eqc[rt2] = rt1;
71 }
72 }
73 }
74 bool SortInference::UnionFind::isValid() {
75 for( unsigned i=0; i<d_deq.size(); i++ ){
76 if( areEqual( d_deq[i].first, d_deq[i].second ) ){
77 return false;
78 }
79 }
80 return true;
81 }
82
83
84 void SortInference::recordSubsort( TypeNode tn, int s ){
85 s = d_type_union_find.getRepresentative( s );
86 if( std::find( d_sub_sorts.begin(), d_sub_sorts.end(), s )==d_sub_sorts.end() ){
87 d_sub_sorts.push_back( s );
88 d_type_sub_sorts[tn].push_back( s );
89 }
90 }
91
92 void SortInference::printSort( const char* c, int t ){
93 int rt = d_type_union_find.getRepresentative( t );
94 if( d_type_types.find( rt )!=d_type_types.end() ){
95 Trace(c) << d_type_types[rt];
96 }else{
97 Trace(c) << "s_" << rt;
98 }
99 }
100
101 void SortInference::reset() {
102 d_sub_sorts.clear();
103 d_non_monotonic_sorts.clear();
104 d_type_sub_sorts.clear();
105 //reset info
106 d_sortCount = 1;
107 d_type_union_find.clear();
108 d_type_types.clear();
109 d_id_for_types.clear();
110 d_op_return_types.clear();
111 d_op_arg_types.clear();
112 d_var_types.clear();
113 //for rewriting
114 d_symbol_map.clear();
115 d_const_map.clear();
116 }
117
118 void SortInference::initialize(const std::vector<Node>& assertions)
119 {
120 Trace("sort-inference-proc") << "Calculating sort inference..." << std::endl;
121 // process all assertions
122 std::map<Node, int> visited;
123 NodeManager * nm = NodeManager::currentNM();
124 int btId = getIdForType( nm->booleanType() );
125 for (const Node& a : assertions)
126 {
127 Trace("sort-inference-debug") << "Process " << a << std::endl;
128 std::map<Node, Node> var_bound;
129 int pid = process(a, var_bound, visited);
130 // the type of the topmost term must be Boolean
131 setEqual( pid, btId );
132 }
133 Trace("sort-inference-proc") << "...done" << std::endl;
134 for (const std::pair<const Node, int>& rt : d_op_return_types)
135 {
136 Trace("sort-inference") << rt.first << " : ";
137 TypeNode retTn = rt.first.getType();
138 if (!d_op_arg_types[rt.first].empty())
139 {
140 Trace("sort-inference") << "( ";
141 for (size_t i = 0; i < d_op_arg_types[rt.first].size(); i++)
142 {
143 recordSubsort(retTn[i], d_op_arg_types[rt.first][i]);
144 printSort("sort-inference", d_op_arg_types[rt.first][i]);
145 Trace("sort-inference") << " ";
146 }
147 Trace("sort-inference") << ") -> ";
148 retTn = retTn[(int)retTn.getNumChildren() - 1];
149 }
150 recordSubsort(retTn, rt.second);
151 printSort("sort-inference", rt.second);
152 Trace("sort-inference") << std::endl;
153 }
154 for (std::pair<const Node, std::map<Node, int> >& vt : d_var_types)
155 {
156 Trace("sort-inference")
157 << "Quantified formula : " << vt.first << " : " << std::endl;
158 for (const Node& v : vt.first[0])
159 {
160 recordSubsort(v.getType(), vt.second[v]);
161 printSort("sort-inference", vt.second[v]);
162 Trace("sort-inference") << std::endl;
163 }
164 Trace("sort-inference") << std::endl;
165 }
166
167 // determine monotonicity of sorts
168 Trace("sort-inference-proc")
169 << "Calculating monotonicty for subsorts..." << std::endl;
170 std::map<Node, std::map<int, bool> > visitedm;
171 for (const Node& a : assertions)
172 {
173 Trace("sort-inference-debug")
174 << "Process monotonicity for " << a << std::endl;
175 std::map<Node, Node> var_bound;
176 processMonotonic(a, true, true, var_bound, visitedm);
177 }
178 Trace("sort-inference-proc") << "...done" << std::endl;
179
180 Trace("sort-inference") << "We have " << d_sub_sorts.size()
181 << " sub-sorts : " << std::endl;
182 for (unsigned i = 0, size = d_sub_sorts.size(); i < size; i++)
183 {
184 printSort("sort-inference", d_sub_sorts[i]);
185 if (d_type_types.find(d_sub_sorts[i]) != d_type_types.end())
186 {
187 Trace("sort-inference") << " is interpreted." << std::endl;
188 }
189 else if (d_non_monotonic_sorts.find(d_sub_sorts[i])
190 == d_non_monotonic_sorts.end())
191 {
192 Trace("sort-inference") << " is monotonic." << std::endl;
193 }
194 else
195 {
196 Trace("sort-inference") << " is not monotonic." << std::endl;
197 }
198 }
199 }
200
201 Node SortInference::simplify(Node n,
202 std::map<Node, Node>& model_replace_f,
203 std::map<Node, std::map<TypeNode, Node> >& visited)
204 {
205 Trace("sort-inference-debug") << "Simplify " << n << std::endl;
206 std::map<Node, Node> var_bound;
207 TypeNode tnn;
208 Node ret = simplifyNode(n, var_bound, tnn, model_replace_f, visited);
209 ret = theory::Rewriter::rewrite(ret);
210 return ret;
211 }
212
213 void SortInference::getNewAssertions(std::vector<Node>& new_asserts)
214 {
215 NodeManager* nm = NodeManager::currentNM();
216 // now, ensure constants are distinct
217 for (const std::pair<const TypeNode, std::map<Node, Node> >& cm : d_const_map)
218 {
219 std::vector<Node> consts;
220 for (const std::pair<const Node, Node>& c : cm.second)
221 {
222 Assert(c.first.isConst());
223 consts.push_back(c.second);
224 }
225 // add lemma enforcing introduced constants to be distinct
226 if (consts.size() > 1)
227 {
228 Node distinct_const = nm->mkNode(kind::DISTINCT, consts);
229 Trace("sort-inference-rewrite")
230 << "Add the constant distinctness lemma: " << std::endl;
231 Trace("sort-inference-rewrite") << " " << distinct_const << std::endl;
232 new_asserts.push_back(distinct_const);
233 }
234 }
235
236 // enforce constraints based on monotonicity
237 Trace("sort-inference-proc") << "Enforce monotonicity..." << std::endl;
238
239 for (const std::pair<const TypeNode, std::vector<int> >& tss :
240 d_type_sub_sorts)
241 {
242 int nmonSort = -1;
243 unsigned nsorts = tss.second.size();
244 for (unsigned i = 0; i < nsorts; i++)
245 {
246 if (d_non_monotonic_sorts.find(tss.second[i])
247 != d_non_monotonic_sorts.end())
248 {
249 nmonSort = tss.second[i];
250 break;
251 }
252 }
253 if (nmonSort != -1)
254 {
255 std::vector<Node> injections;
256 TypeNode base_tn = getOrCreateTypeForId(nmonSort, tss.first);
257 for (unsigned i = 0; i < nsorts; i++)
258 {
259 if (tss.second[i] != nmonSort)
260 {
261 TypeNode new_tn = getOrCreateTypeForId(tss.second[i], tss.first);
262 // make injection to nmonSort
263 Node a1 = mkInjection(new_tn, base_tn);
264 injections.push_back(a1);
265 if (d_non_monotonic_sorts.find(tss.second[i])
266 != d_non_monotonic_sorts.end())
267 {
268 // also must make injection from nmonSort to this
269 Node a2 = mkInjection(base_tn, new_tn);
270 injections.push_back(a2);
271 }
272 }
273 }
274 if (Trace.isOn("sort-inference-rewrite"))
275 {
276 Trace("sort-inference-rewrite")
277 << "Add the following injections for " << tss.first
278 << " to ensure consistency wrt non-monotonic sorts : " << std::endl;
279 for (const Node& i : injections)
280 {
281 Trace("sort-inference-rewrite") << " " << i << std::endl;
282 }
283 }
284 new_asserts.insert(
285 new_asserts.end(), injections.begin(), injections.end());
286 }
287 }
288 Trace("sort-inference-proc") << "...done" << std::endl;
289 // no sub-sort information is stored
290 reset();
291 Trace("sort-inference-debug") << "Finished sort inference" << std::endl;
292 }
293
294 void SortInference::computeMonotonicity(const std::vector<Node>& assertions)
295 {
296 std::map<Node, std::map<int, bool> > visitedmt;
297 Trace("sort-inference-proc")
298 << "Calculating monotonicty for types..." << std::endl;
299 for (const Node& a : assertions)
300 {
301 Trace("sort-inference-debug")
302 << "Process type monotonicity for " << a << std::endl;
303 std::map<Node, Node> var_bound;
304 processMonotonic(a, true, true, var_bound, visitedmt, true);
305 }
306 Trace("sort-inference-proc") << "...done" << std::endl;
307 }
308
309 void SortInference::setEqual( int t1, int t2 ){
310 if( t1!=t2 ){
311 int rt1 = d_type_union_find.getRepresentative( t1 );
312 int rt2 = d_type_union_find.getRepresentative( t2 );
313 if( rt1!=rt2 ){
314 Trace("sort-inference-debug") << "Set equal : ";
315 printSort( "sort-inference-debug", rt1 );
316 Trace("sort-inference-debug") << " ";
317 printSort( "sort-inference-debug", rt2 );
318 Trace("sort-inference-debug") << std::endl;
319 /*
320 d_type_eq_class[rt1].insert( d_type_eq_class[rt1].end(), d_type_eq_class[rt2].begin(), d_type_eq_class[rt2].end() );
321 d_type_eq_class[rt2].clear();
322 Trace("sort-inference-debug") << "EqClass : { ";
323 for( int i=0; i<(int)d_type_eq_class[rt1].size(); i++ ){
324 Trace("sort-inference-debug") << d_type_eq_class[rt1][i] << ", ";
325 }
326 Trace("sort-inference-debug") << "}" << std::endl;
327 */
328 if( rt2>rt1 ){
329 //swap
330 int swap = rt1;
331 rt1 = rt2;
332 rt2 = swap;
333 }
334 std::map< int, TypeNode >::iterator it1 = d_type_types.find( rt1 );
335 if( it1!=d_type_types.end() ){
336 if( d_type_types.find( rt2 )==d_type_types.end() ){
337 d_type_types[rt2] = it1->second;
338 d_type_types.erase( rt1 );
339 }else{
340 Trace("sort-inference-debug") << "...fail : associated with types " << d_type_types[rt1] << " and " << d_type_types[rt2] << std::endl;
341 return;
342 }
343 }
344 d_type_union_find.d_eqc[rt1] = rt2;
345 }
346 }
347 }
348
349 int SortInference::getIdForType( TypeNode tn ){
350 //register the return type
351 std::map< TypeNode, int >::iterator it = d_id_for_types.find( tn );
352 if( it==d_id_for_types.end() ){
353 int sc = d_sortCount;
354 d_type_types[d_sortCount] = tn;
355 d_id_for_types[tn] = d_sortCount;
356 d_sortCount++;
357 return sc;
358 }else{
359 return it->second;
360 }
361 }
362
363 int SortInference::process( Node n, std::map< Node, Node >& var_bound, std::map< Node, int >& visited ){
364 std::map< Node, int >::iterator itv = visited.find( n );
365 if( itv!=visited.end() ){
366 return itv->second;
367 }else{
368 //add to variable bindings
369 bool use_new_visited = false;
370 std::map< Node, int > new_visited;
371 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
372 if( d_var_types.find( n )!=d_var_types.end() ){
373 return getIdForType( n.getType() );
374 }else{
375 //apply sort inference to quantified variables
376 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
377 TypeNode nitn = n[0][i].getType();
378 if( !nitn.isSort() )
379 {
380 // If the variable is of an interpreted sort, we assume the
381 // the sort of the variable will stay the same sort.
382 d_var_types[n][n[0][i]] = getIdForType( nitn );
383 }
384 else
385 {
386 // If it is of an uninterpreted sort, infer subsorts.
387 d_var_types[n][n[0][i]] = d_sortCount;
388 d_sortCount++;
389 }
390 var_bound[ n[0][i] ] = n;
391 }
392 }
393 use_new_visited = true;
394 }
395
396 //process children
397 std::vector< Node > children;
398 std::vector< int > child_types;
399 for( size_t i=0; i<n.getNumChildren(); i++ ){
400 bool processChild = true;
401 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
402 processChild =
403 options::userPatternsQuant() == options::UserPatMode::IGNORE
404 ? i == 1
405 : i >= 1;
406 }
407 if( processChild ){
408 children.push_back( n[i] );
409 child_types.push_back( process( n[i], var_bound, use_new_visited ? new_visited : visited ) );
410 }
411 }
412
413 //remove from variable bindings
414 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
415 //erase from variable bound
416 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
417 var_bound.erase( n[0][i] );
418 }
419 }
420 Trace("sort-inference-debug") << "...Process " << n << std::endl;
421
422 int retType;
423 if( n.getKind()==kind::EQUAL && !n[0].getType().isBoolean() ){
424 Trace("sort-inference-debug") << "For equality " << n << ", set equal types from : " << n[0].getType() << " " << n[1].getType() << std::endl;
425 //if original types are mixed (e.g. Int/Real), don't commit type equality in either direction
426 if( n[0].getType()!=n[1].getType() ){
427 //for now, assume the original types
428 for( unsigned i=0; i<2; i++ ){
429 int ct = getIdForType( n[i].getType() );
430 setEqual( child_types[i], ct );
431 }
432 }else{
433 //we only require that the left and right hand side must be equal
434 setEqual( child_types[0], child_types[1] );
435 }
436 d_equality_types[n] = child_types[0];
437 retType = getIdForType( n.getType() );
438 }
439 else if (isHandledApplyUf(n.getKind()))
440 {
441 Node op = n.getOperator();
442 TypeNode tn_op = op.getType();
443 if( d_op_return_types.find( op )==d_op_return_types.end() ){
444 if( n.getType().isBoolean() ){
445 //use booleans
446 d_op_return_types[op] = getIdForType( n.getType() );
447 }else{
448 //assign arbitrary sort for return type
449 d_op_return_types[op] = d_sortCount;
450 d_sortCount++;
451 }
452 // d_type_eq_class[d_sortCount].push_back( op );
453 // assign arbitrary sort for argument types
454 for( size_t i=0; i<n.getNumChildren(); i++ ){
455 d_op_arg_types[op].push_back(d_sortCount);
456 d_sortCount++;
457 }
458 }
459 for( size_t i=0; i<n.getNumChildren(); i++ ){
460 //the argument of the operator must match the return type of the subterm
461 if( n[i].getType()!=tn_op[i] ){
462 //if type mismatch, assume original types
463 Trace("sort-inference-debug") << "Argument " << i << " of " << op << " " << n[i] << " has type " << n[i].getType();
464 Trace("sort-inference-debug") << ", while operator arg has type " << tn_op[i] << std::endl;
465 int ct1 = getIdForType( n[i].getType() );
466 setEqual( child_types[i], ct1 );
467 int ct2 = getIdForType( tn_op[i] );
468 setEqual( d_op_arg_types[op][i], ct2 );
469 }else{
470 setEqual( child_types[i], d_op_arg_types[op][i] );
471 }
472 }
473 //return type is the return type
474 retType = d_op_return_types[op];
475 }else{
476 std::map< Node, Node >::iterator it = var_bound.find( n );
477 if( it!=var_bound.end() ){
478 Trace("sort-inference-debug") << n << " is a bound variable." << std::endl;
479 //the return type was specified while binding
480 retType = d_var_types[it->second][n];
481 }else if( n.isVar() ){
482 Trace("sort-inference-debug") << n << " is a variable." << std::endl;
483 if( d_op_return_types.find( n )==d_op_return_types.end() ){
484 //assign arbitrary sort
485 d_op_return_types[n] = d_sortCount;
486 d_sortCount++;
487 // d_type_eq_class[d_sortCount].push_back( n );
488 }
489 retType = d_op_return_types[n];
490 }else if( n.isConst() ){
491 Trace("sort-inference-debug") << n << " is a constant." << std::endl;
492 //can be any type we want
493 retType = d_sortCount;
494 d_sortCount++;
495 }else{
496 Trace("sort-inference-debug") << n << " is a interpreted symbol." << std::endl;
497 //it is an interpreted term
498 for( size_t i=0; i<children.size(); i++ ){
499 Trace("sort-inference-debug") << children[i] << " forced to have " << children[i].getType() << std::endl;
500 //must enforce the actual type of the operator on the children
501 int ct = getIdForType( children[i].getType() );
502 setEqual( child_types[i], ct );
503 }
504 //return type must be the actual return type
505 retType = getIdForType( n.getType() );
506 }
507 }
508 Trace("sort-inference-debug") << "...Type( " << n << " ) = ";
509 printSort("sort-inference-debug", retType );
510 Trace("sort-inference-debug") << std::endl;
511 visited[n] = retType;
512 return retType;
513 }
514 }
515
516 void SortInference::processMonotonic( Node n, bool pol, bool hasPol, std::map< Node, Node >& var_bound, std::map< Node, std::map< int, bool > >& visited, bool typeMode ) {
517 int pindex = hasPol ? ( pol ? 1 : -1 ) : 0;
518 if( visited[n].find( pindex )==visited[n].end() ){
519 visited[n][pindex] = true;
520 Trace("sort-inference-debug") << "...Process monotonic " << pol << " " << hasPol << " " << n << std::endl;
521 if( n.getKind()==kind::FORALL ){
522 //only consider variables universally if it is possible this quantified formula is asserted positively
523 if( !hasPol || pol ){
524 for( unsigned i=0; i<n[0].getNumChildren(); i++ ){
525 var_bound[n[0][i]] = n;
526 }
527 }
528 processMonotonic( n[1], pol, hasPol, var_bound, visited, typeMode );
529 if( !hasPol || pol ){
530 for( unsigned i=0; i<n[0].getNumChildren(); i++ ){
531 var_bound.erase( n[0][i] );
532 }
533 }
534 return;
535 }else if( n.getKind()==kind::EQUAL ){
536 if( !hasPol || pol ){
537 for( unsigned i=0; i<2; i++ ){
538 if( var_bound.find( n[i] )!=var_bound.end() ){
539 if( !typeMode ){
540 int sid = getSortId( var_bound[n[i]], n[i] );
541 d_non_monotonic_sorts[sid] = true;
542 }else{
543 d_non_monotonic_sorts_orig[n[i].getType()] = true;
544 }
545 break;
546 }
547 }
548 }
549 }
550 for( unsigned i=0; i<n.getNumChildren(); i++ ){
551 bool npol;
552 bool nhasPol;
553 theory::QuantPhaseReq::getPolarity( n, i, hasPol, pol, nhasPol, npol );
554 processMonotonic( n[i], npol, nhasPol, var_bound, visited, typeMode );
555 }
556 }
557 }
558
559
560 TypeNode SortInference::getOrCreateTypeForId( int t, TypeNode pref ){
561 int rt = d_type_union_find.getRepresentative( t );
562 if( d_type_types.find( rt )!=d_type_types.end() ){
563 return d_type_types[rt];
564 }else{
565 TypeNode retType;
566 // See if we can assign pref. This is an optimization for reusing an
567 // uninterpreted sort as the first subsort, so that fewer symbols needed
568 // to be rewritten in the sort-inferred signature. Notice we only assign
569 // pref here if it is an uninterpreted sort.
570 if (!pref.isNull() && d_id_for_types.find(pref) == d_id_for_types.end()
571 && pref.isSort())
572 {
573 retType = pref;
574 }else{
575 //must create new type
576 std::stringstream ss;
577 ss << "it_" << t << "_" << pref;
578 retType = NodeManager::currentNM()->mkSort( ss.str() );
579 }
580 Trace("sort-inference") << "-> Make type " << retType << " to correspond to ";
581 printSort("sort-inference", t );
582 Trace("sort-inference") << std::endl;
583 d_id_for_types[ retType ] = rt;
584 d_type_types[ rt ] = retType;
585 return retType;
586 }
587 }
588
589 TypeNode SortInference::getTypeForId( int t ){
590 int rt = d_type_union_find.getRepresentative( t );
591 if( d_type_types.find( rt )!=d_type_types.end() ){
592 return d_type_types[rt];
593 }else{
594 return TypeNode::null();
595 }
596 }
597
598 Node SortInference::getNewSymbol( Node old, TypeNode tn ){
599 // if no sort was inferred for this node, return original
600 if( tn.isNull() || tn.isComparableTo( old.getType() ) ){
601 return old;
602 }else if( old.isConst() ){
603 //must make constant of type tn
604 if( d_const_map[tn].find( old )==d_const_map[tn].end() ){
605 std::stringstream ss;
606 ss << "ic_" << tn << "_" << old;
607 d_const_map[tn][ old ] = NodeManager::currentNM()->mkSkolem( ss.str(), tn, "constant created during sort inference" ); //use mkConst???
608 }
609 return d_const_map[tn][ old ];
610 }else if( old.getKind()==kind::BOUND_VARIABLE ){
611 std::stringstream ss;
612 ss << "b_" << old;
613 return NodeManager::currentNM()->mkBoundVar( ss.str(), tn );
614 }else{
615 std::stringstream ss;
616 ss << "i_" << old;
617 return NodeManager::currentNM()->mkSkolem( ss.str(), tn, "created during sort inference" );
618 }
619 }
620
621 Node SortInference::simplifyNode(
622 Node n,
623 std::map<Node, Node>& var_bound,
624 TypeNode tnn,
625 std::map<Node, Node>& model_replace_f,
626 std::map<Node, std::map<TypeNode, Node> >& visited)
627 {
628 std::map< TypeNode, Node >::iterator itv = visited[n].find( tnn );
629 if( itv!=visited[n].end() ){
630 return itv->second;
631 }else{
632 Trace("sort-inference-debug2") << "Simplify " << n << ", type context=" << tnn << std::endl;
633 std::vector< Node > children;
634 std::map< Node, std::map< TypeNode, Node > > new_visited;
635 bool use_new_visited = false;
636 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
637 //recreate based on types of variables
638 std::vector< Node > new_children;
639 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
640 TypeNode tn = getOrCreateTypeForId( d_var_types[n][ n[0][i] ], n[0][i].getType() );
641 Node v = getNewSymbol( n[0][i], tn );
642 Trace("sort-inference-debug2") << "Map variable " << n[0][i] << " to " << v << std::endl;
643 new_children.push_back( v );
644 var_bound[ n[0][i] ] = v;
645 }
646 children.push_back( NodeManager::currentNM()->mkNode( n[0].getKind(), new_children ) );
647 use_new_visited = true;
648 }
649
650 //process children
651 if( n.getMetaKind() == kind::metakind::PARAMETERIZED ){
652 children.push_back( n.getOperator() );
653 }
654 Node op;
655 if( n.hasOperator() ){
656 op = n.getOperator();
657 }
658 bool childChanged = false;
659 TypeNode tnnc;
660 for( size_t i=0; i<n.getNumChildren(); i++ ){
661 bool processChild = true;
662 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
663 processChild =
664 options::userPatternsQuant() == options::UserPatMode::IGNORE
665 ? i == 1
666 : i >= 1;
667 }
668 if( processChild ){
669 if (isHandledApplyUf(n.getKind()))
670 {
671 Assert(d_op_arg_types.find(op) != d_op_arg_types.end());
672 tnnc = getOrCreateTypeForId( d_op_arg_types[op][i], n[i].getType() );
673 Assert(!tnnc.isNull());
674 }
675 else if (n.getKind() == kind::EQUAL && !n[0].getType().isBoolean()
676 && i == 0)
677 {
678 Assert(d_equality_types.find(n) != d_equality_types.end());
679 tnnc = getOrCreateTypeForId( d_equality_types[n], n[0].getType() );
680 Assert(!tnnc.isNull());
681 }
682 Node nc = simplifyNode(n[i],
683 var_bound,
684 tnnc,
685 model_replace_f,
686 use_new_visited ? new_visited : visited);
687 Trace("sort-inference-debug2") << "Simplify " << i << " " << n[i] << " returned " << nc << std::endl;
688 children.push_back( nc );
689 childChanged = childChanged || nc!=n[i];
690 }
691 }
692
693 //remove from variable bindings
694 Node ret;
695 if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
696 //erase from variable bound
697 for( size_t i=0; i<n[0].getNumChildren(); i++ ){
698 Trace("sort-inference-debug2") << "Remove bound for " << n[0][i] << std::endl;
699 var_bound.erase( n[0][i] );
700 }
701 ret = NodeManager::currentNM()->mkNode( n.getKind(), children );
702 }else if( n.getKind()==kind::EQUAL ){
703 TypeNode tn1 = children[0].getType();
704 TypeNode tn2 = children[1].getType();
705 if( !tn1.isComparableTo( tn2 ) ){
706 Trace("sort-inference-warn") << "Sort inference created bad equality: " << children[0] << " = " << children[1] << std::endl;
707 Trace("sort-inference-warn") << " Types : " << children[0].getType() << " " << children[1].getType() << std::endl;
708 Assert(false);
709 }
710 ret = NodeManager::currentNM()->mkNode( kind::EQUAL, children );
711 }
712 else if (isHandledApplyUf(n.getKind()))
713 {
714 if( d_symbol_map.find( op )==d_symbol_map.end() ){
715 //make the new operator if necessary
716 bool opChanged = false;
717 std::vector< TypeNode > argTypes;
718 for( size_t i=0; i<n.getNumChildren(); i++ ){
719 TypeNode tn = getOrCreateTypeForId( d_op_arg_types[op][i], n[i].getType() );
720 argTypes.push_back( tn );
721 if( tn!=n[i].getType() ){
722 opChanged = true;
723 }
724 }
725 TypeNode retType = getOrCreateTypeForId( d_op_return_types[op], n.getType() );
726 if( retType!=n.getType() ){
727 opChanged = true;
728 }
729 if( opChanged ){
730 std::stringstream ss;
731 ss << "io_" << op;
732 TypeNode typ = NodeManager::currentNM()->mkFunctionType( argTypes, retType );
733 d_symbol_map[op] = NodeManager::currentNM()->mkSkolem( ss.str(), typ, "op created during sort inference" );
734 Trace("setp-model") << "Function " << op << " is replaced with " << d_symbol_map[op] << std::endl;
735 model_replace_f[op] = d_symbol_map[op];
736 }else{
737 d_symbol_map[op] = op;
738 }
739 }
740 children[0] = d_symbol_map[op];
741 // make sure all children have been given proper types
742 for (size_t i = 0, size = n.getNumChildren(); i < size; i++)
743 {
744 TypeNode tn = children[i+1].getType();
745 TypeNode tna = getTypeForId( d_op_arg_types[op][i] );
746 if (!tn.isSubtypeOf(tna))
747 {
748 Trace("sort-inference-warn") << "Sort inference created bad child: " << n << " " << n[i] << " " << tn << " " << tna << std::endl;
749 Assert(false);
750 }
751 }
752 ret = NodeManager::currentNM()->mkNode( kind::APPLY_UF, children );
753 }else{
754 std::map< Node, Node >::iterator it = var_bound.find( n );
755 if( it!=var_bound.end() ){
756 ret = it->second;
757 }else if( n.getKind() == kind::VARIABLE || n.getKind() == kind::SKOLEM ){
758 if( d_symbol_map.find( n )==d_symbol_map.end() ){
759 TypeNode tn = getOrCreateTypeForId( d_op_return_types[n], n.getType() );
760 d_symbol_map[n] = getNewSymbol( n, tn );
761 }
762 ret = d_symbol_map[n];
763 }else if( n.isConst() ){
764 //type is determined by context
765 ret = getNewSymbol( n, tnn );
766 }else if( childChanged ){
767 ret = NodeManager::currentNM()->mkNode( n.getKind(), children );
768 }else{
769 ret = n;
770 }
771 }
772 visited[n][tnn] = ret;
773 return ret;
774 }
775 }
776
777 Node SortInference::mkInjection( TypeNode tn1, TypeNode tn2 ) {
778 std::vector< TypeNode > tns;
779 tns.push_back( tn1 );
780 TypeNode typ = NodeManager::currentNM()->mkFunctionType( tns, tn2 );
781 Node f = NodeManager::currentNM()->mkSkolem( "inj", typ, "injection for monotonicity constraint" );
782 Trace("sort-inference") << "-> Make injection " << f << " from " << tn1 << " to " << tn2 << std::endl;
783 Node v1 = NodeManager::currentNM()->mkBoundVar( "?x", tn1 );
784 Node v2 = NodeManager::currentNM()->mkBoundVar( "?y", tn1 );
785 Node ret = NodeManager::currentNM()->mkNode( kind::FORALL,
786 NodeManager::currentNM()->mkNode( kind::BOUND_VAR_LIST, v1, v2 ),
787 NodeManager::currentNM()->mkNode( kind::OR,
788 NodeManager::currentNM()->mkNode( kind::APPLY_UF, f, v1 ).eqNode( NodeManager::currentNM()->mkNode( kind::APPLY_UF, f, v2 ) ).negate(),
789 v1.eqNode( v2 ) ) );
790 ret = theory::Rewriter::rewrite( ret );
791 return ret;
792 }
793
794 int SortInference::getSortId( Node n ) {
795 Node op = n.getKind()==kind::APPLY_UF ? n.getOperator() : n;
796 if( d_op_return_types.find( op )!=d_op_return_types.end() ){
797 return d_type_union_find.getRepresentative( d_op_return_types[op] );
798 }else{
799 return 0;
800 }
801 }
802
803 int SortInference::getSortId( Node f, Node v ) {
804 if( d_var_types.find( f )!=d_var_types.end() ){
805 return d_type_union_find.getRepresentative( d_var_types[f][v] );
806 }else{
807 return 0;
808 }
809 }
810
811 void SortInference::setSkolemVar( Node f, Node v, Node sk ){
812 Trace("sort-inference-temp") << "Set skolem var for " << f << ", variable " << v << std::endl;
813 if( isWellSortedFormula( f ) && d_var_types.find( f )==d_var_types.end() ){
814 //calculate the sort for variables if not done so already
815 std::map< Node, Node > var_bound;
816 std::map< Node, int > visited;
817 process( f, var_bound, visited );
818 }
819 d_op_return_types[sk] = getSortId( f, v );
820 Trace("sort-inference-temp") << "Set skolem sort id for " << sk << " to " << d_op_return_types[sk] << std::endl;
821 }
822
823 bool SortInference::isWellSortedFormula( Node n ) {
824 if (n.getType().isBoolean() && !isHandledApplyUf(n.getKind()))
825 {
826 for( unsigned i=0; i<n.getNumChildren(); i++ ){
827 if( !isWellSortedFormula( n[i] ) ){
828 return false;
829 }
830 }
831 return true;
832 }else{
833 return isWellSorted( n );
834 }
835 }
836
837 bool SortInference::isWellSorted( Node n ) {
838 if( getSortId( n )==0 ){
839 return false;
840 }else{
841 if (isHandledApplyUf(n.getKind()))
842 {
843 for( unsigned i=0; i<n.getNumChildren(); i++ ){
844 int s1 = getSortId( n[i] );
845 int s2 = d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] );
846 if( s1!=s2 ){
847 return false;
848 }
849 if( !isWellSorted( n[i] ) ){
850 return false;
851 }
852 }
853 }
854 return true;
855 }
856 }
857
858 bool SortInference::isMonotonic( TypeNode tn ) {
859 Assert(tn.isSort());
860 return d_non_monotonic_sorts_orig.find( tn )==d_non_monotonic_sorts_orig.end();
861 }
862
863 bool SortInference::isHandledApplyUf(Kind k) const
864 {
865 return k == APPLY_UF && !options::ufHo();
866 }
867
868 }/* CVC4 namespace */