change to use tuple element representatives to build TC graph for full
authorPaulMeng <baolmeng@gmail.com>
Thu, 5 May 2016 19:58:10 +0000 (14:58 -0500)
committerPaulMeng <baolmeng@gmail.com>
Thu, 5 May 2016 19:58:10 +0000 (14:58 -0500)
effort

src/theory/sets/theory_sets_rels.cpp
src/theory/sets/theory_sets_rels.h

index 428027acc7447eabd9e55e3b5d8c90a55d02ced5..ccb917d5f6453a2eee43539c5591100fc4108310 100644 (file)
@@ -51,6 +51,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
       Assert(d_pending_facts.empty());
     } else {
       doPendingMerge();
+      doPendingLemmas();
     }
     Trace("rels") << "\n[sets-rels] ******************************* Done with the relational solver *******************************\n" << std::endl;
   }
@@ -208,13 +209,15 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
     if(mem_it != d_membership_db.end()) {
       for(std::vector<Node>::iterator pair_it = mem_it->second.begin();
           pair_it != mem_it->second.end(); pair_it++) {
-        TC_PAIR_IT pair_set_it = tc_graph.find(RelsUtils::nthElementOfTuple(*pair_it, 0));
+        Node fst_rep = getRepresentative(RelsUtils::nthElementOfTuple(*pair_it, 0));
+        Node snd_rep = getRepresentative(RelsUtils::nthElementOfTuple(*pair_it, 1));
+        TC_PAIR_IT pair_set_it = tc_graph.find(fst_rep);
         if( pair_set_it != tc_graph.end() ) {
-          pair_set_it->second.insert(RelsUtils::nthElementOfTuple(*pair_it, 1));
+          pair_set_it->second.insert(snd_rep);
         } else {
-          std::hash_set< Node, NodeHashFunction > snd_pair_set;
-          snd_pair_set.insert(RelsUtils::nthElementOfTuple(*pair_it, 1));
-          tc_graph[RelsUtils::nthElementOfTuple(*pair_it, 0)] = snd_pair_set;
+          std::hash_set< Node, NodeHashFunction > snd_set;
+          snd_set.insert(snd_rep);
+          tc_graph[fst_rep] = snd_set;
         }
       }
     }
