Adding an option to the equality engine constructor to treat all constants as
[cvc5.git] / src / theory / sets / theory_sets_private.cpp
index 57b761500e87189b258cfac907d5d61699c3a44a..1892ecceb4626e1ad43a100eec9581406338ab22 100644 (file)
@@ -42,8 +42,6 @@ const char* element_of_str = " \u2208 ";
 
 void TheorySetsPrivate::check(Theory::Effort level) {
 
-  CodeTimer checkCodeTimer(d_statistics.d_checkTime);
-
   while(!d_external.done() && !d_conflict) {
     // Get all the assertions
     Assertion assertion = d_external.get();
@@ -734,7 +732,7 @@ Node TheorySetsPrivate::elementsToShape(Elements elements, TypeNode setType) con
   NodeManager* nm = NodeManager::currentNM();
 
   if(elements.size() == 0) {
-    return nm->mkConst(EmptySet(nm->toType(setType)));
+    return nm->mkConst<EmptySet>(EmptySet(nm->toType(setType)));
   } else {
     Elements::iterator it = elements.begin();
     Node cur = SINGLETON(*it);
@@ -749,7 +747,7 @@ Node TheorySetsPrivate::elementsToShape(set<Node> elements, TypeNode setType) co
   NodeManager* nm = NodeManager::currentNM();
 
   if(elements.size() == 0) {
-    return nm->mkConst(EmptySet(nm->toType(setType)));
+    return nm->mkConst<EmptySet>(EmptySet(nm->toType(setType)));
   } else {
     typeof(elements.begin()) it = elements.begin();
     Node cur = SINGLETON(*it);
@@ -767,18 +765,20 @@ void TheorySetsPrivate::collectModelInfo(TheoryModel* m, bool fullModel)
 
   set<Node> terms;
 
+  if(Trace.isOn("sets-assertions")) {
+    dumpAssertionsHumanified();
+  }
+
   // Compute terms appearing assertions and shared terms
   d_external.computeRelevantTerms(terms);
 
   // Compute for each setterm elements that it contains
   SettermElementsMap settermElementsMap;
-  TNode true_atom = NodeManager::currentNM()->mkConst<bool>(true);
-  TNode false_atom = NodeManager::currentNM()->mkConst<bool>(false);
-  for(eq::EqClassIterator it_eqclasses(true_atom, &d_equalityEngine);
+  for(eq::EqClassIterator it_eqclasses(d_trueNode, &d_equalityEngine);
       ! it_eqclasses.isFinished() ; ++it_eqclasses) {
     TNode n = (*it_eqclasses);
     if(n.getKind() == kind::MEMBER) {
-      Assert(d_equalityEngine.areEqual(n, true_atom));
+      Assert(d_equalityEngine.areEqual(n, d_trueNode));
       TNode x = d_equalityEngine.getRepresentative(n[0]);
       TNode S = d_equalityEngine.getRepresentative(n[1]);
       settermElementsMap[S].insert(x);
@@ -795,7 +795,7 @@ void TheorySetsPrivate::collectModelInfo(TheoryModel* m, bool fullModel)
   }
 
   if(Debug.isOn("sets-model-details")) {
-    for(eq::EqClassIterator it_eqclasses(false_atom, &d_equalityEngine);
+    for(eq::EqClassIterator it_eqclasses(d_trueNode, &d_equalityEngine);
         ! it_eqclasses.isFinished() ; ++it_eqclasses) {
       TNode n = (*it_eqclasses);
       vector<TNode> explanation;
@@ -856,8 +856,8 @@ void TheorySetsPrivate::collectModelInfo(TheoryModel* m, bool fullModel)
       checkPassed &= checkModel(settermElementsMap, term);
     }
   }
-  if(Debug.isOn("sets-checkmodel-ignore")) {
-    Debug("sets-checkmodel-ignore") << "[sets-checkmodel-ignore] checkPassed value was " << checkPassed << std::endl;
+  if(Trace.isOn("sets-checkmodel-ignore")) {
+    Trace("sets-checkmodel-ignore") << "[sets-checkmodel-ignore] checkPassed value was " << checkPassed << std::endl;
   } else {
     Assert( checkPassed,
             "THEORY_SETS check-model failed. Run with -d sets-model for details." );
@@ -916,12 +916,10 @@ Node mkAnd(const std::vector<TNode>& conjunctions) {
 
 
 TheorySetsPrivate::Statistics::Statistics() :
-  d_checkTime("theory::sets::time")
-  , d_getModelValueTime("theory::sets::getModelValueTime")
+    d_getModelValueTime("theory::sets::getModelValueTime")
   , d_memberLemmas("theory::sets::lemmas::member", 0)
   , d_disequalityLemmas("theory::sets::lemmas::disequality", 0)
 {
-  StatisticsRegistry::registerStat(&d_checkTime);
   StatisticsRegistry::registerStat(&d_getModelValueTime);
   StatisticsRegistry::registerStat(&d_memberLemmas);
   StatisticsRegistry::registerStat(&d_disequalityLemmas);
@@ -929,7 +927,6 @@ TheorySetsPrivate::Statistics::Statistics() :
 
 
 TheorySetsPrivate::Statistics::~Statistics() {
-  StatisticsRegistry::unregisterStat(&d_checkTime);
   StatisticsRegistry::unregisterStat(&d_getModelValueTime);
   StatisticsRegistry::unregisterStat(&d_memberLemmas);
   StatisticsRegistry::unregisterStat(&d_disequalityLemmas);
@@ -942,7 +939,7 @@ bool TheorySetsPrivate::present(TNode atom) {
 
 
 bool TheorySetsPrivate::holds(TNode atom, bool polarity) {
-  Node polarity_atom = NodeManager::currentNM()->mkConst<bool>(polarity);
+  TNode polarity_atom = polarity ? d_trueNode : d_falseNode;
 
   Node atomModEq = NodeManager::currentNM()->mkNode
     (atom.getKind(), d_equalityEngine.getRepresentative(atom[0]),
@@ -998,21 +995,44 @@ void TheorySetsPrivate::finishPropagation()
 
 void TheorySetsPrivate::addToPending(Node n) {
   Debug("sets-pending") << "[sets-pending] addToPending " << n << std::endl;
-  if(d_pendingEverInserted.find(n) == d_pendingEverInserted.end()) {
-    if(n.getKind() == kind::MEMBER) {
-      Debug("sets-pending") << "[sets-pending] \u2514 added to member queue"
-                            << std::endl;
-      ++d_statistics.d_memberLemmas;
-      d_pending.push(n);
-    } else {
-      Debug("sets-pending") << "[sets-pending] \u2514 added to equality queue"
-                            << std::endl;
-      Assert(n.getKind() == kind::EQUAL);
-      ++d_statistics.d_disequalityLemmas;
-      d_pendingDisequal.push(n);
+
+  if(d_pendingEverInserted.find(n) != d_pendingEverInserted.end()) {
+    Debug("sets-pending") << "[sets-pending] \u2514 skipping " << n
+                         << " as lemma already generated." << std::endl;
+    return;
+  }
+
+  if(n.getKind() == kind::MEMBER) {
+
+    Node nRewritten = theory::Rewriter::rewrite(n);
+
+    if(nRewritten.isConst()) {
+      Debug("sets-pending") << "[sets-pending] \u2514 skipping " << n
+                           << " as we can learn one of the sides." << std::endl;
+      Assert(nRewritten == d_trueNode || nRewritten == d_falseNode);
+
+      bool polarity = (nRewritten == d_trueNode);
+      learnLiteral(n, polarity, d_trueNode);
+      return;
     }
-    d_external.d_out->lemma(getLemma());
+
+    Debug("sets-pending") << "[sets-pending] \u2514 added to member queue"
+                         << std::endl;
+    ++d_statistics.d_memberLemmas;
+    d_pending.push(n);
+    d_external.d_out->splitLemma(getLemma());
     Assert(isComplete());
+
+  } else {
+
+    Debug("sets-pending") << "[sets-pending] \u2514 added to equality queue"
+                         << std::endl;
+    Assert(n.getKind() == kind::EQUAL);
+    ++d_statistics.d_disequalityLemmas;
+    d_pendingDisequal.push(n);
+    d_external.d_out->splitLemma(getLemma());
+    Assert(isComplete());
+
   }
 }
 
@@ -1047,13 +1067,15 @@ Node TheorySetsPrivate::getLemma() {
     d_pendingEverInserted.insert(n);
 
     Assert(n.getKind() == kind::EQUAL && n[0].getType().isSet());
-    Node x = NodeManager::currentNM()->mkSkolem("sde_", n[0].getType().getSetElementType() );
+    TypeNode elementType = n[0].getType().getSetElementType();
+    Node x = NodeManager::currentNM()->mkSkolem("sde_",  elementType);
     Node l1 = MEMBER(x, n[0]), l2 = MEMBER(x, n[1]);
 
     lemma = OR(n, AND(l1, NOT(l2)), AND(NOT(l1), l2));
   }
 
-  Debug("sets-lemma") << "[sets-lemma] Generating for " << n << ", lemma: " << lemma << std::endl;
+  Debug("sets-lemma") << "[sets-lemma] Generating for " << n
+                      << ", lemma: " << lemma << std::endl;
 
   return  lemma;
 }
@@ -1064,7 +1086,9 @@ TheorySetsPrivate::TheorySetsPrivate(TheorySets& external,
                                      context::UserContext* u):
   d_external(external),
   d_notify(*this),
-  d_equalityEngine(d_notify, c, "theory::sets::TheorySetsPrivate"),
+  d_equalityEngine(d_notify, c, "theory::sets::TheorySetsPrivate", false),
+  d_trueNode(NodeManager::currentNM()->mkConst<bool>(true)),
+  d_falseNode(NodeManager::currentNM()->mkConst<bool>(false)),
   d_conflict(c),
   d_termInfoManager(NULL),
   d_propagationQueue(c),
@@ -1114,7 +1138,7 @@ void TheorySetsPrivate::propagate(Theory::Effort effort) {
   }
 
   const CDNodeSet& terms = (d_termInfoManager->d_terms);
-  for(typeof(terms.begin()) it = terms.begin(); it != terms.end(); ++it) {
+  for(typeof(terms.key_begin()) it = terms.key_begin(); it != terms.key_end(); ++it) {
     Node node = (*it);
     Kind k = node.getKind();
     if(k == kind::UNION && node[0].getKind() == kind::SINGLETON ) {
@@ -1219,12 +1243,10 @@ void TheorySetsPrivate::preRegisterTerm(TNode node)
   default:
     d_termInfoManager->addTerm(node);
     d_equalityEngine.addTriggerTerm(node, THEORY_SETS);
-    // d_equalityEngine.addTerm(node);
   }
+
   if(node.getKind() == kind::SINGLETON) {
-    Node true_node = NodeManager::currentNM()->mkConst<bool>(true);
-    learnLiteral(MEMBER(node[0], node), true, true_node);
-    //intentional fallthrough
+    learnLiteral(MEMBER(node[0], node), true, d_trueNode);
   }
 }
 
@@ -1361,25 +1383,40 @@ const CDTNodeList* TheorySetsPrivate::TermInfoManager::getNonMembers(TNode S) {
 }
 
 void TheorySetsPrivate::TermInfoManager::addTerm(TNode n) {
-  unsigned numChild = n.getNumChildren();
+  if(d_terms.contains(n)) {
+    return;
+  }
+  d_terms.insert(n);
 
-  if(!d_terms.contains(n)) {
-    d_terms.insert(n);
-    d_info[n] = new TheorySetsTermInfo(d_context);
+  if(d_info.find(n) == d_info.end()) {
+    d_info.insert(make_pair(n, new TheorySetsTermInfo(d_context)));
   }
 
   if(n.getKind() == kind::UNION ||
      n.getKind() == kind::INTERSECTION ||
      n.getKind() == kind::SETMINUS) {
 
+    unsigned numChild = n.getNumChildren();
+
     for(unsigned i = 0; i < numChild; ++i) {
+      Assert(d_terms.contains(n[i]));
       if(d_terms.contains(n[i])) {
         Debug("sets-parent") << "Adding " << n << " to parent list of "
                              << n[i] << std::endl;
         d_info[n[i]]->parents->push_back(n);
+
+        typeof(d_info.begin()) ita = d_info.find(d_eqEngine->getRepresentative(n[i]));
+        Assert(ita != d_info.end());
+        CDTNodeList* l = (*ita).second->elementsNotInThisSet;
+        for(typeof(l->begin()) it = l->begin(); it != l->end(); ++it) {
+          d_theory.d_settermPropagationQueue.push_back( std::make_pair( (*it), n ) );
+        }
+        l = (*ita).second->elementsInThisSet;
+        for(typeof(l->begin()) it = l->begin(); it != l->end(); ++it) {
+          d_theory.d_settermPropagationQueue.push_back( std::make_pair( (*it), n ) );
+        }
       }
     }
-
   }
 }