Cleanup some includes (#5847)
[cvc5.git] / src / theory / uf / cardinality_extension.cpp
1 /********************* */
2 /*! \file cardinality_extension.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds, Morgan Deters, Tim King
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
8 ** in the top-level source directory and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Implementation of theory of UF with cardinality.
13 **/
14
15 #include "theory/uf/cardinality_extension.h"
16
17 #include "options/smt_options.h"
18 #include "options/uf_options.h"
19 #include "theory/quantifiers/term_database.h"
20 #include "theory/quantifiers_engine.h"
21 #include "theory/theory_engine.h"
22 #include "theory/theory_model.h"
23 #include "theory/uf/equality_engine.h"
24 #include "theory/uf/theory_uf.h"
25
26 using namespace std;
27 using namespace CVC4::kind;
28 using namespace CVC4::context;
29
30 namespace CVC4 {
31 namespace theory {
32 namespace uf {
33
34 /* These are names are unambigious are we use abbreviations. */
35 typedef CardinalityExtension::SortModel SortModel;
36 typedef SortModel::Region Region;
37 typedef Region::RegionNodeInfo RegionNodeInfo;
38 typedef RegionNodeInfo::DiseqList DiseqList;
39
40 Region::Region(SortModel* cf, context::Context* c)
41 : d_cf( cf )
42 , d_testCliqueSize( c, 0 )
43 , d_splitsSize( c, 0 )
44 , d_testClique( c )
45 , d_splits( c )
46 , d_reps_size( c, 0 )
47 , d_total_diseq_external( c, 0 )
48 , d_total_diseq_internal( c, 0 )
49 , d_valid( c, true ) {}
50
51 Region::~Region() {
52 for(iterator i = begin(), iend = end(); i != iend; ++i) {
53 RegionNodeInfo* regionNodeInfo = (*i).second;
54 delete regionNodeInfo;
55 }
56 d_nodes.clear();
57 }
58
59 void Region::addRep( Node n ) {
60 setRep( n, true );
61 }
62
63 void Region::takeNode( Region* r, Node n ){
64 Assert(!hasRep(n));
65 Assert(r->hasRep(n));
66 //add representative
67 setRep( n, true );
68 //take disequalities from r
69 RegionNodeInfo* rni = r->d_nodes[n];
70 for( int t=0; t<2; t++ ){
71 DiseqList* del = rni->get(t);
72 for(DiseqList::iterator it = del->begin(); it != del->end(); ++it ){
73 if( (*it).second ){
74 r->setDisequal( n, (*it).first, t, false );
75 if( t==0 ){
76 if( hasRep( (*it).first ) ){
77 setDisequal( (*it).first, n, 0, false );
78 setDisequal( (*it).first, n, 1, true );
79 setDisequal( n, (*it).first, 1, true );
80 }else{
81 setDisequal( n, (*it).first, 0, true );
82 }
83 }else{
84 r->setDisequal( (*it).first, n, 1, false );
85 r->setDisequal( (*it).first, n, 0, true );
86 setDisequal( n, (*it).first, 0, true );
87 }
88 }
89 }
90 }
91 //remove representative
92 r->setRep( n, false );
93 }
94
95 void Region::combine( Region* r ){
96 //take all nodes from r
97 for(Region::iterator it = r->d_nodes.begin(); it != r->d_nodes.end(); ++it) {
98 if( it->second->valid() ){
99 setRep( it->first, true );
100 }
101 }
102 for(Region::iterator it = r->d_nodes.begin(); it != r->d_nodes.end(); ++it){
103 if( it->second->valid() ){
104 //take disequalities from r
105 Node n = it->first;
106 RegionNodeInfo* rni = it->second;
107 for( int t=0; t<2; t++ ){
108 RegionNodeInfo::DiseqList* del = rni->get(t);
109 for( RegionNodeInfo::DiseqList::iterator it2 = del->begin(),
110 it2end = del->end(); it2 != it2end; ++it2 ){
111 if( (*it2).second ){
112 if( t==0 && hasRep( (*it2).first ) ){
113 setDisequal( (*it2).first, n, 0, false );
114 setDisequal( (*it2).first, n, 1, true );
115 setDisequal( n, (*it2).first, 1, true );
116 }else{
117 setDisequal( n, (*it2).first, t, true );
118 }
119 }
120 }
121 }
122 }
123 }
124 r->d_valid = false;
125 }
126
127 /** setEqual */
128 void Region::setEqual( Node a, Node b ){
129 Assert(hasRep(a) && hasRep(b));
130 //move disequalities of b over to a
131 for( int t=0; t<2; t++ ){
132 DiseqList* del = d_nodes[b]->get(t);
133 for( DiseqList::iterator it = del->begin(); it != del->end(); ++it ){
134 if( (*it).second ){
135 Node n = (*it).first;
136 //get the region that contains the endpoint of the disequality b != ...
137 Region* nr = d_cf->d_regions[ d_cf->d_regions_map[ n ] ];
138 if( !isDisequal( a, n, t ) ){
139 setDisequal( a, n, t, true );
140 nr->setDisequal( n, a, t, true );
141 }
142 setDisequal( b, n, t, false );
143 nr->setDisequal( n, b, t, false );
144 }
145 }
146 }
147 //remove b from representatives
148 setRep( b, false );
149 }
150
151 void Region::setDisequal( Node n1, Node n2, int type, bool valid ){
152 //Debug("uf-ss-region-debug") << "set disequal " << n1 << " " << n2 << " "
153 // << type << " " << valid << std::endl;
154 //debugPrint("uf-ss-region-debug");
155 //Assert( isDisequal( n1, n2, type )!=valid );
156 if( isDisequal( n1, n2, type )!=valid ){ //DO_THIS: make assertion
157 d_nodes[ n1 ]->get(type)->setDisequal( n2, valid );
158 if( type==0 ){
159 d_total_diseq_external = d_total_diseq_external + ( valid ? 1 : -1 );
160 }else{
161 d_total_diseq_internal = d_total_diseq_internal + ( valid ? 1 : -1 );
162 if( valid ){
163 //if they are both a part of testClique, then remove split
164 if( d_testClique.find( n1 )!=d_testClique.end() && d_testClique[n1] &&
165 d_testClique.find( n2 )!=d_testClique.end() && d_testClique[n2] ){
166 Node eq = NodeManager::currentNM()->mkNode( EQUAL, n1, n2 );
167 if( d_splits.find( eq )!=d_splits.end() && d_splits[ eq ] ){
168 Debug("uf-ss-debug") << "removing split for " << n1 << " " << n2
169 << std::endl;
170 d_splits[ eq ] = false;
171 d_splitsSize = d_splitsSize - 1;
172 }
173 }
174 }
175 }
176 }
177 }
178
179 void Region::setRep( Node n, bool valid ) {
180 Assert(hasRep(n) != valid);
181 if( valid && d_nodes.find( n )==d_nodes.end() ){
182 d_nodes[n] = new RegionNodeInfo(d_cf->d_state.getSatContext());
183 }
184 d_nodes[n]->setValid(valid);
185 d_reps_size = d_reps_size + ( valid ? 1 : -1 );
186 //removing a member of the test clique from this region
187 if( d_testClique.find( n ) != d_testClique.end() && d_testClique[n] ){
188 Assert(!valid);
189 d_testClique[n] = false;
190 d_testCliqueSize = d_testCliqueSize - 1;
191 //remove all splits involving n
192 for( split_iterator it = begin_splits(); it != end_splits(); ++it ){
193 if( (*it).second ){
194 if( (*it).first[0]==n || (*it).first[1]==n ){
195 d_splits[ (*it).first ] = false;
196 d_splitsSize = d_splitsSize - 1;
197 }
198 }
199 }
200 }
201 }
202
203 bool Region::isDisequal( Node n1, Node n2, int type ) {
204 RegionNodeInfo::DiseqList* del = d_nodes[ n1 ]->get(type);
205 return del->isSet(n2) && del->getDisequalityValue(n2);
206 }
207
208 struct sortInternalDegree {
209 Region* r;
210 bool operator() (Node i, Node j) {
211 return (r->getRegionInfo(i)->getNumInternalDisequalities() >
212 r->getRegionInfo(j)->getNumInternalDisequalities());
213 }
214 };
215
216 struct sortExternalDegree {
217 Region* r;
218 bool operator() (Node i,Node j) {
219 return (r->getRegionInfo(i)->getNumExternalDisequalities() >
220 r->getRegionInfo(j)->getNumExternalDisequalities());
221 }
222 };
223
224 int gmcCount = 0;
225
226 bool Region::getMustCombine( int cardinality ){
227 if (d_total_diseq_external >= static_cast<unsigned>(cardinality))
228 {
229 //The number of external disequalities is greater than or equal to
230 //cardinality. Thus, a clique of size cardinality+1 may exist
231 //between nodes in d_regions[i] and other regions Check if this is
232 //actually the case: must have n nodes with outgoing degree
233 //(cardinality+1-n) for some n>0
234 std::vector< int > degrees;
235 for( Region::iterator it = begin(); it != end(); ++it ){
236 RegionNodeInfo* rni = it->second;
237 if( rni->valid() ){
238 if( rni->getNumDisequalities() >= cardinality ){
239 int outDeg = rni->getNumExternalDisequalities();
240 if( outDeg>=cardinality ){
241 //we have 1 node of degree greater than (cardinality)
242 return true;
243 }else if( outDeg>=1 ){
244 degrees.push_back( outDeg );
245 if( (int)degrees.size()>=cardinality ){
246 //we have (cardinality) nodes of degree 1
247 return true;
248 }
249 }
250 }
251 }
252 }
253 gmcCount++;
254 if( gmcCount%100==0 ){
255 Trace("gmc-count") << gmcCount << " " << cardinality
256 << " sample : " << degrees.size() << std::endl;
257 }
258 //this should happen relatively infrequently....
259 std::sort( degrees.begin(), degrees.end() );
260 for( int i=0; i<(int)degrees.size(); i++ ){
261 if( degrees[i]>=cardinality+1-((int)degrees.size()-i) ){
262 return true;
263 }
264 }
265 }
266 return false;
267 }
268
269 bool Region::check( Theory::Effort level, int cardinality,
270 std::vector< Node >& clique ) {
271 if( d_reps_size>unsigned(cardinality) ){
272 if( d_total_diseq_internal==d_reps_size*( d_reps_size - 1 ) ){
273 if( d_reps_size>1 ){
274 //quick clique check, all reps form a clique
275 for( iterator it = begin(); it != end(); ++it ){
276 if( it->second->valid() ){
277 clique.push_back( it->first );
278 }
279 }
280 Trace("quick-clique") << "Found quick clique" << std::endl;
281 return true;
282 }else{
283 return false;
284 }
285 }
286 else if (level==Theory::EFFORT_FULL)
287 {
288 //build test clique, up to size cardinality+1
289 if( d_testCliqueSize<=unsigned(cardinality) ){
290 std::vector< Node > newClique;
291 if( d_testCliqueSize<unsigned(cardinality) ){
292 for( iterator it = begin(); it != end(); ++it ){
293 //if not in the test clique, add it to the set of new members
294 if( it->second->valid() &&
295 ( d_testClique.find( it->first ) == d_testClique.end() ||
296 !d_testClique[ it->first ] ) ){
297 //if( it->second->getNumInternalDisequalities()>cardinality ||
298 // level==Theory::EFFORT_FULL ){
299 newClique.push_back( it->first );
300 //}
301 }
302 }
303 //choose remaining nodes with the highest degrees
304 sortInternalDegree sidObj;
305 sidObj.r = this;
306 std::sort( newClique.begin(), newClique.end(), sidObj );
307 int offset = ( cardinality - d_testCliqueSize ) + 1;
308 newClique.erase( newClique.begin() + offset, newClique.end() );
309 }else{
310 //scan for the highest degree
311 int maxDeg = -1;
312 Node maxNode;
313 for( std::map< Node, RegionNodeInfo* >::iterator
314 it = d_nodes.begin(); it != d_nodes.end(); ++it ){
315 //if not in the test clique, add it to the set of new members
316 if( it->second->valid() &&
317 ( d_testClique.find( it->first )==d_testClique.end() ||
318 !d_testClique[ it->first ] ) ){
319 if( it->second->getNumInternalDisequalities()>maxDeg ){
320 maxDeg = it->second->getNumInternalDisequalities();
321 maxNode = it->first;
322 }
323 }
324 }
325 Assert(maxNode != Node::null());
326 newClique.push_back( maxNode );
327 }
328 //check splits internal to new members
329 for( int j=0; j<(int)newClique.size(); j++ ){
330 Debug("uf-ss-debug") << "Choose to add clique member "
331 << newClique[j] << std::endl;
332 for( int k=(j+1); k<(int)newClique.size(); k++ ){
333 if( !isDisequal( newClique[j], newClique[k], 1 ) ){
334 Node at_j = newClique[j];
335 Node at_k = newClique[k];
336 Node j_eq_k =
337 NodeManager::currentNM()->mkNode( EQUAL, at_j, at_k );
338 d_splits[ j_eq_k ] = true;
339 d_splitsSize = d_splitsSize + 1;
340 }
341 }
342 //check disequalities with old members
343 for( NodeBoolMap::iterator it = d_testClique.begin();
344 it != d_testClique.end(); ++it ){
345 if( (*it).second ){
346 if( !isDisequal( (*it).first, newClique[j], 1 ) ){
347 Node at_it = (*it).first;
348 Node at_j = newClique[j];
349 Node it_eq_j = at_it.eqNode(at_j);
350 d_splits[ it_eq_j ] = true;
351 d_splitsSize = d_splitsSize + 1;
352 }
353 }
354 }
355 }
356 //add new clique members to test clique
357 for( int j=0; j<(int)newClique.size(); j++ ){
358 d_testClique[ newClique[j] ] = true;
359 d_testCliqueSize = d_testCliqueSize + 1;
360 }
361 }
362 // Check if test clique has larger size than cardinality, and
363 // forms a clique.
364 if( d_testCliqueSize >= unsigned(cardinality+1) && d_splitsSize==0 ){
365 //test clique is a clique
366 for( NodeBoolMap::iterator it = d_testClique.begin();
367 it != d_testClique.end(); ++it ){
368 if( (*it).second ){
369 clique.push_back( (*it).first );
370 }
371 }
372 return true;
373 }
374 }
375 }
376 return false;
377 }
378
379 void Region::getNumExternalDisequalities(
380 std::map< Node, int >& num_ext_disequalities ){
381 for( Region::iterator it = begin(); it != end(); ++it ){
382 RegionNodeInfo* rni = it->second;
383 if( rni->valid() ){
384 DiseqList* del = rni->get(0);
385 for( DiseqList::iterator it2 = del->begin(); it2 != del->end(); ++it2 ){
386 if( (*it2).second ){
387 num_ext_disequalities[ (*it2).first ]++;
388 }
389 }
390 }
391 }
392 }
393
394 void Region::debugPrint( const char* c, bool incClique ) {
395 Debug( c ) << "Num reps: " << d_reps_size << std::endl;
396 for( Region::iterator it = begin(); it != end(); ++it ){
397 RegionNodeInfo* rni = it->second;
398 if( rni->valid() ){
399 Node n = it->first;
400 Debug( c ) << " " << n << std::endl;
401 for( int i=0; i<2; i++ ){
402 Debug( c ) << " " << ( i==0 ? "Ext" : "Int" ) << " disequal:";
403 DiseqList* del = rni->get(i);
404 for( DiseqList::iterator it2 = del->begin(); it2 != del->end(); ++it2 ){
405 if( (*it2).second ){
406 Debug( c ) << " " << (*it2).first;
407 }
408 }
409 Debug( c ) << ", total = " << del->size() << std::endl;
410 }
411 }
412 }
413 Debug( c ) << "Total disequal: " << d_total_diseq_external << " external,";
414 Debug( c ) << " " << d_total_diseq_internal << " internal." << std::endl;
415
416 if( incClique ){
417 if( !d_testClique.empty() ){
418 Debug( c ) << "Candidate clique members: " << std::endl;
419 Debug( c ) << " ";
420 for( NodeBoolMap::iterator it = d_testClique.begin();
421 it != d_testClique.end(); ++ it ){
422 if( (*it).second ){
423 Debug( c ) << (*it).first << " ";
424 }
425 }
426 Debug( c ) << ", size = " << d_testCliqueSize << std::endl;
427 }
428 if( !d_splits.empty() ){
429 Debug( c ) << "Required splits: " << std::endl;
430 Debug( c ) << " ";
431 for( NodeBoolMap::iterator it = d_splits.begin(); it != d_splits.end();
432 ++ it ){
433 if( (*it).second ){
434 Debug( c ) << (*it).first << " ";
435 }
436 }
437 Debug( c ) << ", size = " << d_splitsSize << std::endl;
438 }
439 }
440 }
441
442 SortModel::CardinalityDecisionStrategy::CardinalityDecisionStrategy(
443 Node t, context::Context* satContext, Valuation valuation)
444 : DecisionStrategyFmf(satContext, valuation), d_cardinality_term(t)
445 {
446 }
447 Node SortModel::CardinalityDecisionStrategy::mkLiteral(unsigned i)
448 {
449 NodeManager* nm = NodeManager::currentNM();
450 return nm->mkNode(
451 CARDINALITY_CONSTRAINT, d_cardinality_term, nm->mkConst(Rational(i + 1)));
452 }
453 std::string SortModel::CardinalityDecisionStrategy::identify() const
454 {
455 return std::string("uf_card");
456 }
457
458 SortModel::SortModel(Node n,
459 TheoryState& state,
460 TheoryInferenceManager& im,
461 CardinalityExtension* thss)
462 : d_type(n.getType()),
463 d_state(state),
464 d_im(im),
465 d_thss(thss),
466 d_regions_index(d_state.getSatContext(), 0),
467 d_regions_map(d_state.getSatContext()),
468 d_split_score(d_state.getSatContext()),
469 d_disequalities_index(d_state.getSatContext(), 0),
470 d_reps(d_state.getSatContext(), 0),
471 d_cardinality(d_state.getSatContext(), 1),
472 d_hasCard(d_state.getSatContext(), false),
473 d_maxNegCard(d_state.getSatContext(), 0),
474 d_initialized(d_state.getUserContext(), false),
475 d_c_dec_strat(nullptr)
476 {
477 d_cardinality_term = n;
478
479 if (options::ufssMode() == options::UfssMode::FULL)
480 {
481 // Register the strategy with the decision manager of the theory.
482 // We are guaranteed that the decision manager is ready since we
483 // construct this module during TheoryUF::finishInit.
484 d_c_dec_strat.reset(new CardinalityDecisionStrategy(
485 n, d_state.getSatContext(), thss->getTheory()->getValuation()));
486 }
487 }
488
489 SortModel::~SortModel() {
490 for(std::vector<Region*>::iterator i = d_regions.begin();
491 i != d_regions.end(); ++i) {
492 Region* region = *i;
493 delete region;
494 }
495 d_regions.clear();
496 }
497
498 /** initialize */
499 void SortModel::initialize()
500 {
501 if (d_c_dec_strat.get() != nullptr && !d_initialized)
502 {
503 d_initialized = true;
504 // Strategy is user-context-dependent, since it is in sync with
505 // user-context-dependent flag d_initialized.
506 d_thss->getTheory()->getDecisionManager()->registerStrategy(
507 DecisionManager::STRAT_UF_CARD, d_c_dec_strat.get());
508 }
509 }
510
511 /** new node */
512 void SortModel::newEqClass( Node n ){
513 if (!d_state.isInConflict())
514 {
515 if( d_regions_map.find( n )==d_regions_map.end() ){
516 d_regions_map[n] = d_regions_index;
517 Debug("uf-ss") << "CardinalityExtension: New Eq Class " << n << std::endl;
518 Debug("uf-ss-debug") << d_regions_index << " " << (int)d_regions.size()
519 << std::endl;
520 if (d_regions_index < d_regions.size())
521 {
522 d_regions[d_regions_index]->debugPrint("uf-ss-debug", true);
523 d_regions[d_regions_index]->setValid(true);
524 Assert(d_regions[d_regions_index]->getNumReps() == 0);
525 }else{
526 d_regions.push_back(new Region(this, d_state.getSatContext()));
527 }
528 d_regions[d_regions_index]->addRep(n);
529 d_regions_index = d_regions_index + 1;
530
531 d_reps = d_reps + 1;
532 }
533 }
534 }
535
536 /** merge */
537 void SortModel::merge( Node a, Node b ){
538 if (d_state.isInConflict())
539 {
540 return;
541 }
542 Debug("uf-ss") << "CardinalityExtension: Merging " << a << " = " << b << "..."
543 << std::endl;
544 if (a != b)
545 {
546 Assert(d_regions_map.find(a) != d_regions_map.end());
547 Assert(d_regions_map.find(b) != d_regions_map.end());
548 int ai = d_regions_map[a];
549 int bi = d_regions_map[b];
550 Debug("uf-ss") << " regions: " << ai << " " << bi << std::endl;
551 if (ai != bi)
552 {
553 if (d_regions[ai]->getNumReps() == 1)
554 {
555 int ri = combineRegions(bi, ai);
556 d_regions[ri]->setEqual(a, b);
557 checkRegion(ri);
558 }
559 else if (d_regions[bi]->getNumReps() == 1)
560 {
561 int ri = combineRegions(ai, bi);
562 d_regions[ri]->setEqual(a, b);
563 checkRegion(ri);
564 }
565 else
566 {
567 // Either move a to d_regions[bi], or b to d_regions[ai].
568 RegionNodeInfo* a_region_info = d_regions[ai]->getRegionInfo(a);
569 RegionNodeInfo* b_region_info = d_regions[bi]->getRegionInfo(b);
570 int aex = (a_region_info->getNumInternalDisequalities()
571 - getNumDisequalitiesToRegion(a, bi));
572 int bex = (b_region_info->getNumInternalDisequalities()
573 - getNumDisequalitiesToRegion(b, ai));
574 // Based on which would produce the fewest number of
575 // external disequalities.
576 if (aex < bex)
577 {
578 moveNode(a, bi);
579 d_regions[bi]->setEqual(a, b);
580 }else{
581 moveNode(b, ai);
582 d_regions[ai]->setEqual( a, b );
583 }
584 checkRegion(ai);
585 checkRegion(bi);
586 }
587 }
588 else
589 {
590 d_regions[ai]->setEqual(a, b);
591 checkRegion(ai);
592 }
593 d_regions_map[b] = -1;
594 }
595 d_reps = d_reps - 1;
596 }
597
598 /** assert terms are disequal */
599 void SortModel::assertDisequal( Node a, Node b, Node reason ){
600 if (d_state.isInConflict())
601 {
602 return;
603 }
604 // if they are not already disequal
605 eq::EqualityEngine* ee = d_thss->getTheory()->getEqualityEngine();
606 a = ee->getRepresentative(a);
607 b = ee->getRepresentative(b);
608 int ai = d_regions_map[a];
609 int bi = d_regions_map[b];
610 if (d_regions[ai]->isDisequal(a, b, ai == bi))
611 {
612 // already disequal
613 return;
614 }
615 Debug("uf-ss") << "Assert disequal " << a << " != " << b << "..."
616 << std::endl;
617 Debug("uf-ss-disequal") << "Assert disequal " << a << " != " << b << "..."
618 << std::endl;
619 // add to list of disequalities
620 if (d_disequalities_index < d_disequalities.size())
621 {
622 d_disequalities[d_disequalities_index] = reason;
623 }
624 else
625 {
626 d_disequalities.push_back(reason);
627 }
628 d_disequalities_index = d_disequalities_index + 1;
629 // now, add disequalities to regions
630 Assert(d_regions_map.find(a) != d_regions_map.end());
631 Assert(d_regions_map.find(b) != d_regions_map.end());
632 Debug("uf-ss") << " regions: " << ai << " " << bi << std::endl;
633 if (ai == bi)
634 {
635 // internal disequality
636 d_regions[ai]->setDisequal(a, b, 1, true);
637 d_regions[ai]->setDisequal(b, a, 1, true);
638 // do not need to check if it needs to combine (no new ext. disequalities)
639 checkRegion(ai, false);
640 }
641 else
642 {
643 // external disequality
644 d_regions[ai]->setDisequal(a, b, 0, true);
645 d_regions[bi]->setDisequal(b, a, 0, true);
646 checkRegion(ai);
647 checkRegion(bi);
648 }
649 }
650
651 bool SortModel::areDisequal( Node a, Node b ) {
652 Assert(a == d_thss->getTheory()->getEqualityEngine()->getRepresentative(a));
653 Assert(b == d_thss->getTheory()->getEqualityEngine()->getRepresentative(b));
654 if( d_regions_map.find( a )!=d_regions_map.end() &&
655 d_regions_map.find( b )!=d_regions_map.end() ){
656 int ai = d_regions_map[a];
657 int bi = d_regions_map[b];
658 return d_regions[ai]->isDisequal(a, b, ai==bi ? 1 : 0);
659 }else{
660 return false;
661 }
662 }
663
664 void SortModel::check(Theory::Effort level)
665 {
666 Assert(options::ufssMode() == options::UfssMode::FULL);
667 if (!d_hasCard && d_state.isInConflict())
668 {
669 // not necessary to check
670 return;
671 }
672 Debug("uf-ss") << "CardinalityExtension: Check " << level << " " << d_type
673 << std::endl;
674 if (level == Theory::EFFORT_FULL)
675 {
676 Debug("fmf-full-check") << std::endl;
677 Debug("fmf-full-check")
678 << "Full check for SortModel " << d_type << ", status : " << std::endl;
679 debugPrint("fmf-full-check");
680 Debug("fmf-full-check") << std::endl;
681 }
682 if (d_reps <= (unsigned)d_cardinality)
683 {
684 Debug("uf-ss-debug") << "We have " << d_reps << " representatives for type "
685 << d_type << ", <= " << d_cardinality << std::endl;
686 if( level==Theory::EFFORT_FULL ){
687 Debug("uf-ss-sat") << "We have " << d_reps << " representatives for type "
688 << d_type << ", <= " << d_cardinality << std::endl;
689 }
690 return;
691 }
692 // first check if we can generate a clique conflict
693 // do a check within each region
694 for (size_t i = 0; i < d_regions_index; i++)
695 {
696 if (d_regions[i]->valid())
697 {
698 std::vector<Node> clique;
699 if (d_regions[i]->check(level, d_cardinality, clique))
700 {
701 // add clique lemma
702 addCliqueLemma(clique);
703 return;
704 }
705 else
706 {
707 Trace("uf-ss-debug") << "No clique in Region #" << i << std::endl;
708 }
709 }
710 }
711 // do splitting on demand
712 bool addedLemma = false;
713 if (level == Theory::EFFORT_FULL)
714 {
715 Trace("uf-ss-debug") << "Add splits?" << std::endl;
716 // see if we have any recommended splits from large regions
717 for (size_t i = 0; i < d_regions_index; i++)
718 {
719 if (d_regions[i]->valid() && d_regions[i]->getNumReps() > d_cardinality)
720 {
721 int sp = addSplit(d_regions[i]);
722 if (sp == 1)
723 {
724 addedLemma = true;
725 }
726 else if (sp == -1)
727 {
728 check(level);
729 return;
730 }
731 }
732 }
733 }
734 // If no added lemmas, force continuation via combination of regions.
735 if (level != Theory::EFFORT_FULL || addedLemma)
736 {
737 return;
738 }
739 // check at full effort
740 Trace("uf-ss-debug") << "No splits added. " << d_cardinality << std::endl;
741 Trace("uf-ss-si") << "Must combine region" << std::endl;
742 bool recheck = false;
743 SortInference* si = d_thss->getSortInference();
744 if (si != nullptr)
745 {
746 // If sort inference is enabled, search for regions with same sort.
747 std::map<int, int> sortsFound;
748 for (size_t i = 0; i < d_regions_index; i++)
749 {
750 if (d_regions[i]->valid())
751 {
752 Node op = d_regions[i]->frontKey();
753 int sort_id = si->getSortId(op);
754 if (sortsFound.find(sort_id) != sortsFound.end())
755 {
756 Debug("fmf-full-check") << "Combined regions " << i << " "
757 << sortsFound[sort_id] << std::endl;
758 combineRegions(sortsFound[sort_id], i);
759 recheck = true;
760 break;
761 }
762 else
763 {
764 sortsFound[sort_id] = i;
765 }
766 }
767 }
768 }
769 if (!recheck)
770 {
771 // naive strategy, force region combination involving the first
772 // valid region
773 for (size_t i = 0; i < d_regions_index; i++)
774 {
775 if (d_regions[i]->valid())
776 {
777 int fcr = forceCombineRegion(i, false);
778 Debug("fmf-full-check")
779 << "Combined regions " << i << " " << fcr << std::endl;
780 Trace("uf-ss-debug")
781 << "Combined regions " << i << " " << fcr << std::endl;
782 recheck = true;
783 break;
784 }
785 }
786 }
787 if (recheck)
788 {
789 Trace("uf-ss-debug") << "Must recheck." << std::endl;
790 check(level);
791 }
792 }
793
794 void SortModel::presolve() {
795 d_initialized = false;
796 }
797
798 int SortModel::getNumDisequalitiesToRegion( Node n, int ri ){
799 int ni = d_regions_map[n];
800 int counter = 0;
801 DiseqList* del = d_regions[ni]->getRegionInfo(n)->get(0);
802 for( DiseqList::iterator it = del->begin(); it != del->end(); ++it ){
803 if( (*it).second ){
804 if( d_regions_map[ (*it).first ]==ri ){
805 counter++;
806 }
807 }
808 }
809 return counter;
810 }
811
812 void SortModel::getDisequalitiesToRegions(int ri,
813 std::map< int, int >& regions_diseq)
814 {
815 Region* region = d_regions[ri];
816 for(Region::iterator it = region->begin(); it != region->end(); ++it ){
817 if( it->second->valid() ){
818 DiseqList* del = it->second->get(0);
819 for( DiseqList::iterator it2 = del->begin(); it2 != del->end(); ++it2 ){
820 if( (*it2).second ){
821 Assert(isValid(d_regions_map[(*it2).first]));
822 //Notice() << "Found disequality with " << (*it2).first << ", region = " << d_regions_map[ (*it2).first ] << std::endl;
823 regions_diseq[ d_regions_map[ (*it2).first ] ]++;
824 }
825 }
826 }
827 }
828 }
829
830 void SortModel::setSplitScore( Node n, int s ){
831 if( d_split_score.find( n )!=d_split_score.end() ){
832 int ss = d_split_score[ n ];
833 d_split_score[ n ] = s>ss ? s : ss;
834 }else{
835 d_split_score[ n ] = s;
836 }
837 for( int i=0; i<(int)n.getNumChildren(); i++ ){
838 setSplitScore( n[i], s+1 );
839 }
840 }
841
842 void SortModel::assertCardinality(uint32_t c, bool val)
843 {
844 if (!d_state.isInConflict())
845 {
846 Trace("uf-ss-assert")
847 << "Assert cardinality " << d_type << " " << c << " " << val
848 << " level = "
849 << d_thss->getTheory()->getValuation().getAssertionLevel() << std::endl;
850 Assert(c > 0);
851 Node cl = getCardinalityLiteral( c );
852 if( val ){
853 bool doCheckRegions = !d_hasCard;
854 bool prevHasCard = d_hasCard;
855 d_hasCard = true;
856 if (!prevHasCard || c < d_cardinality)
857 {
858 d_cardinality = c;
859 simpleCheckCardinality();
860 if (d_state.isInConflict())
861 {
862 return;
863 }
864 }
865 //should check all regions now
866 if (doCheckRegions)
867 {
868 for (size_t i = 0; i < d_regions_index; i++)
869 {
870 if( d_regions[i]->valid() ){
871 checkRegion( i );
872 if (d_state.isInConflict())
873 {
874 return;
875 }
876 }
877 }
878 }
879 // we assert it positively, if its beyond the bound, abort
880 if (options::ufssAbortCardinality() >= 0
881 && c >= static_cast<uint32_t>(options::ufssAbortCardinality()))
882 {
883 std::stringstream ss;
884 ss << "Maximum cardinality (" << options::ufssAbortCardinality()
885 << ") for finite model finding exceeded." << std::endl;
886 throw LogicException(ss.str());
887 }
888 }
889 else
890 {
891 if (c > d_maxNegCard.get())
892 {
893 Trace("uf-ss-com-card-debug") << "Maximum negative cardinality for "
894 << d_type << " is now " << c << std::endl;
895 d_maxNegCard.set(c);
896 simpleCheckCardinality();
897 }
898 }
899 }
900 }
901
902 void SortModel::checkRegion( int ri, bool checkCombine ){
903 if( isValid(ri) && d_hasCard ){
904 Assert(d_cardinality > 0);
905 if( checkCombine && d_regions[ri]->getMustCombine( d_cardinality ) ){
906 int riNew = forceCombineRegion( ri, true );
907 if( riNew>=0 ){
908 checkRegion( riNew, checkCombine );
909 }
910 }
911 //now check if region is in conflict
912 std::vector< Node > clique;
913 if( d_regions[ri]->check( Theory::EFFORT_STANDARD, d_cardinality, clique ) ){
914 //explain clique
915 addCliqueLemma(clique);
916 }
917 }
918 }
919
920 int SortModel::forceCombineRegion( int ri, bool useDensity ){
921 if( !useDensity ){
922 for( int i=0; i<(int)d_regions_index; i++ ){
923 if( ri!=i && d_regions[i]->valid() ){
924 return combineRegions( ri, i );
925 }
926 }
927 return -1;
928 }else{
929 //this region must merge with another
930 if( Debug.isOn("uf-ss-check-region") ){
931 Debug("uf-ss-check-region") << "We must combine Region #" << ri << ". " << std::endl;
932 d_regions[ri]->debugPrint("uf-ss-check-region");
933 }
934 //take region with maximum disequality density
935 double maxScore = 0;
936 int maxRegion = -1;
937 std::map< int, int > regions_diseq;
938 getDisequalitiesToRegions( ri, regions_diseq );
939 for( std::map< int, int >::iterator it = regions_diseq.begin(); it != regions_diseq.end(); ++it ){
940 Debug("uf-ss-check-region") << it->first << " : " << it->second << std::endl;
941 }
942 for( std::map< int, int >::iterator it = regions_diseq.begin(); it != regions_diseq.end(); ++it ){
943 Assert(it->first != ri);
944 Assert(isValid(it->first));
945 Assert(d_regions[it->first]->getNumReps() > 0);
946 double tempScore = double(it->second)/double(d_regions[it->first]->getNumReps() );
947 if( tempScore>maxScore ){
948 maxRegion = it->first;
949 maxScore = tempScore;
950 }
951 }
952 if( maxRegion!=-1 ){
953 if( Debug.isOn("uf-ss-check-region") ){
954 Debug("uf-ss-check-region") << "Combine with region #" << maxRegion << ":" << std::endl;
955 d_regions[maxRegion]->debugPrint("uf-ss-check-region");
956 }
957 return combineRegions( ri, maxRegion );
958 }
959 return -1;
960 }
961 }
962
963
964 int SortModel::combineRegions( int ai, int bi ){
965 Debug("uf-ss-region") << "uf-ss: Combine Region #" << bi << " with Region #" << ai << std::endl;
966 Assert(isValid(ai) && isValid(bi));
967 Region* region_bi = d_regions[bi];
968 for(Region::iterator it = region_bi->begin(); it != region_bi->end(); ++it){
969 Region::RegionNodeInfo* rni = it->second;
970 if( rni->valid() ){
971 d_regions_map[ it->first ] = ai;
972 }
973 }
974 //update regions disequal DO_THIS?
975 d_regions[ai]->combine( d_regions[bi] );
976 d_regions[bi]->setValid( false );
977 return ai;
978 }
979
980 void SortModel::moveNode( Node n, int ri ){
981 Debug("uf-ss-region") << "uf-ss: Move node " << n << " to Region #" << ri << std::endl;
982 Assert(isValid(d_regions_map[n]));
983 Assert(isValid(ri));
984 //move node to region ri
985 d_regions[ri]->takeNode( d_regions[ d_regions_map[n] ], n );
986 d_regions_map[n] = ri;
987 }
988
989 int SortModel::addSplit(Region* r)
990 {
991 Node s;
992 if( r->hasSplits() ){
993 //take the first split you find
994 for( Region::split_iterator it = r->begin_splits();
995 it != r->end_splits(); ++it ){
996 if( (*it).second ){
997 s = (*it).first;
998 break;
999 }
1000 }
1001 Assert(s != Node::null());
1002 }
1003 if (!s.isNull() ){
1004 //add lemma to output channel
1005 Assert(s.getKind() == EQUAL);
1006 Node ss = Rewriter::rewrite( s );
1007 if( ss.getKind()!=EQUAL ){
1008 Node b_t = NodeManager::currentNM()->mkConst( true );
1009 Node b_f = NodeManager::currentNM()->mkConst( false );
1010 if( ss==b_f ){
1011 Trace("uf-ss-lemma") << "....Assert disequal directly : "
1012 << s[0] << " " << s[1] << std::endl;
1013 assertDisequal( s[0], s[1], b_t );
1014 return -1;
1015 }else{
1016 Trace("uf-ss-warn") << "Split on unknown literal : " << ss << std::endl;
1017 }
1018 if (ss == b_t)
1019 {
1020 CVC4Message() << "Bad split " << s << std::endl;
1021 AlwaysAssert(false);
1022 }
1023 }
1024 SortInference* si = d_thss->getSortInference();
1025 if (si != nullptr)
1026 {
1027 for( int i=0; i<2; i++ ){
1028 int sid = si->getSortId(ss[i]);
1029 Trace("uf-ss-split-si") << sid << " ";
1030 }
1031 Trace("uf-ss-split-si") << std::endl;
1032 }
1033 //Trace("uf-ss-lemma") << d_th->getEqualityEngine()->areEqual( s[0], s[1] ) << " ";
1034 //Trace("uf-ss-lemma") << d_th->getEqualityEngine()->areDisequal( s[0], s[1] ) << std::endl;
1035 //Trace("uf-ss-lemma") << s[0].getType() << " " << s[1].getType() << std::endl;
1036 //Notice() << "*** Split on " << s << std::endl;
1037 //split on the equality s
1038 Node lem = NodeManager::currentNM()->mkNode( kind::OR, ss, ss.negate() );
1039 // send lemma, with caching
1040 if (d_im.lemma(lem))
1041 {
1042 Trace("uf-ss-lemma") << "*** Split on " << s << std::endl;
1043 //tell the sat solver to explore the equals branch first
1044 d_im.requirePhase(ss, true);
1045 ++( d_thss->d_statistics.d_split_lemmas );
1046 }
1047 return 1;
1048 }else{
1049 return 0;
1050 }
1051 }
1052
1053 void SortModel::addCliqueLemma(std::vector<Node>& clique)
1054 {
1055 Assert(d_hasCard);
1056 Assert(d_cardinality > 0);
1057 while (clique.size() > d_cardinality + 1)
1058 {
1059 clique.pop_back();
1060 }
1061 // add as lemma
1062 std::vector<Node> eqs;
1063 for (unsigned i = 0, size = clique.size(); i < size; i++)
1064 {
1065 for (unsigned j = 0; j < i; j++)
1066 {
1067 eqs.push_back(clique[i].eqNode(clique[j]));
1068 }
1069 }
1070 eqs.push_back(d_cardinality_literal[d_cardinality].notNode());
1071 Node lem = NodeManager::currentNM()->mkNode(OR, eqs);
1072 // send lemma, with caching
1073 if (d_im.lemma(lem))
1074 {
1075 Trace("uf-ss-lemma") << "*** Add clique lemma " << lem << std::endl;
1076 ++(d_thss->d_statistics.d_clique_lemmas);
1077 }
1078 }
1079
1080 void SortModel::simpleCheckCardinality() {
1081 if( d_maxNegCard.get()!=0 && d_hasCard.get() && d_cardinality.get()<d_maxNegCard.get() ){
1082 Node lem = NodeManager::currentNM()->mkNode( AND, getCardinalityLiteral( d_cardinality.get() ),
1083 getCardinalityLiteral( d_maxNegCard.get() ).negate() );
1084 Trace("uf-ss-lemma") << "*** Simple cardinality conflict : " << lem << std::endl;
1085 d_im.conflict(lem);
1086 }
1087 }
1088
1089 void SortModel::debugPrint( const char* c ){
1090 if( Debug.isOn( c ) ){
1091 Debug( c ) << "Number of reps = " << d_reps << std::endl;
1092 Debug( c ) << "Cardinality req = " << d_cardinality << std::endl;
1093 unsigned debugReps = 0;
1094 for( unsigned i=0; i<d_regions_index; i++ ){
1095 Region* region = d_regions[i];
1096 if( region->valid() ){
1097 Debug( c ) << "Region #" << i << ": " << std::endl;
1098 region->debugPrint( c, true );
1099 Debug( c ) << std::endl;
1100 for( Region::iterator it = region->begin(); it != region->end(); ++it ){
1101 if( it->second->valid() ){
1102 if( d_regions_map[ it->first ]!=(int)i ){
1103 Debug( c ) << "***Bad regions map : " << it->first
1104 << " " << d_regions_map[ it->first ].get() << std::endl;
1105 }
1106 }
1107 }
1108 debugReps += region->getNumReps();
1109 }
1110 }
1111
1112 if( debugReps!=d_reps ){
1113 Debug( c ) << "***Bad reps: " << d_reps << ", "
1114 << "actual = " << debugReps << std::endl;
1115 }
1116 }
1117 }
1118
1119 bool SortModel::checkLastCall()
1120 {
1121 TheoryModel* m = d_state.getModel();
1122 if( Trace.isOn("uf-ss-warn") ){
1123 std::vector< Node > eqcs;
1124 eq::EqClassesIterator eqcs_i =
1125 eq::EqClassesIterator(m->getEqualityEngine());
1126 while( !eqcs_i.isFinished() ){
1127 Node eqc = (*eqcs_i);
1128 if( eqc.getType()==d_type ){
1129 if( std::find( eqcs.begin(), eqcs.end(), eqc )==eqcs.end() ){
1130 eqcs.push_back( eqc );
1131 //we must ensure that this equivalence class has been accounted for
1132 if( d_regions_map.find( eqc )==d_regions_map.end() ){
1133 Trace("uf-ss-warn") << "WARNING : equivalence class " << eqc << " unaccounted for." << std::endl;
1134 Trace("uf-ss-warn") << " type : " << d_type << std::endl;
1135 Trace("uf-ss-warn") << " kind : " << eqc.getKind() << std::endl;
1136 }
1137 }
1138 }
1139 ++eqcs_i;
1140 }
1141 }
1142 RepSet* rs = m->getRepSetPtr();
1143 size_t nReps = rs->getNumRepresentatives(d_type);
1144 if (nReps != d_maxNegCard + 1)
1145 {
1146 Trace("uf-ss-warn") << "WARNING : Model does not have same # "
1147 "representatives as cardinality for "
1148 << d_type << "." << std::endl;
1149 Trace("uf-ss-warn") << " Max neg cardinality : " << d_maxNegCard
1150 << std::endl;
1151 Trace("uf-ss-warn") << " # Reps : " << nReps << std::endl;
1152 if (d_maxNegCard >= nReps)
1153 {
1154 while (d_fresh_aloc_reps.size() <= d_maxNegCard)
1155 {
1156 std::stringstream ss;
1157 ss << "r_" << d_type << "_";
1158 Node nn = NodeManager::currentNM()->mkSkolem(
1159 ss.str(), d_type, "enumeration to meet negative card constraint");
1160 d_fresh_aloc_reps.push_back( nn );
1161 }
1162 if (d_maxNegCard == 0)
1163 {
1164 rs->d_type_reps[d_type].push_back(d_fresh_aloc_reps[0]);
1165 }
1166 else
1167 {
1168 //must add lemma
1169 std::vector< Node > force_cl;
1170 for (size_t i = 0; i <= d_maxNegCard; i++)
1171 {
1172 for (size_t j = (i + 1); j <= d_maxNegCard; j++)
1173 {
1174 force_cl.push_back(
1175 d_fresh_aloc_reps[i].eqNode(d_fresh_aloc_reps[j]).negate());
1176 }
1177 }
1178 Node cl = getCardinalityLiteral( d_maxNegCard );
1179 Node lem = NodeManager::currentNM()->mkNode(
1180 OR, cl, NodeManager::currentNM()->mkAnd(force_cl));
1181 Trace("uf-ss-lemma") << "*** Enforce negative cardinality constraint lemma : " << lem << std::endl;
1182 d_im.lemma(lem, LemmaProperty::NONE, false);
1183 return false;
1184 }
1185 }
1186 }
1187 return true;
1188 }
1189
1190 int SortModel::getNumRegions(){
1191 int count = 0;
1192 for( int i=0; i<(int)d_regions_index; i++ ){
1193 if( d_regions[i]->valid() ){
1194 count++;
1195 }
1196 }
1197 return count;
1198 }
1199
1200 Node SortModel::getCardinalityLiteral(uint32_t c)
1201 {
1202 Assert(c > 0);
1203 std::map<uint32_t, Node>::iterator itcl = d_cardinality_literal.find(c);
1204 if (itcl != d_cardinality_literal.end())
1205 {
1206 return itcl->second;
1207 }
1208 // get the literal from the decision strategy
1209 Node lit = d_c_dec_strat->getLiteral(c - 1);
1210 d_cardinality_literal[c] = lit;
1211
1212 // return the literal
1213 return lit;
1214 }
1215
1216 CardinalityExtension::CardinalityExtension(TheoryState& state,
1217 TheoryInferenceManager& im,
1218 TheoryUF* th)
1219 : d_state(state),
1220 d_im(im),
1221 d_th(th),
1222 d_rep_model(),
1223 d_min_pos_com_card(state.getSatContext(), 0),
1224 d_min_pos_com_card_set(state.getSatContext(), false),
1225 d_cc_dec_strat(nullptr),
1226 d_initializedCombinedCardinality(state.getUserContext(), false),
1227 d_card_assertions_eqv_lemma(state.getUserContext()),
1228 d_min_pos_tn_master_card(state.getSatContext(), 0),
1229 d_min_pos_tn_master_card_set(state.getSatContext(), false),
1230 d_rel_eqc(state.getSatContext())
1231 {
1232 if (options::ufssMode() == options::UfssMode::FULL && options::ufssFairness())
1233 {
1234 // Register the strategy with the decision manager of the theory.
1235 // We are guaranteed that the decision manager is ready since we
1236 // construct this module during TheoryUF::finishInit.
1237 d_cc_dec_strat.reset(new CombinedCardinalityDecisionStrategy(
1238 state.getSatContext(), th->getValuation()));
1239 }
1240 }
1241
1242 CardinalityExtension::~CardinalityExtension()
1243 {
1244 for (std::map<TypeNode, SortModel*>::iterator it = d_rep_model.begin();
1245 it != d_rep_model.end(); ++it) {
1246 delete it->second;
1247 }
1248 }
1249
1250 SortInference* CardinalityExtension::getSortInference()
1251 {
1252 if (!options::sortInference())
1253 {
1254 return nullptr;
1255 }
1256 QuantifiersEngine* qe = d_th->getQuantifiersEngine();
1257 if (qe != nullptr)
1258 {
1259 return qe->getTheoryEngine()->getSortInference();
1260 }
1261 return nullptr;
1262 }
1263
1264 /** ensure eqc */
1265 void CardinalityExtension::ensureEqc(SortModel* c, Node a)
1266 {
1267 if( !hasEqc( a ) ){
1268 d_rel_eqc[a] = true;
1269 Trace("uf-ss-solver") << "CardinalityExtension: New eq class " << a << " : "
1270 << a.getType() << std::endl;
1271 c->newEqClass( a );
1272 Trace("uf-ss-solver") << "CardinalityExtension: Done New eq class."
1273 << std::endl;
1274 }
1275 }
1276
1277 void CardinalityExtension::ensureEqcRec(Node n)
1278 {
1279 if( !hasEqc( n ) ){
1280 SortModel* c = getSortModel( n );
1281 if( c ){
1282 ensureEqc( c, n );
1283 }
1284 for( unsigned i=0; i<n.getNumChildren(); i++ ){
1285 ensureEqcRec( n[i] );
1286 }
1287 }
1288 }
1289
1290 /** has eqc */
1291 bool CardinalityExtension::hasEqc(Node a)
1292 {
1293 NodeBoolMap::iterator it = d_rel_eqc.find( a );
1294 return it!=d_rel_eqc.end() && (*it).second;
1295 }
1296
1297 /** new node */
1298 void CardinalityExtension::newEqClass(Node a)
1299 {
1300 SortModel* c = getSortModel( a );
1301 if( c ){
1302 Trace("uf-ss-solver") << "CardinalityExtension: New eq class " << a << " : "
1303 << a.getType() << std::endl;
1304 c->newEqClass( a );
1305 Trace("uf-ss-solver") << "CardinalityExtension: Done New eq class."
1306 << std::endl;
1307 }
1308 }
1309
1310 /** merge */
1311 void CardinalityExtension::merge(Node a, Node b)
1312 {
1313 //TODO: ensure they are relevant
1314 SortModel* c = getSortModel( a );
1315 if( c ){
1316 Trace("uf-ss-solver") << "CardinalityExtension: Merge " << a << " " << b
1317 << " : " << a.getType() << std::endl;
1318 c->merge( a, b );
1319 Trace("uf-ss-solver") << "CardinalityExtension: Done Merge." << std::endl;
1320 }
1321 }
1322
1323 /** assert terms are disequal */
1324 void CardinalityExtension::assertDisequal(Node a, Node b, Node reason)
1325 {
1326 SortModel* c = getSortModel( a );
1327 if( c ){
1328 Trace("uf-ss-solver") << "CardinalityExtension: Assert disequal " << a
1329 << " " << b << " : " << a.getType() << std::endl;
1330 c->assertDisequal( a, b, reason );
1331 Trace("uf-ss-solver") << "CardinalityExtension: Done Assert disequal."
1332 << std::endl;
1333 }
1334 }
1335
1336 /** assert a node */
1337 void CardinalityExtension::assertNode(Node n, bool isDecision)
1338 {
1339 Trace("uf-ss") << "Assert " << n << " " << isDecision << std::endl;
1340 bool polarity = n.getKind() != kind::NOT;
1341 TNode lit = polarity ? n : n[0];
1342 if (options::ufssMode() == options::UfssMode::FULL)
1343 {
1344 if( lit.getKind()==CARDINALITY_CONSTRAINT ){
1345 TypeNode tn = lit[0].getType();
1346 Assert(tn.isSort());
1347 Assert(d_rep_model[tn]);
1348 uint32_t nCard =
1349 lit[1].getConst<Rational>().getNumerator().getUnsignedInt();
1350 Node ct = d_rep_model[tn]->getCardinalityTerm();
1351 Trace("uf-ss-debug") << "...check cardinality terms : " << lit[0] << " " << ct << std::endl;
1352 if( lit[0]==ct ){
1353 if( options::ufssFairnessMonotone() ){
1354 Trace("uf-ss-com-card-debug") << "...set master/slave" << std::endl;
1355 if( tn!=d_tn_mono_master ){
1356 std::map< TypeNode, bool >::iterator it = d_tn_mono_slave.find( tn );
1357 if( it==d_tn_mono_slave.end() ){
1358 bool isMonotonic;
1359 SortInference* si = getSortInference();
1360 if (si != nullptr)
1361 {
1362 isMonotonic = si->isMonotonic(tn);
1363 }else{
1364 //if ground, everything is monotonic
1365 isMonotonic = true;
1366 }
1367 if( isMonotonic ){
1368 if( d_tn_mono_master.isNull() ){
1369 Trace("uf-ss-com-card-debug") << "uf-ss-fair-monotone: Set master : " << tn << std::endl;
1370 d_tn_mono_master = tn;
1371 }else{
1372 Trace("uf-ss-com-card-debug") << "uf-ss-fair-monotone: Set slave : " << tn << std::endl;
1373 d_tn_mono_slave[tn] = true;
1374 }
1375 }else{
1376 Trace("uf-ss-com-card-debug") << "uf-ss-fair-monotone: Set non-monotonic : " << tn << std::endl;
1377 d_tn_mono_slave[tn] = false;
1378 }
1379 }
1380 }
1381 //set the minimum positive cardinality for master if necessary
1382 if( polarity && tn==d_tn_mono_master ){
1383 Trace("uf-ss-com-card-debug") << "...set min positive cardinality" << std::endl;
1384 if (!d_min_pos_tn_master_card_set.get()
1385 || nCard < d_min_pos_tn_master_card.get())
1386 {
1387 d_min_pos_tn_master_card_set.set(true);
1388 d_min_pos_tn_master_card.set( nCard );
1389 }
1390 }
1391 }
1392 Trace("uf-ss-com-card-debug") << "...assert cardinality" << std::endl;
1393 d_rep_model[tn]->assertCardinality(nCard, polarity);
1394 //check if combined cardinality is violated
1395 checkCombinedCardinality();
1396 }else{
1397 //otherwise, make equal via lemma
1398 if( d_card_assertions_eqv_lemma.find( lit )==d_card_assertions_eqv_lemma.end() ){
1399 Node eqv_lit = NodeManager::currentNM()->mkNode( CARDINALITY_CONSTRAINT, ct, lit[1] );
1400 eqv_lit = lit.eqNode( eqv_lit );
1401 Trace("uf-ss-lemma") << "*** Cardinality equiv lemma : " << eqv_lit << std::endl;
1402 d_im.lemma(eqv_lit, LemmaProperty::NONE, false);
1403 d_card_assertions_eqv_lemma[lit] = true;
1404 }
1405 }
1406 }else if( lit.getKind()==COMBINED_CARDINALITY_CONSTRAINT ){
1407 if( polarity ){
1408 //safe to assume int here
1409 uint32_t nCard =
1410 lit[0].getConst<Rational>().getNumerator().getUnsignedInt();
1411 if (!d_min_pos_com_card_set.get() || nCard < d_min_pos_com_card.get())
1412 {
1413 d_min_pos_com_card_set.set(true);
1414 d_min_pos_com_card.set( nCard );
1415 checkCombinedCardinality();
1416 }
1417 }
1418 }else{
1419 if( Trace.isOn("uf-ss-warn") ){
1420 ////FIXME: this is too strict: theory propagations are showing up as isDecision=true, but
1421 //// a theory propagation is not a decision.
1422 if( isDecision ){
1423 for( std::map< TypeNode, SortModel* >::iterator it = d_rep_model.begin(); it != d_rep_model.end(); ++it ){
1424 if( !it->second->hasCardinalityAsserted() ){
1425 Trace("uf-ss-warn") << "WARNING: Assert " << n << " as a decision before cardinality for " << it->first << "." << std::endl;
1426 // CVC4Message() << "Error: constraint asserted before cardinality
1427 // for " << it->first << std::endl; Unimplemented();
1428 }
1429 }
1430 }
1431 }
1432 }
1433 }
1434 else
1435 {
1436 if( lit.getKind()==CARDINALITY_CONSTRAINT || lit.getKind()==COMBINED_CARDINALITY_CONSTRAINT ){
1437 // cardinality constraint from user input, set incomplete
1438 Trace("uf-ss") << "Literal " << lit << " not handled when uf ss mode is not FULL, set incomplete." << std::endl;
1439 d_im.setIncomplete();
1440 }
1441 }
1442 Trace("uf-ss") << "Assert: done " << n << " " << isDecision << std::endl;
1443 }
1444
1445 bool CardinalityExtension::areDisequal(Node a, Node b)
1446 {
1447 if( a==b ){
1448 return false;
1449 }
1450 eq::EqualityEngine* ee = d_th->getEqualityEngine();
1451 a = ee->getRepresentative(a);
1452 b = ee->getRepresentative(b);
1453 if (ee->areDisequal(a, b, false))
1454 {
1455 return true;
1456 }
1457 SortModel* c = getSortModel(a);
1458 if (c)
1459 {
1460 return c->areDisequal(a, b);
1461 }
1462 return false;
1463 }
1464
1465 /** check */
1466 void CardinalityExtension::check(Theory::Effort level)
1467 {
1468 if (level == Theory::EFFORT_LAST_CALL)
1469 {
1470 // if last call, call last call check for each sort
1471 for (std::pair<const TypeNode, SortModel*>& r : d_rep_model)
1472 {
1473 if (!r.second->checkLastCall())
1474 {
1475 break;
1476 }
1477 }
1478 return;
1479 }
1480 if (!d_state.isInConflict())
1481 {
1482 if (options::ufssMode() == options::UfssMode::FULL)
1483 {
1484 Trace("uf-ss-solver")
1485 << "CardinalityExtension: check " << level << std::endl;
1486 if (level == Theory::EFFORT_FULL)
1487 {
1488 if (Debug.isOn("uf-ss-debug"))
1489 {
1490 debugPrint("uf-ss-debug");
1491 }
1492 if (Trace.isOn("uf-ss-state"))
1493 {
1494 Trace("uf-ss-state")
1495 << "CardinalityExtension::check " << level << std::endl;
1496 for (std::pair<const TypeNode, SortModel*>& rm : d_rep_model)
1497 {
1498 Trace("uf-ss-state") << " " << rm.first << " has cardinality "
1499 << rm.second->getCardinality() << std::endl;
1500 }
1501 }
1502 }
1503 for( std::map< TypeNode, SortModel* >::iterator it = d_rep_model.begin(); it != d_rep_model.end(); ++it ){
1504 it->second->check(level);
1505 if (d_state.isInConflict())
1506 {
1507 break;
1508 }
1509 }
1510 }
1511 else if (options::ufssMode() == options::UfssMode::NO_MINIMAL)
1512 {
1513 if( level==Theory::EFFORT_FULL ){
1514 // split on an equality between two equivalence classes (at most one per type)
1515 std::map< TypeNode, std::vector< Node > > eqc_list;
1516 std::map< TypeNode, bool > type_proc;
1517 eq::EqClassesIterator eqcs_i(d_th->getEqualityEngine());
1518 while( !eqcs_i.isFinished() ){
1519 Node a = *eqcs_i;
1520 TypeNode tn = a.getType();
1521 if( tn.isSort() ){
1522 if( type_proc.find( tn )==type_proc.end() ){
1523 std::map< TypeNode, std::vector< Node > >::iterator itel = eqc_list.find( tn );
1524 if( itel!=eqc_list.end() ){
1525 for( unsigned j=0; j<itel->second.size(); j++ ){
1526 Node b = itel->second[j];
1527 if( !d_th->getEqualityEngine()->areDisequal( a, b, false ) ){
1528 Node eq = Rewriter::rewrite( a.eqNode( b ) );
1529 Node lem = NodeManager::currentNM()->mkNode( kind::OR, eq, eq.negate() );
1530 Trace("uf-ss-lemma") << "*** Split (no-minimal) : " << lem << std::endl;
1531 d_im.lemma(lem, LemmaProperty::NONE, false);
1532 d_im.requirePhase(eq, true);
1533 type_proc[tn] = true;
1534 break;
1535 }
1536 }
1537 }
1538 eqc_list[tn].push_back( a );
1539 }
1540 }
1541 ++eqcs_i;
1542 }
1543 }
1544 }
1545 else
1546 {
1547 // unhandled uf ss mode
1548 Assert(false);
1549 }
1550 Trace("uf-ss-solver") << "Done CardinalityExtension: check " << level
1551 << std::endl;
1552 }
1553 }
1554
1555 void CardinalityExtension::presolve()
1556 {
1557 d_initializedCombinedCardinality = false;
1558 for( std::map< TypeNode, SortModel* >::iterator it = d_rep_model.begin(); it != d_rep_model.end(); ++it ){
1559 it->second->presolve();
1560 it->second->initialize();
1561 }
1562 }
1563
1564 CardinalityExtension::CombinedCardinalityDecisionStrategy::
1565 CombinedCardinalityDecisionStrategy(context::Context* satContext,
1566 Valuation valuation)
1567 : DecisionStrategyFmf(satContext, valuation)
1568 {
1569 }
1570 Node CardinalityExtension::CombinedCardinalityDecisionStrategy::mkLiteral(
1571 unsigned i)
1572 {
1573 NodeManager* nm = NodeManager::currentNM();
1574 return nm->mkNode(COMBINED_CARDINALITY_CONSTRAINT, nm->mkConst(Rational(i)));
1575 }
1576
1577 std::string
1578 CardinalityExtension::CombinedCardinalityDecisionStrategy::identify() const
1579 {
1580 return std::string("uf_combined_card");
1581 }
1582
1583 void CardinalityExtension::preRegisterTerm(TNode n)
1584 {
1585 if (options::ufssMode() == options::UfssMode::FULL)
1586 {
1587 //initialize combined cardinality
1588 initializeCombinedCardinality();
1589
1590 Trace("uf-ss-register") << "Preregister " << n << "." << std::endl;
1591 //shouldn't have to preregister this type (it may be that there are no quantifiers over tn)
1592 TypeNode tn = n.getType();
1593 std::map< TypeNode, SortModel* >::iterator it = d_rep_model.find( tn );
1594 if( it==d_rep_model.end() ){
1595 SortModel* rm = NULL;
1596 if( tn.isSort() ){
1597 Trace("uf-ss-register") << "Create sort model " << tn << "." << std::endl;
1598 rm = new SortModel(n, d_state, d_im, this);
1599 }
1600 if( rm ){
1601 rm->initialize();
1602 d_rep_model[tn] = rm;
1603 //d_rep_model_init[tn] = true;
1604 }
1605 }else{
1606 //ensure sort model is initialized
1607 it->second->initialize();
1608 }
1609 }
1610 }
1611
1612 SortModel* CardinalityExtension::getSortModel(Node n)
1613 {
1614 TypeNode tn = n.getType();
1615 std::map< TypeNode, SortModel* >::iterator it = d_rep_model.find( tn );
1616 //pre-register the type if not done already
1617 if( it==d_rep_model.end() ){
1618 preRegisterTerm( n );
1619 it = d_rep_model.find( tn );
1620 }
1621 if( it!=d_rep_model.end() ){
1622 return it->second;
1623 }else{
1624 return NULL;
1625 }
1626 }
1627
1628 /** get cardinality for sort */
1629 int CardinalityExtension::getCardinality(Node n)
1630 {
1631 SortModel* c = getSortModel( n );
1632 if( c ){
1633 return c->getCardinality();
1634 }else{
1635 return -1;
1636 }
1637 }
1638
1639 int CardinalityExtension::getCardinality(TypeNode tn)
1640 {
1641 std::map< TypeNode, SortModel* >::iterator it = d_rep_model.find( tn );
1642 if( it!=d_rep_model.end() && it->second ){
1643 return it->second->getCardinality();
1644 }
1645 return -1;
1646 }
1647
1648 //print debug
1649 void CardinalityExtension::debugPrint(const char* c)
1650 {
1651 for( std::map< TypeNode, SortModel* >::iterator it = d_rep_model.begin(); it != d_rep_model.end(); ++it ){
1652 Debug( c ) << "Conflict find structure for " << it->first << ": " << std::endl;
1653 it->second->debugPrint( c );
1654 Debug( c ) << std::endl;
1655 }
1656 }
1657
1658 /** initialize */
1659 void CardinalityExtension::initializeCombinedCardinality()
1660 {
1661 if (d_cc_dec_strat.get() != nullptr
1662 && !d_initializedCombinedCardinality.get())
1663 {
1664 d_initializedCombinedCardinality = true;
1665 d_th->getDecisionManager()->registerStrategy(
1666 DecisionManager::STRAT_UF_COMBINED_CARD, d_cc_dec_strat.get());
1667 }
1668 }
1669
1670 /** check */
1671 void CardinalityExtension::checkCombinedCardinality()
1672 {
1673 Assert(options::ufssMode() == options::UfssMode::FULL);
1674 if( options::ufssFairness() ){
1675 Trace("uf-ss-com-card-debug") << "Check combined cardinality, get maximum negative cardinalities..." << std::endl;
1676 uint32_t totalCombinedCard = 0;
1677 uint32_t maxMonoSlave = 0;
1678 TypeNode maxSlaveType;
1679 for( std::map< TypeNode, SortModel* >::iterator it = d_rep_model.begin(); it != d_rep_model.end(); ++it ){
1680 uint32_t max_neg = it->second->getMaximumNegativeCardinality();
1681 if( !options::ufssFairnessMonotone() ){
1682 totalCombinedCard += max_neg;
1683 }else{
1684 std::map< TypeNode, bool >::iterator its = d_tn_mono_slave.find( it->first );
1685 if( its==d_tn_mono_slave.end() || !its->second ){
1686 totalCombinedCard += max_neg;
1687 }else{
1688 if( max_neg>maxMonoSlave ){
1689 maxMonoSlave = max_neg;
1690 maxSlaveType = it->first;
1691 }
1692 }
1693 }
1694 }
1695 Trace("uf-ss-com-card-debug") << "Check combined cardinality, total combined card : " << totalCombinedCard << std::endl;
1696 if( options::ufssFairnessMonotone() ){
1697 Trace("uf-ss-com-card-debug") << "Max slave monotonic negated cardinality : " << maxMonoSlave << std::endl;
1698 if (!d_min_pos_tn_master_card_set.get()
1699 && maxMonoSlave > d_min_pos_tn_master_card.get())
1700 {
1701 uint32_t mc = d_min_pos_tn_master_card.get();
1702 std::vector< Node > conf;
1703 conf.push_back( d_rep_model[d_tn_mono_master]->getCardinalityLiteral( mc ) );
1704 conf.push_back( d_rep_model[maxSlaveType]->getCardinalityLiteral( maxMonoSlave ).negate() );
1705 Node cf = NodeManager::currentNM()->mkNode( AND, conf );
1706 Trace("uf-ss-lemma") << "*** Combined monotone cardinality conflict"
1707 << " : " << cf << std::endl;
1708 Trace("uf-ss-com-card") << "*** Combined monotone cardinality conflict"
1709 << " : " << cf << std::endl;
1710 d_im.conflict(cf);
1711 return;
1712 }
1713 }
1714 uint32_t cc = d_min_pos_com_card.get();
1715 if (d_min_pos_com_card_set.get() && totalCombinedCard > cc)
1716 {
1717 //conflict
1718 Node com_lit = d_cc_dec_strat->getLiteral(cc);
1719 std::vector< Node > conf;
1720 conf.push_back( com_lit );
1721 uint32_t totalAdded = 0;
1722 for( std::map< TypeNode, SortModel* >::iterator it = d_rep_model.begin();
1723 it != d_rep_model.end(); ++it ){
1724 bool doAdd = true;
1725 if( options::ufssFairnessMonotone() ){
1726 std::map< TypeNode, bool >::iterator its =
1727 d_tn_mono_slave.find( it->first );
1728 if( its!=d_tn_mono_slave.end() && its->second ){
1729 doAdd = false;
1730 }
1731 }
1732 if( doAdd ){
1733 uint32_t c = it->second->getMaximumNegativeCardinality();
1734 if( c>0 ){
1735 conf.push_back( it->second->getCardinalityLiteral( c ).negate() );
1736 totalAdded += c;
1737 }
1738 if( totalAdded>cc ){
1739 break;
1740 }
1741 }
1742 }
1743 Node cf = NodeManager::currentNM()->mkNode( AND, conf );
1744 Trace("uf-ss-lemma") << "*** Combined cardinality conflict : " << cf
1745 << std::endl;
1746 Trace("uf-ss-com-card") << "*** Combined cardinality conflict : " << cf
1747 << std::endl;
1748 d_im.conflict(cf);
1749 }
1750 }
1751 }
1752
1753 CardinalityExtension::Statistics::Statistics()
1754 : d_clique_conflicts("CardinalityExtension::Clique_Conflicts", 0),
1755 d_clique_lemmas("CardinalityExtension::Clique_Lemmas", 0),
1756 d_split_lemmas("CardinalityExtension::Split_Lemmas", 0),
1757 d_max_model_size("CardinalityExtension::Max_Model_Size", 1)
1758 {
1759 smtStatisticsRegistry()->registerStat(&d_clique_conflicts);
1760 smtStatisticsRegistry()->registerStat(&d_clique_lemmas);
1761 smtStatisticsRegistry()->registerStat(&d_split_lemmas);
1762 smtStatisticsRegistry()->registerStat(&d_max_model_size);
1763 }
1764
1765 CardinalityExtension::Statistics::~Statistics()
1766 {
1767 smtStatisticsRegistry()->unregisterStat(&d_clique_conflicts);
1768 smtStatisticsRegistry()->unregisterStat(&d_clique_lemmas);
1769 smtStatisticsRegistry()->unregisterStat(&d_split_lemmas);
1770 smtStatisticsRegistry()->unregisterStat(&d_max_model_size);
1771 }
1772
1773 }/* CVC4::theory namespace::uf */
1774 }/* CVC4::theory namespace */
1775 }/* CVC4 namespace */