@@ -243,14 +246,16 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
     // insert tup_rep in the tc_graph if it is not in the graph already
     TC_IT tc_graph_it = d_membership_tc_cache.find(tc_rep);
     if(polarity) {
+      Node fst_rep = getRepresentative(RelsUtils::nthElementOfTuple(tup_rep, 0));
+      Node snd_rep = getRepresentative(RelsUtils::nthElementOfTuple(tup_rep, 1));
       if(tc_graph_it != d_membership_tc_cache.end()) {
-        TC_PAIR_IT pair_set_it = tc_graph_it->second.find(RelsUtils::nthElementOfTuple(tup_rep, 0));
+        TC_PAIR_IT pair_set_it = tc_graph_it->second.find(fst_rep);
         if(pair_set_it != tc_graph_it->second.end()) {
-          pair_set_it->second.insert(RelsUtils::nthElementOfTuple(tup_rep, 1));
+          pair_set_it->second.insert(snd_rep);
         } else {
           std::hash_set< Node, NodeHashFunction > pair_set;
-          pair_set.insert(RelsUtils::nthElementOfTuple(tup_rep, 1));
-          tc_graph_it->second[RelsUtils::nthElementOfTuple(tup_rep, 0)] = pair_set;
+          pair_set.insert(snd_rep);
+          tc_graph_it->second[fst_rep] = pair_set;
         }
         Node reason = getReason(tc_rep, tc_term, tc_r_rep, tc_term[0]);
         std::map< Node, Node >::iterator exp_it = d_membership_tc_exp_cache.find(tc_rep);
@@ -259,9 +264,9 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
         }
       } else {
         std::map< Node, std::hash_set< Node, NodeHashFunction > > pair_set;
-        std::hash_set< Node, NodeHashFunction > set;
-        set.insert(RelsUtils::nthElementOfTuple(tup_rep, 1));
-        pair_set[RelsUtils::nthElementOfTuple(tup_rep, 0)] = set;
+        std::hash_set< Node, NodeHashFunction > snd_set;
+        snd_set.insert(snd_rep);
+        pair_set[fst_rep] = snd_set;
         d_membership_tc_cache[tc_rep] = pair_set;
         Node reason = getReason(tc_rep, tc_term, tc_r_rep, tc_term[0]);
         if(!reason.isNull()) {
@@ -300,7 +305,6 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
     // check if tup_rep already exists in TC graph for conflict
     } else {
       if(tc_graph_it != d_membership_tc_cache.end()) {
-        Trace("rels-debug") << "********** tc reach here 0" << std::endl;
         checkTCGraphForConflict(atom, tc_rep, d_trueNode, RelsUtils::nthElementOfTuple(tup_rep, 0),
                                 RelsUtils::nthElementOfTuple(tup_rep, 1), tc_graph_it->second);
       }
@@ -310,11 +314,8 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
   void TheorySetsRels::checkTCGraphForConflict (Node atom, Node tc_rep, Node exp, Node a, Node b,
                                                 std::map< Node, std::hash_set< Node, NodeHashFunction > >& pair_set) {
     TC_PAIR_IT pair_set_it = pair_set.find(a);
-    Trace("rels-debug") << "********** tc reach here 1" << " a = " << a << " b = " << b << std::endl;
     if(pair_set_it != pair_set.end()) {
-      Trace("rels-debug") << "********** tc reach here 2" << std::endl;
       if(pair_set_it->second.find(b) != pair_set_it->second.end()) {
-        Trace("rels-debug") << "********** tc reach here 3" << std::endl;
         Node reason = AND(exp, findMemExp(tc_rep, constructPair(tc_rep, a, b)));
         if(atom[1] != tc_rep) {
           reason = AND(exp, explain(EQUAL(atom[1], tc_rep)));
@@ -326,20 +327,17 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
 //                            << AND(reason.negate(), atom) << std::endl;
 //        d_sets_theory.d_out->conflict(AND(reason.negate(), atom));
       } else {
-        Trace("rels-debug") << "********** tc reach here 4" << std::endl;
         std::hash_set< Node, NodeHashFunction >::iterator set_it = pair_set_it->second.begin();
         while(set_it != pair_set_it->second.end()) {
           // need to check if *set_it has been looked already
           if(!areEqual(*set_it, a)) {
             checkTCGraphForConflict(atom, tc_rep, AND(exp, findMemExp(tc_rep, constructPair(tc_rep, a, *set_it))),
                                     *set_it, b, pair_set);  
-          }         
-          Trace("rels-debug") << "********** looping here 6 *set_it = " << *set_it << std::endl;                                  
+          }
           set_it++;
         }
       }
     }
-    Trace("rels-debug") << "********** tc reach here 5" << std::endl;
   }
 
 
@@ -576,43 +574,38 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
 
 
   void TheorySetsRels::inferTC( Node exp, Node tc_rep, std::map< Node, std::hash_set< Node, NodeHashFunction > >& tc_graph,
-                                Node start_node, Node cur_node, std::hash_set< Node, NodeHashFunction >& elements, bool first_round ) {
+                                Node start_node, Node cur_node, std::hash_set< Node, NodeHashFunction >& traversed ) {
     Node pair = constructPair(tc_rep, start_node, cur_node);
-    if(safeAddToMap(d_membership_db, tc_rep, pair)) {
-      addToMap(d_membership_exp_cache, tc_rep, Rewriter::rewrite(exp));
-      sendLemma( MEMBER(pair, tc_rep), Rewriter::rewrite(exp), "Transitivity" );
+    std::map<Node, std::vector<Node> >::iterator mem_it = d_membership_db.find(tc_rep);
+    if(mem_it != d_membership_db.end()) {
+      if(std::find(mem_it->second.begin(), mem_it->second.end(), pair) == mem_it->second.end()) {
+        sendLemma( MEMBER(pair, tc_rep), Rewriter::rewrite(exp), "Transitivity" );
+      }
     }
-
     // check if cur_node has been traversed or not
-    if(!first_round) {
-      std::hash_set< Node, NodeHashFunction >::iterator ele_it = elements.begin();
-      while(ele_it != elements.end()) {
-        if(areEqual(cur_node, *ele_it)) {
-          return;
-        }
-        ele_it++;
-      }
+    if(traversed.find(cur_node) != traversed.end()) {
+      return;
     }
-    std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator pair_set_it = tc_graph.begin();
+    traversed.insert(cur_node);
     Node reason = exp;
-    while(pair_set_it != tc_graph.end()) {
-      if(areEqual(pair_set_it->first, cur_node)) {
-        reason = AND(exp, EQUAL(pair_set_it->first, cur_node));
-        break;
-      }
-      pair_set_it++;
-    }
-    if(pair_set_it != tc_graph.end()) {
-      for(std::hash_set< Node, NodeHashFunction >::iterator set_it = pair_set_it->second.begin();
-          set_it != pair_set_it->second.end(); set_it++) {
-        Node p = constructPair( tc_rep, cur_node, *set_it );
+    std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator cur_set = tc_graph.find(cur_node);
+    if(cur_set != tc_graph.end()) {
+      for(std::hash_set< Node, NodeHashFunction >::iterator set_it = cur_set->second.begin();
+          set_it != cur_set->second.end(); set_it++) {
+        Node new_pair = constructPair( tc_rep, cur_node, *set_it );
         Assert(!reason.isNull());
-        elements.insert(*set_it);
-        inferTC( AND( findMemExp(tc_rep, p), reason ), tc_rep, tc_graph, start_node, *set_it, elements, false );
+        inferTC( AND( findMemExp(tc_rep, new_pair), reason ), tc_rep, tc_graph, start_node, *set_it, traversed );
       }
     }
   }
 
+  void TheorySetsRels::finalizeTCInfer() {
+    Trace("rels-debug") << "[sets-rels] Finalizing transitive closure inferences!" << std::endl;
+    for(TC_IT tc_it = d_membership_tc_cache.begin(); tc_it != d_membership_tc_cache.end(); tc_it++) {
+      inferTC(tc_it->first, tc_it->second);
+    }
+  }
+
   void TheorySetsRels::inferTC(Node tc_rep, std::map< Node, std::hash_set< Node, NodeHashFunction > >& tc_graph) {
     Trace("rels-debug") << "[sets-rels] Build TC graph for tc_rep = " << tc_rep << std::endl;
     for(std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator pair_set_it = tc_graph.begin();
@@ -627,19 +620,11 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
         }
         Assert(!exp.isNull());
         elements.insert(pair_set_it->first);
-        elements.insert(*set_it);
-        inferTC( exp, tc_rep, tc_graph, pair_set_it->first, *set_it, elements, true );
+        inferTC( exp, tc_rep, tc_graph, pair_set_it->first, *set_it, elements );
       }
     }
   }
 
-  void TheorySetsRels::finalizeTCInfer() {
-    Trace("rels-debug") << "[sets-rels] Finalizing transitive closure inferences!" << std::endl;
-    for(TC_IT tc_it = d_membership_tc_cache.begin(); tc_it != d_membership_tc_cache.end(); tc_it++) {
-      inferTC(tc_it->first, tc_it->second);
-    }
-  }
-
   // Bottom-up fashion to compute relations
   void TheorySetsRels::computeRels(Node n) {
     Trace("rels-debug") << "\n[sets-rels] computeJoinOrProductRelations for relation  " << n << std::endl;
@@ -770,11 +755,9 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
             reasons.push_back(explain(r1_exps[i]));            
             reasons.push_back(explain(r2_exps[j]));
             if(r1_exps[i].getKind() == kind::MEMBER && r1_exps[i][0] != r1_elements[i]) {
-              Trace("rels-debug") << "************* $ r1 ele = " << r1_elements[i] << " r1 exp ele = " << r1_exps[i][0] << std::endl;
               reasons.push_back(explain(EQUAL(r1_elements[i], r1_exps[i][0])));            
             }
             if(r2_exps[j].getKind() == kind::MEMBER && r2_exps[j][0] != r2_elements[j]) {
-              Trace("rels-debug") << "************* $ r2 ele = " << r2_elements[j] << " r2 exp ele = " << r2_exps[j][0] << std::endl;
               reasons.push_back(explain(EQUAL(r2_elements[j], r2_exps[j][0])));            
             }
             if(!isProduct) {              
@@ -935,6 +918,8 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
   }
 
   bool TheorySetsRels::areEqual( Node a, Node b ){
+    Assert(a.getType() == b.getType());
+    Trace("rels-eq") << "[sets-rels]**** checking equality between " << a << " and " << b << std::endl;
     if(a == b) {
       return true;
     } else if( hasTerm( a ) && hasTerm( b ) ){
@@ -1001,19 +986,23 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
 
   // tuple might be a member of tc_rep; or it might be a member of rels or tc_terms such that
   // tc_terms are transitive closure of rels and are modulo equal to tc_rep
-  Node TheorySetsRels::findMemExp(Node tc_rep, Node tuple) {
-    Trace("rels-exp") << "TheorySetsRels::findMemExp ( tc_rep = " << tc_rep << ", tuple = " << tuple << ")" << std::endl;
+  Node TheorySetsRels::findMemExp(Node tc_rep, Node pair) {
+    Node fst = RelsUtils::nthElementOfTuple(pair, 0);
+    Node snd = RelsUtils::nthElementOfTuple(pair, 1);
+    Trace("rels-exp") << "TheorySetsRels::findMemExp ( tc_rep = " << tc_rep << ", pair = " << pair << ")" << std::endl;
     std::vector<Node> tc_terms = d_terms_cache.find(tc_rep)->second[kind::TCLOSURE];
     Assert(tc_terms.size() > 0);
     for(unsigned int i = 0; i < tc_terms.size(); i++) {
       Node tc_term = tc_terms[i];
       Node tc_r_rep = getRepresentative(tc_term[0]);
 
-      Trace("rels-exp") << "TheorySetsRels::findMemExp ( r_rep = " << tc_r_rep << ", tuple = " << tuple << ")" << std::endl;
+      Trace("rels-exp") << "TheorySetsRels::findMemExp ( r_rep = " << tc_r_rep << ", pair = " << pair << ")" << std::endl;
       std::map< Node, std::vector< Node > >::iterator tc_r_mems = d_membership_db.find(tc_r_rep);
       if(tc_r_mems != d_membership_db.end()) {
         for(unsigned int i = 0; i < tc_r_mems->second.size(); i++) {
-          if(areEqual(tc_r_mems->second[i], tuple)) {
+          Node fst_mem = RelsUtils::nthElementOfTuple(tc_r_mems->second[i], 0);
+          Node snd_mem = RelsUtils::nthElementOfTuple(tc_r_mems->second[i], 1);
+          if(areEqual(fst_mem, fst) && areEqual(snd_mem, snd)) {
             Node exp = MEMBER(tc_r_mems->second[i], tc_r_mems->first);
             if(tc_r_rep != tc_term[0]) {
               exp = explain(EQUAL(tc_r_rep, tc_term[0]));
@@ -1021,14 +1010,14 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
             if(tc_rep != tc_term) {
               exp = AND(exp, explain(EQUAL(tc_rep, tc_term)));
             }
-            if(tc_r_mems->second[i] != tuple) {
-              if(RelsUtils::nthElementOfTuple(tc_r_mems->second[i], 0) != RelsUtils::nthElementOfTuple(tuple, 0)) {
-                exp = AND(exp, explain(EQUAL(RelsUtils::nthElementOfTuple(tc_r_mems->second[i], 0), RelsUtils::nthElementOfTuple(tuple, 0))));
+            if(tc_r_mems->second[i] != pair) {
+              if(fst_mem != fst) {
+                exp = AND(exp, explain(EQUAL(fst_mem, fst)));
               }
-              if(RelsUtils::nthElementOfTuple(tc_r_mems->second[i], 1) != RelsUtils::nthElementOfTuple(tuple, 1)) {
-                exp = AND(exp, explain(EQUAL(RelsUtils::nthElementOfTuple(tc_r_mems->second[i], 1), RelsUtils::nthElementOfTuple(tuple, 1))));
+              if(snd_mem != snd) {
+                exp = AND(exp, explain(EQUAL(snd_mem, snd)));
               }
-              exp = AND(exp, EQUAL(tc_r_mems->second[i], tuple));
+              exp = AND(exp, EQUAL(tc_r_mems->second[i], pair));
             }
             return Rewriter::rewrite(AND(exp, explain(d_membership_exp_db[tc_r_rep][i])));
           }
@@ -1037,10 +1026,11 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
 
       Node tc_term_rep = getRepresentative(tc_terms[i]);
       std::map< Node, std::vector< Node > >::iterator tc_t_mems = d_membership_db.find(tc_term_rep);
-      Trace("rels-exp") << "TheorySetsRels::findMemExp ( tc_t_rep = " << tc_term_rep << ", tuple = " << tuple << ")" << std::endl;
       if(tc_t_mems != d_membership_db.end()) {
         for(unsigned int j = 0; j < tc_t_mems->second.size(); j++) {
-          if(areEqual(tc_t_mems->second[j], tuple)) {
+          Node fst_mem = RelsUtils::nthElementOfTuple(tc_t_mems->second[j], 0);
+          Node snd_mem = RelsUtils::nthElementOfTuple(tc_t_mems->second[j], 1);
+          if(areEqual(fst_mem, fst) && areEqual(snd_mem, snd)) {
             Node exp = MEMBER(tc_t_mems->second[j], tc_t_mems->first);
             if(tc_rep != tc_terms[i]) {
               exp = AND(exp, explain(EQUAL(tc_rep, tc_terms[i])));
@@ -1048,14 +1038,14 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
             if(tc_term_rep != tc_terms[i]) {
               exp = AND(exp, explain(EQUAL(tc_term_rep, tc_terms[i])));
             }
-            if(tc_t_mems->second[j] != tuple) {
-              if(RelsUtils::nthElementOfTuple(tc_t_mems->second[j], 0) != RelsUtils::nthElementOfTuple(tuple, 0)) {
-                exp = AND(exp, explain(EQUAL(RelsUtils::nthElementOfTuple(tc_t_mems->second[j], 0), RelsUtils::nthElementOfTuple(tuple, 0))));
+            if(tc_t_mems->second[j] != pair) {
+              if(fst_mem != fst) {
+                exp = AND(exp, explain(EQUAL(fst_mem, fst)));
               }
-              if(RelsUtils::nthElementOfTuple(tc_t_mems->second[j], 1) != RelsUtils::nthElementOfTuple(tuple, 1)) {
-                exp = AND(exp, explain(EQUAL(RelsUtils::nthElementOfTuple(tc_t_mems->second[j], 1), RelsUtils::nthElementOfTuple(tuple, 1))));
+              if(snd_mem != snd) {
+                exp = AND(exp, explain(EQUAL(snd_mem, snd)));
               }
-              exp = AND(exp, EQUAL(tc_t_mems->second[j], tuple));
+              exp = AND(exp, EQUAL(tc_t_mems->second[j], pair));
             }
             return Rewriter::rewrite(AND(exp, explain(d_membership_exp_db[tc_term_rep][j])));
           }
@@ -1087,7 +1077,8 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
   }
 
   void TheorySetsRels::makeSharedTerm( Node n ) {
-    if(d_shared_terms.find(n) == d_shared_terms.end() && !n.getType().isBoolean()) {
+    Trace("rels-share") << " [sets-rels] making shared term " << n << std::endl;
+    if(d_shared_terms.find(n) == d_shared_terms.end()) {
       Node skolem = NodeManager::currentNM()->mkSkolem( "sde", n.getType() );
       sendLemma(MEMBER(skolem, SINGLETON(n)), d_trueNode, "share-term");
       d_shared_terms.insert(n);
@@ -1262,7 +1253,8 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
   }
 
   TheorySetsRels::EqcInfo::EqcInfo( context::Context* c ) :
-  d_mem(c), d_not_mem(c), d_in(c), d_out(c), d_tc_mem_exp(c), d_tp(c), d_pt(c), d_join(c), d_tc(c) {}
+  counter(0), d_mem(c), d_not_mem(c), d_in(c), d_out(c), d_tc_mem_exp(c),
+  d_tp(c), d_pt(c), d_join(c), d_tc(c), d_id_in(c), d_id_out(c) {}
 
   void TheorySetsRels::eqNotifyNewClass( Node n ) {
     Trace("rels-std") << "[sets-rels] eqNotifyNewClass:" << " t = " << n << std::endl;
@@ -1272,7 +1264,6 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
                     n.getKind() == kind::TCLOSURE)) {
       getOrMakeEqcInfo( n, true );
     }
-    Trace("rels-std") << "[sets-rels] eqNotifyNewClass*****:" << " t = " << n << std::endl;
   }
   void TheorySetsRels::addTCMem(EqcInfo* tc_ei, Node mem) {
     Node fst = RelsUtils::nthElementOfTuple(mem, 0);
@@ -1321,8 +1312,6 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
       if(ei == NULL) {
         ei = getOrMakeEqcInfo( t2_1rep, true );
       }
-      // might not need to store the membership info
-      // if we don't need to consider the eqc merge?
       if(polarity) {
         ei->d_mem.insert(t2[0]);
       } else {
@@ -1345,8 +1334,8 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
       if(polarity) {
         if(!ei->d_tc.get().isNull()) {
           addTCMem(ei, t2[0]);
-          ei->d_tc_mem_exp.insert(t2[0], t2);
-          sendInferTC(ei, t2[0], t2);
+          ei->d_tc_mem_exp.insert(t2[0], explain(t2));
+          sendInferTC(ei, t2[0], explain(t2));
         } else {
           std::vector<TypeNode> tup_types = t2[1].getType().getSetElementType().getTupleTypes();
           if( tup_types.size() == 2 && tup_types[0] == tup_types[1] ) {
@@ -1354,7 +1343,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
             EqcInfo* tc_ei = getOrMakeEqcInfo( tc_n );
             if(tc_ei != NULL) {
               addTCMem(tc_ei, t2[0]);
-              Node exp = (tc_n == tc_ei->d_tc.get()) ? t2 : AND(EQUAL(tc_n, tc_ei->d_tc.get()), t2);
+              Node exp = (tc_n == tc_ei->d_tc.get()) ? explain(t2) : AND(EQUAL(tc_n, tc_ei->d_tc.get()), explain(t2));
               tc_ei->d_tc_mem_exp.insert(t2[0], exp);
               sendInferTC(tc_ei, t2[0], exp);
             }
@@ -1386,26 +1375,27 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
     seen.insert(RelsUtils::nthElementOfTuple(mem, 0));
     sendInferInTC(tc_ei, RelsUtils::nthElementOfTuple(mem, 0), RelsUtils::nthElementOfTuple(mem, 1), seen, exp);
     sendInferOutTC(tc_ei, RelsUtils::nthElementOfTuple(mem, 0), RelsUtils::nthElementOfTuple(mem, 1), seen, exp);
+    Trace("rels-std") << "[sets-rels] *** done with sendInferTC member = " << mem << " with explanation = " << exp << std::endl;
   }
 
   void TheorySetsRels::sendInferInTC(EqcInfo* tc_ei, Node fst, Node snd, std::hash_set<Node, NodeHashFunction> seen, Node exp) {
     for(NodeListMap::iterator nl_it = tc_ei->d_in.begin(); nl_it != tc_ei->d_in.end(); nl_it++) {
-      if(areEqual((*nl_it).first, fst)) {
-        for(NodeList::const_iterator itr = (*nl_it).second->begin(); itr != (*nl_it).second->end(); itr++) {
-          Node pair = RelsUtils::constructPair(tc_ei->d_tc.get(), *itr, snd);
+      if((*nl_it).first == fst) {
+        for(NodeList::const_iterator in_itr = (*nl_it).second->begin(); in_itr != (*nl_it).second->end(); in_itr++) {
+          Node pair = RelsUtils::constructPair(tc_ei->d_tc.get(), *in_itr, snd);
           if(!tc_ei->d_mem.contains(pair)) {
             Node reason = ((*nl_it).first == fst) ?
-                          Rewriter::rewrite(AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), *itr, fst)))):
-                          Rewriter::rewrite(AND(EQUAL((*nl_it).first, fst), AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), *itr, fst)))));
+                          Rewriter::rewrite(AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), *in_itr, fst)))):
+                          Rewriter::rewrite(AND(EQUAL((*nl_it).first, fst), AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), *in_itr, fst)))));
             Node tc_lemma = NodeManager::currentNM()->mkNode(kind::IMPLIES, reason, MEMBER(pair, tc_ei->d_tc.get()));
             d_pending_merge.push_back(tc_lemma);
             d_lemma.insert(tc_lemma);
             tc_ei->d_mem.insert(pair);
             tc_ei->d_tc_mem_exp.insert(pair, reason);
           }
-          if(seen.find(*itr) == seen.end()) {
-            seen.insert(*itr);
-            sendInferInTC(tc_ei, *itr, snd, seen, AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), *itr, fst))));
+          if(seen.find(*in_itr) == seen.end()) {
+            seen.insert(*in_itr);
+            sendInferInTC(tc_ei, *in_itr, snd, seen, AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), *in_itr, fst))));
           }
         }
       }
@@ -1414,7 +1404,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
 
   void TheorySetsRels::sendInferOutTC(EqcInfo* tc_ei, Node fst, Node snd, std::hash_set<Node, NodeHashFunction> seen, Node exp) {
     for(NodeListMap::iterator nl_it = tc_ei->d_out.begin(); nl_it != tc_ei->d_out.end(); nl_it++) {
-      if(areEqual((*nl_it).first, snd)) {
+      if((*nl_it).first == snd) {
         for(NodeList::const_iterator itr = (*nl_it).second->begin(); itr != (*nl_it).second->end(); itr++) {
           Node pair = RelsUtils::constructPair(tc_ei->d_tc.get(), fst, *itr);
           if(!tc_ei->d_mem.contains(pair)) {
index 0d24c65b385f83122dc3f325579f05b86b2c86f5..faee651b7b1a600e733ad1c3e0771285921e5e7a 100644 (file)
@@ -45,9 +45,11 @@ public:
 class TheorySetsRels {
 
   typedef context::CDChunkList<Node> NodeList;
+  typedef context::CDChunkList<int> IdList;
   typedef context::CDHashSet<Node, NodeHashFunction> NodeSet;
   typedef context::CDHashMap<Node, bool, NodeHashFunction> NodeBoolMap;
   typedef context::CDHashMap<Node, NodeList*, NodeHashFunction> NodeListMap;
+  typedef context::CDHashMap<int, IdList*> IdListMap;
   typedef context::CDHashMap<Node, NodeSet*, NodeHashFunction> NodeSetMap;
   typedef context::CDHashMap<Node, Node, NodeHashFunction> NodeMap;
 
@@ -76,6 +78,7 @@ private:
   public:
     EqcInfo( context::Context* c );
     ~EqcInfo(){}
+    int counter;
     NodeSet d_mem;
     NodeSet d_not_mem;
     NodeListMap d_in;
@@ -85,6 +88,10 @@ private:
     context::CDO< Node > d_pt;
     context::CDO< Node > d_join;
     context::CDO< Node > d_tc;
+    IdListMap d_id_in;
+    IdListMap d_id_out;
+    std::hash_map<int, Node> d_id_node;
+    std::hash_map<Node, int> d_node_id;
   };
 
   /** has eqc info */
@@ -168,7 +175,7 @@ private:
   void finalizeTCInfer();
   void inferTC( Node, std::map< Node, std::hash_set< Node, NodeHashFunction > >& );
   void inferTC( Node, Node, std::map< Node, std::hash_set< Node, NodeHashFunction > >&,
-                Node, Node, std::hash_set< Node, NodeHashFunction >&, bool first_round = false);
+                Node, Node, std::hash_set< Node, NodeHashFunction >&);
 
   Node explain(Node);
 
@@ -184,7 +191,7 @@ private:
   // Helper functions
   inline Node getReason(Node tc_rep, Node tc_term, Node tc_r_rep, Node tc_r);
   inline Node constructPair(Node tc_rep, Node a, Node b);
-  Node findMemExp(Node r, Node tuple);
+  Node findMemExp(Node r, Node pair);
   bool safeAddToMap( std::map< Node, std::vector<Node> >&, Node, Node );
   void addToMap( std::map< Node, std::vector<Node> >&, Node, Node );
   bool hasMember( Node, Node );