- added standard effort for transpose
authorPaulMeng <baolmeng@gmail.com>
Thu, 7 Apr 2016 18:41:21 +0000 (13:41 -0500)
committerPaulMeng <baolmeng@gmail.com>
Thu, 7 Apr 2016 18:41:21 +0000 (13:41 -0500)
- implement transitive closure rule for concrete input

src/theory/sets/theory_sets_private.cpp
src/theory/sets/theory_sets_private.h
src/theory/sets/theory_sets_rels.cpp
src/theory/sets/theory_sets_rels.h
src/theory/sets/theory_sets_rewriter.cpp

index 4cb82b66d79b6c39b06a7bbd2de3a5a659a8d4f7..10fc9f1953150674684775cc5969208b15467042 100644 (file)
@@ -95,10 +95,10 @@ void TheorySetsPrivate::check(Theory::Effort level) {
     Debug("sets") << "[sets]  is complete = " << isComplete() << std::endl;
 
   }
+
+  // invoke the relational solver
   d_rels->check(level);
-//  if( level == Theory::EFFORT_FULL ) {
-//    d_rels->doPendingLemmas();
-//  }
+
   if( (level == Theory::EFFORT_FULL || options::setsEagerLemmas() ) && !isComplete()) {
     d_external.d_out->lemma(getLemma());
     return;
@@ -119,29 +119,31 @@ void TheorySetsPrivate::assertEquality(TNode fact, TNode reason, bool learnt)
 
   bool polarity = fact.getKind() != kind::NOT;
   TNode atom = polarity ? fact : fact[0];
-
+  Debug("sets-assert") << "\n finish assert equality!!!!!!!*************** 0" <<std::endl;
   // fact already holds
   if( holds(atom, polarity) ) {
     Debug("sets-assert") << "[sets-assert]   already present, skipping" << std::endl;
     return;
   }
-
+  Debug("sets-assert") << "\n finish assert equality!!!!!!!*************** 1" <<std::endl;
   // assert fact & check for conflict
   if(learnt) {
+    Debug("sets-assert") << "\n finish assert equality!!!!!!!*************** 5" <<std::endl;
     registerReason(reason, /*save=*/ true);
   }
+  Debug("sets-assert") << "\n finish assert equality!!!!!!!*************** 4" <<std::endl;
   d_equalityEngine.assertEquality(atom, polarity, reason);
-
+  Debug("sets-assert") << "\n finish assert equality!!!!!!!*************** 2" <<std::endl;
   if(!d_equalityEngine.consistent()) {
     Debug("sets-assert") << "[sets-assert]   running into a conflict" << std::endl;
     d_conflict = true;
     return;
   }
-
+  Debug("sets-assert") << "\n finish assert equality!!!!!!!*************** 3" <<std::endl;
   if(!polarity && atom[0].getType().isSet()) {
     addToPending(atom);
   }
-
+  Debug("sets-assert") << "\n finish assert equality!!!!!!!*************** " <<std::endl;
 }/* TheorySetsPrivate::assertEquality() */
 
 
@@ -1308,21 +1310,23 @@ void TheorySetsPrivate::NotifyClass::eqNotifyConstantTermMerge(TNode t1, TNode t
   Debug("sets-eq") << "[sets-eq] eqNotifyConstantTermMerge " << " t1 = " << t1 << " t2 = " << t2 << std::endl;
   d_theory.conflict(t1, t2);
 }
-
-// void TheorySetsPrivate::NotifyClass::eqNotifyNewClass(TNode t)
-// {
-//   Debug("sets-eq") << "[sets-eq] eqNotifyNewClass:" << " t = " << t << std::endl;
-// }
+//
+ void TheorySetsPrivate::NotifyClass::eqNotifyNewClass(TNode t)
+ {
+   Debug("sets-eq") << "[sets-eq] eqNotifyNewClass:" << " t = " << t << std::endl;
+   d_theory.d_rels->eqNotifyNewClass(t);
+ }
 
 // void TheorySetsPrivate::NotifyClass::eqNotifyPreMerge(TNode t1, TNode t2)
 // {
 //   Debug("sets-eq") << "[sets-eq] eqNotifyPreMerge:" << " t1 = " << t1 << " t2 = " << t2 << std::endl;
 // }
-
-// void TheorySetsPrivate::NotifyClass::eqNotifyPostMerge(TNode t1, TNode t2)
-// {
-//   Debug("sets-eq") << "[sets-eq] eqNotifyPostMerge:" << " t1 = " << t1 << " t2 = " << t2 << std::endl;
-// }
+//
+ void TheorySetsPrivate::NotifyClass::eqNotifyPostMerge(TNode t1, TNode t2)
+ {
+   Debug("sets-eq") << "[sets-eq] eqNotifyPostMerge:" << " t1 = " << t1 << " t2 = " << t2 << std::endl;
+   d_theory.d_rels->eqNotifyPostMerge(t1, t2);
+ }
 
 // void TheorySetsPrivate::NotifyClass::eqNotifyDisequal(TNode t1, TNode t2, TNode reason)
 // {
index ad04ff2736a396995652fceb70aa71dab4eb856d..ce81fac27c20c0071208db264561e4403140b2eb 100644 (file)
@@ -87,14 +87,14 @@ private:
     TheorySetsPrivate& d_theory;
 
   public:
-    NotifyClass(TheorySetsPrivate& theory): d_theory(theory) {}
+    NotifyClass(TheorySetsPrivate& theory): d_theory(theory){}
     bool eqNotifyTriggerEquality(TNode equality, bool value);
     bool eqNotifyTriggerPredicate(TNode predicate, bool value);
     bool eqNotifyTriggerTermEquality(TheoryId tag, TNode t1, TNode t2, bool value);
     void eqNotifyConstantTermMerge(TNode t1, TNode t2);
-    void eqNotifyNewClass(TNode t) {}
+    void eqNotifyNewClass(TNode t);
     void eqNotifyPreMerge(TNode t1, TNode t2) {}
-    void eqNotifyPostMerge(TNode t1, TNode t2) {}
+    void eqNotifyPostMerge(TNode t1, TNode t2);
     void eqNotifyDisequal(TNode t1, TNode t2, TNode reason) {}
   } d_notify;
 
@@ -199,8 +199,11 @@ private:
   // more debugging stuff
   friend class TheorySetsScrutinize;
   TheorySetsScrutinize* d_scrutinize;
-  TheorySetsRels* d_rels;
   void dumpAssertionsHumanified() const;  /** do some formatting to make them more readable */
+
+  // relational solver
+  TheorySetsRels* d_rels;
+
 };/* class TheorySetsPrivate */
 
 
index 7baf5976a56653029858efa90b50a1e50057d5eb..d9b975eb997b68474b20f5ad81558acb384a8b77 100644 (file)
@@ -36,43 +36,49 @@ namespace CVC4 {
 namespace theory {
 namespace sets {
 
-typedef std::map<Node, std::map<kind::Kind_t, std::vector<Node> > >::iterator term_it;
-typedef std::map<Node, std::vector<Node> >::iterator mem_it;
+typedef std::map<Node, std::map<kind::Kind_t, std::vector<Node> > >::iterator TERM_IT;
+typedef std::map<Node, std::map<Node, std::hash_set<Node, NodeHashFunction> > >::iterator TC_IT;
+typedef std::map<Node, std::vector<Node> >::iterator MEM_IT;
+typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_PAIR_IT;
 
   void TheorySetsRels::check(Theory::Effort level) {
     Trace("rels") << "\n[sets-rels] ******************************* Start the relational solver *******************************\n" << std::endl;
-    collectRelsInfo();
-    check();
-    doPendingLemmas();
-    Assert(d_lemma_cache.empty());
-    Assert(d_pending_facts.empty());
+    if(Theory::fullEffort(level)) {
+      collectRelsInfo();
+      check();
+      doPendingLemmas();
+      Assert(d_lemma_cache.empty());
+      Assert(d_pending_facts.empty());
+    } else {
+      doPendingMerge();
+    }
     Trace("rels") << "\n[sets-rels] ******************************* Done with the relational solver *******************************\n" << std::endl;
   }
 
   void TheorySetsRels::check() {
-    mem_it m_it = d_membership_cache.begin();
-    while(m_it != d_membership_cache.end()) {
+    MEM_IT m_it = d_membership_constraints_cache.begin();
+    while(m_it != d_membership_constraints_cache.end()) {
       Node rel_rep = m_it->first;
+      Trace("rels-debug") << "[sets-rels] Processing rel_rep = " << rel_rep << std::endl;
 
       // No relational terms found with rel_rep as its representative
+      // But TRANSPOSE(rel_rep) may occur in the context
       if(d_terms_cache.find(rel_rep) == d_terms_cache.end()) {
-        // TRANSPOSE(rel_rep) may occur in the context
         Node tp_rel = NodeManager::currentNM()->mkNode(kind::TRANSPOSE, rel_rep);
         Node tp_rel_rep = getRepresentative(tp_rel);
         if(d_terms_cache.find(tp_rel_rep) != d_terms_cache.end()) {
           for(unsigned int i = 0; i < m_it->second.size(); i++) {
-            Node exp = tp_rel == tp_rel_rep ? d_membership_exp_cache[rel_rep][i]
-                                            : AND(d_membership_exp_cache[rel_rep][i], EQUAL(tp_rel, tp_rel_rep));
+//            Node exp = tp_rel == tp_rel_rep ? d_membership_exp_cache[rel_rep][i]
+//                                            : AND(d_membership_exp_cache[rel_rep][i], EQUAL(tp_rel, tp_rel_rep));
             // Lazily apply transpose-occur rule.
             // Need to eagerly apply if we don't send facts as lemmas
-            applyTransposeRule(exp, tp_rel_rep, true);
+            applyTransposeRule(d_membership_exp_cache[rel_rep][i], tp_rel_rep, true);
           }
         }
       } else {
         for(unsigned int i = 0; i < m_it->second.size(); i++) {
           Node exp = d_membership_exp_cache[rel_rep][i];
           std::map<kind::Kind_t, std::vector<Node> > kind_terms = d_terms_cache[rel_rep];
-
           if(kind_terms.find(kind::TRANSPOSE) != kind_terms.end()) {
             std::vector<Node> tp_terms = kind_terms[kind::TRANSPOSE];
             // exp is a membership term and tp_terms contains all
@@ -95,10 +101,17 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
               applyProductRule(exp, product_terms[j]);
             }
           }
+          if(kind_terms.find(kind::TRANSCLOSURE) != kind_terms.end()) {
+            std::vector<Node> tc_terms = kind_terms[kind::TRANSCLOSURE];
+            for(unsigned int j = 0; j < tc_terms.size(); j++) {
+              applyTCRule(exp, tc_terms[j]);
+            }
+          }
         }
       }
       m_it++;
     }
+    finalizeTCInfer();
   }
 
   /*
@@ -125,10 +138,11 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
             if(n[0].isVar()){
               reduceTupleVar(n);
             } else {
-              if(safeAddToMap(d_membership_cache, rel_rep, tup_rep)) {
+              if(safeAddToMap(d_membership_constraints_cache, rel_rep, tup_rep)) {
                 bool true_eq = areEqual(r, d_trueNode);
                 Node reason = true_eq ? n : n.negate();
                 addToMap(d_membership_exp_cache, rel_rep, reason);
+                Trace("rels-mem") << "[******] exp: " << reason << " for " << rel_rep << std::endl;
                 if(true_eq) {
                   addToMembershipDB(rel_rep, tup_rep, reason);
                 }
@@ -161,7 +175,7 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
         // need to add all tuple elements as shared terms
         } else if(n.getType().isTuple() && !n.isConst() && !n.isVar()) {
           for(unsigned int i = 0; i < n.getType().getTupleLength(); i++) {
-            Node element = selectElement(n, i);
+            Node element = nthElementOfTuple(n, i);
             if(!element.isConst()) {
               makeSharedTerm(element);
             }
@@ -174,8 +188,124 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
     Trace("rels-debug") << "[sets-rels] Done with collecting relational terms!" << std::endl;
   }
 
- /*  product-split rule:    (a, b) IS_IN (X PRODUCT Y)
-  *                  ----------------------------------
+  /*
+   *
+   *
+   * transitive closure rule 1:   y = (TRANSCLOSURE x)
+   *                           ---------------------------------------------
+   *                              y = x | x.x | x.x.x | ... (| is union)
+   *
+   *
+   *
+   * transitive closure rule 2:   TRANSCLOSURE(x)
+   *                            -----------------------------------------------------------
+   *                              x <= TRANSCLOSURE(x) && (x JOIN x) <= TRANSCLOSURE(x) ....
+   *
+   *                              TC(x) = TC(y) => x = y
+   *
+   */
+
+  void TheorySetsRels::buildTCGraph(Node tc_r_rep, Node tc_rep, Node tc_term) {
+    std::map< Node, std::hash_set< Node, NodeHashFunction > > tc_graph;
+    MEM_IT mem_it = d_membership_db.find(tc_r_rep);
+    if(mem_it != d_membership_db.end()) {
+      for(std::vector<Node>::iterator pair_it = mem_it->second.begin();
+          pair_it != mem_it->second.end(); pair_it++) {
+        TC_PAIR_IT pair_set_it = tc_graph.find(nthElementOfTuple(*pair_it, 0));
+        if( pair_set_it != tc_graph.end() ) {
+          pair_set_it->second.insert(nthElementOfTuple(*pair_it, 1));
+        } else {
+          std::hash_set< Node, NodeHashFunction > snd_pair_set;
+          snd_pair_set.insert(nthElementOfTuple(*pair_it, 1));
+          tc_graph[nthElementOfTuple(*pair_it, 0)] = snd_pair_set;
+        }
+      }
+    }
+    Node reason = getReason(tc_rep, tc_term, tc_r_rep, tc_term[0]);
+    if(!reason.isNull()) {
+      d_membership_tc_exp_cache[tc_rep] = reason;
+    }
+    d_membership_tc_cache[tc_rep] = tc_graph;
+  }
+
+  void TheorySetsRels::applyTCRule(Node exp, Node tc_term) {
+    Trace("rels-debug") << "\n[sets-rels] *********** Applying TRANSITIVE CLOSURE rule on  "
+                        << tc_term << " with explanation " << exp << std::endl;
+    bool polarity = exp.getKind() != kind::NOT;
+    Node atom = polarity ? exp : exp[0];
+    Node tc_rep = getRepresentative(tc_term);
+    Node tc_r_rep = getRepresentative(tc_term[0]);
+
+    // build the TC graph for tc_rep if it was not created before
+    if( d_membership_tc_cache.find(tc_rep) == d_membership_tc_cache.end() ) {
+      buildTCGraph(tc_r_rep, tc_rep, tc_term);
+    }
+    // insert atom[0] in the tc_graph
+    TC_IT tc_graph_it = d_membership_tc_cache.find(tc_rep);
+    if(polarity) {
+      if(tc_graph_it != d_membership_tc_cache.end()) {
+        TC_PAIR_IT pair_set_it = tc_graph_it->second.find(nthElementOfTuple(atom[0], 0));
+        if(pair_set_it != tc_graph_it->second.end()) {
+          pair_set_it->second.insert(nthElementOfTuple(atom[0], 1));
+        } else {
+          std::hash_set< Node, NodeHashFunction > pair_set;
+          pair_set.insert(nthElementOfTuple(atom[0], 1));
+          tc_graph_it->second[nthElementOfTuple(atom[0], 0)] = pair_set;
+        }
+        Node reason = getReason(tc_rep, tc_term, tc_r_rep, tc_term[0]);
+        std::map< Node, Node >::iterator exp_it = d_membership_tc_exp_cache.find(tc_rep);
+        if(!reason.isNull() && exp_it->second != reason) {
+          d_membership_tc_exp_cache[tc_rep] = Rewriter::rewrite(AND(exp_it->second, reason));
+        }
+      } else {
+        std::map< Node, std::hash_set< Node, NodeHashFunction > > pair_set;
+        std::hash_set< Node, NodeHashFunction > set;
+        set.insert(nthElementOfTuple(atom[0], 1));
+        pair_set[nthElementOfTuple(atom[0], 0)] = set;
+        d_membership_tc_cache[tc_rep] = pair_set;
+        Node reason = getReason(tc_rep, tc_term, tc_r_rep, tc_term[0]);
+        if(!reason.isNull()) {
+          d_membership_tc_exp_cache[tc_rep] = reason;
+        }
+      }
+    // check if atom[0] exists in TC graph for conflict
+    } else {
+      if(tc_graph_it != d_membership_tc_cache.end()) {
+        checkTCGraphForConflict(atom, tc_rep, d_trueNode, nthElementOfTuple(atom[0], 0),
+                                nthElementOfTuple(atom[0], 1), tc_graph_it->second);
+      }
+    }
+  }
+
+  void TheorySetsRels::checkTCGraphForConflict (Node atom, Node tc_rep, Node exp, Node a, Node b,
+                                                std::map< Node, std::hash_set< Node, NodeHashFunction > >& pair_set) {
+    TC_PAIR_IT pair_set_it = pair_set.find(a);
+    if(pair_set_it != pair_set.end()) {
+      if(pair_set_it->second.find(b) != pair_set_it->second.end()) {
+        Node reason = AND(exp, findMemExp(tc_rep, constructPair(tc_rep, a, b)));
+        if(atom[1] != tc_rep) {
+          reason = AND(exp, EQUAL(atom[1], tc_rep));
+        }
+        Trace("rels-debug") << "[sets-rels] found a conflict and send out lemma : "
+                            <<  NodeManager::currentNM()->mkNode(kind::IMPLIES, reason, atom) << std::endl;
+        d_sets_theory.d_out->lemma(NodeManager::currentNM()->mkNode(kind::IMPLIES, reason, atom));
+//        Trace("rels-debug") << "[sets-rels] found a conflict and send out lemma : "
+//                            << AND(reason.negate(), atom) << std::endl;
+//        d_sets_theory.d_out->conflict(AND(reason.negate(), atom));
+      } else {
+        std::hash_set< Node, NodeHashFunction >::iterator set_it = pair_set_it->second.begin();
+        while(set_it != pair_set_it->second.end()) {
+          checkTCGraphForConflict(atom, tc_rep, AND(exp, findMemExp(tc_rep, constructPair(tc_rep, a, *set_it))),
+                                  *set_it, b, pair_set);
+          set_it++;
+        }
+      }
+    }
+  }
+
+
+ /*  product-split rule:  (a, b) IS_IN (X PRODUCT Y)
+  *                     ----------------------------------
   *                       a IS_IN X  && b IS_IN Y
   *
   *  product-compose rule: (a, b) IS_IN X    (c, d) IS_IN Y  NOT (r, s, t, u) IS_IN (X PRODUCT Y)
@@ -184,12 +314,13 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
   */
 
   void TheorySetsRels::applyProductRule(Node exp, Node product_term) {
+    Trace("rels-debug") << "\n[sets-rels] *********** Applying PRODUCT rule  " << std::endl;
     bool polarity = exp.getKind() != kind::NOT;
     Node atom = polarity ? exp : exp[0];
     Node r1_rep = getRepresentative(product_term[0]);
     Node r2_rep = getRepresentative(product_term[1]);
 
-    if(polarity & d_lemma.find(exp) != d_lemma.end()) {
+    if(polarity) {
       Trace("rels-debug") << "\n[sets-rels] Apply PRODUCT-SPLIT rule on term: " << product_term
                           << " with explanation: " << exp << std::endl;
       std::vector<Node> r1_element;
@@ -203,13 +334,13 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
 
       r1_element.push_back(Node::fromExpr(dt[0].getConstructor()));
       for(; i < s1_len; ++i) {
-        r1_element.push_back(selectElement(atom[0], i));
+        r1_element.push_back(nthElementOfTuple(atom[0], i));
       }
 
       dt = r2_rep.getType().getSetElementType().getDatatype();
       r2_element.push_back(Node::fromExpr(dt[0].getConstructor()));
       for(; i < tup_len; ++i) {
-        r2_element.push_back(selectElement(atom[0], i));
+        r2_element.push_back(nthElementOfTuple(atom[0], i));
       }
 
       Node fact;
@@ -217,7 +348,7 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
       Node t1 = getRepresentative(nm->mkNode(kind::APPLY_CONSTRUCTOR, r1_element));
       Node t2 = getRepresentative(nm->mkNode(kind::APPLY_CONSTRUCTOR, r2_element));
 
-      if(!hasTuple(r1_rep, t1)) {
+      if(!hasMember(r1_rep, t1)) {
         fact = MEMBER( t1, r1_rep );
         if(r1_rep != product_term[0])
           reason = Rewriter::rewrite(AND(reason, EQUAL(r1_rep, product_term[0])));
@@ -226,7 +357,7 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
         sendInfer(fact, reason, "product-split");
       }
 
-      if(!hasTuple(r2_rep, t2)) {
+      if(!hasMember(r2_rep, t2)) {
         fact = MEMBER( t2, r2_rep );
         if(r2_rep != product_term[1])
           reason = Rewriter::rewrite(AND(reason, EQUAL(r2_rep, product_term[1])));
@@ -252,12 +383,14 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
    *                                      (a, c) IS_IN (X JOIN Y)
    */
   void TheorySetsRels::applyJoinRule(Node exp, Node join_term) {
+    Trace("rels-debug") << "\n[sets-rels] *********** Applying JOIN rule  " << std::endl;
     bool polarity = exp.getKind() != kind::NOT;
     Node atom = polarity ? exp : exp[0];
     Node r1_rep = getRepresentative(join_term[0]);
     Node r2_rep = getRepresentative(join_term[1]);
 
-    if(polarity && d_lemma.find(exp) == d_lemma.end()) {
+    if(polarity) {
+
       Trace("rels-debug") <<  "\n[sets-rels] Apply JOIN-SPLIT rule on term: " << join_term
                           << " with explanation: " << exp << std::endl;
 
@@ -273,7 +406,7 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
 
       r1_element.push_back(Node::fromExpr(dt[0].getConstructor()));
       for(; i < s1_len-1; ++i) {
-        r1_element.push_back(selectElement(atom[0], i));
+        r1_element.push_back(nthElementOfTuple(atom[0], i));
       }
       r1_element.push_back(shared_x);
 
@@ -281,7 +414,7 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
       r2_element.push_back(Node::fromExpr(dt[0].getConstructor()));
       r2_element.push_back(shared_x);
       for(; i < tup_len; ++i) {
-        r2_element.push_back(selectElement(atom[0], i));
+        r2_element.push_back(nthElementOfTuple(atom[0], i));
       }
 
       Node t1 = nm->mkNode(kind::APPLY_CONSTRUCTOR, r1_element);
@@ -300,19 +433,21 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
       }
 
       Node fact;
-      Node reason = atom[1] == join_term ? exp : AND(exp, EQUAL(atom[1], join_term));
+      Node reason = atom[1] == join_term ? exp : AND(exp, explain(EQUAL(atom[1], join_term)));
       Node reasons = reason;
 
       fact = MEMBER(t1, r1_rep);
-      if(r1_rep != join_term[0])
-        reasons = Rewriter::rewrite(AND(reason, EQUAL(r1_rep, join_term[0])));
+      if(r1_rep != join_term[0]) {
+        reasons = Rewriter::rewrite(AND(reason, explain(EQUAL(r1_rep, join_term[0]))));
+      }
       addToMembershipDB(r1_rep, t1, reasons);
       sendInfer(fact, reasons, "join-split");
 
       reasons = reason;
       fact = MEMBER(t2, r2_rep);
-      if(r2_rep != join_term[1])
-        reasons = Rewriter::rewrite(AND(reason, EQUAL(r2_rep, join_term[1])));
+      if(r2_rep != join_term[1]) {
+        reasons = Rewriter::rewrite(AND(reason, explain(EQUAL(r2_rep, join_term[1]))));
+      }
       addToMembershipDB(r2_rep, t2, reasons);
       sendInfer(fact, reasons, "join-split");
 
@@ -329,20 +464,26 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
   /*
    * transpose-occur rule:   [NOT] (a, b) IS_IN X   (TRANSPOSE X) occurs
    *                         -------------------------------------------------------
-   *                                   [NOT] (b, a) IS_IN (TRANSPOSE X)
+   *                         [NOT] (b, a) IS_IN (TRANSPOSE X)
    *
-   * transpose rule:       [NOT] (a, b) IS_IN (TRANSPOSE X)
-   *                ------------------------------------------------
+   * transpose-reverse rule:    [NOT] (a, b) IS_IN (TRANSPOSE X)
+   *                         ------------------------------------------------
    *                            [NOT] (b, a) IS_IN X
+   *
+   *
+   * transpose-equal rule:   [NOT]  (TRANSPOSE X) = (TRANSPOSE Y)
+   *                         -----------------------------------------------
+   *                         [NOT]  (X = Y)
    */
   void TheorySetsRels::applyTransposeRule(Node exp, Node tp_term, bool tp_occur) {
-    Trace("rels-debug") << "\n[sets-rels] Apply transpose rule on term: " << tp_term
-                        << " with explanation: " << exp << std::endl;
+    Trace("rels-debug") << "\n[sets-rels] *********** Applying TRANSPOSE rule  " << std::endl;
     bool polarity = exp.getKind() != kind::NOT;
     Node atom = polarity ? exp : exp[0];
     Node reversedTuple = getRepresentative(reverseTuple(atom[0]));
 
     if(tp_occur) {
+      Trace("rels-debug") << "\n[sets-rels] Apply TRANSPOSE-OCCUR rule on term: " << tp_term
+                             << " with explanation: " << exp << std::endl;
       Node fact = polarity ? MEMBER(reversedTuple, tp_term) : MEMBER(reversedTuple, tp_term).negate();
       if(holds(fact)) {
         Trace("rels-debug") << "[sets-rels] New fact: " << fact << " already holds. Skip...." << std::endl;
@@ -387,6 +528,71 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
     }
   }
 
+  // Todo: need to add equality between two pair's left and right elements as explanation
+  void TheorySetsRels::inferTC( Node exp, Node tc_rep, std::map< Node, std::hash_set< Node, NodeHashFunction > >& tc_graph,
+                                Node start_node, Node cur_node, std::hash_set< Node, NodeHashFunction >& elements, bool first_round ) {
+    Node pair = constructPair(tc_rep, start_node, cur_node);
+    if(safeAddToMap(d_membership_db, tc_rep, pair)) {
+      addToMap(d_membership_exp_db, tc_rep, exp);
+      sendLemma( MEMBER(pair, tc_rep), exp, "Transitivity" );
+    }
+
+    if(!first_round) {
+      std::hash_set< Node, NodeHashFunction >::iterator ele_it = elements.begin();
+      while(ele_it != elements.end()) {
+        if(areEqual(cur_node, *ele_it)) {
+          return;
+        }
+        ele_it++;
+      }
+    }
+    std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator pair_set_it = tc_graph.begin();
+    while(pair_set_it != tc_graph.end()) {
+      if(areEqual(pair_set_it->first, cur_node)) {
+        break;
+      }
+      pair_set_it++;
+    }
+    if(pair_set_it != tc_graph.end()) {
+      for(std::hash_set< Node, NodeHashFunction >::iterator set_it = pair_set_it->second.begin();
+          set_it != pair_set_it->second.end(); set_it++) {
+        Node p = constructPair( tc_rep, cur_node, *set_it );
+        Node reason = AND( findMemExp(tc_rep, p), exp );
+        Assert(!reason.isNull());
+        elements.insert(*set_it);
+        inferTC( reason, tc_rep, tc_graph, start_node, *set_it, elements, false );
+      }
+    }
+  }
+
+  void TheorySetsRels::inferTC(Node tc_rep, std::map< Node, std::hash_set< Node, NodeHashFunction > >& tc_graph) {
+    Trace("rels-debug") << "[sets-rels] Build TC graph for tc_rep = " << tc_rep << std::endl;
+    for(std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator pair_set_it = tc_graph.begin();
+        pair_set_it != tc_graph.end(); pair_set_it++) {
+      for(std::hash_set< Node, NodeHashFunction >::iterator set_it = pair_set_it->second.begin();
+          set_it != pair_set_it->second.end(); set_it++) {
+        std::hash_set<Node, NodeHashFunction> elements;
+        Node pair = constructPair(tc_rep, pair_set_it->first, *set_it);
+        Node exp = findMemExp(tc_rep, pair);
+        Trace("rels-debug") << "[sets-rels] pair = " << pair << std::endl;
+        if(d_membership_tc_exp_cache.find(tc_rep) != d_membership_tc_exp_cache.end()) {
+          exp = AND(d_membership_tc_exp_cache[tc_rep], exp);
+        }
+        Assert(!exp.isNull());
+        elements.insert(pair_set_it->first);
+        elements.insert(*set_it);
+        inferTC( exp, tc_rep, tc_graph, pair_set_it->first, *set_it, elements, true );
+      }
+    }
+  }
+
+  void TheorySetsRels::finalizeTCInfer() {
+    Trace("rels-debug") << "[sets-rels] Finalizing transitive closure inferences!" << std::endl;
+    for(TC_IT tc_it = d_membership_tc_cache.begin(); tc_it != d_membership_tc_cache.end(); tc_it++) {
+      inferTC(tc_it->first, tc_it->second);
+    }
+  }
+
   // Bottom-up fashion to compute relations
   void TheorySetsRels::computeRels(Node n) {
     Trace("rels-debug") << "\n[sets-rels] computeJoinOrProductRelations for relation  " << n << std::endl;
@@ -417,7 +623,7 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
     if(d_membership_db.find(getRepresentative(n[0])) == d_membership_db.end() ||
        d_membership_db.find(getRepresentative(n[1])) == d_membership_db.end())
           return;
-    composeTuplesForRels(n);
+    composeTupleMemForRels(n);
   }
 
   void TheorySetsRels::computeTransposeRelations(Node n) {
@@ -459,7 +665,7 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
    * e.g. If (a, b) in X and (b, c) in Y, (a, c) in (X JOIN Y)
    *
    */
-  void TheorySetsRels::composeTuplesForRels( Node n ) {
+  void TheorySetsRels::composeTupleMemForRels( Node n ) {
     Node r1 = n[0];
     Node r2 = n[1];
     Node r1_rep = getRepresentative(r1);
@@ -489,8 +695,8 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
       for(unsigned int j = 0; j < r2_elements.size(); j++) {
         std::vector<Node> composed_tuple;
         TypeNode tn = n.getType().getSetElementType();
-        Node r2_lmost = selectElement(r2_elements[j], 0);
-        Node r1_rmost = selectElement(r1_elements[i], t1_len-1);
+        Node r2_lmost = nthElementOfTuple(r2_elements[j], 0);
+        Node r1_rmost = nthElementOfTuple(r1_elements[i], t1_len-1);
         composed_tuple.push_back(Node::fromExpr(tn.getDatatype()[0].getConstructor()));
 
         if((areEqual(r1_rmost, r2_lmost) && n.getKind() == kind::JOIN) ||
@@ -499,14 +705,14 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
           unsigned int k = 0;
           unsigned int l = 1;
           for(; k < t1_len - 1; ++k) {
-            composed_tuple.push_back(selectElement(r1_elements[i], k));
+            composed_tuple.push_back(nthElementOfTuple(r1_elements[i], k));
           }
           if(isProduct) {
-            composed_tuple.push_back(selectElement(r1_elements[i], k));
-            composed_tuple.push_back(selectElement(r2_elements[j], 0));
+            composed_tuple.push_back(nthElementOfTuple(r1_elements[i], k));
+            composed_tuple.push_back(nthElementOfTuple(r2_elements[j], 0));
           }
           for(; l < t2_len; ++l) {
-            composed_tuple.push_back(selectElement(r2_elements[j], l));
+            composed_tuple.push_back(nthElementOfTuple(r2_elements[j], l));
           }
           Node composed_tuple_rep = getRepresentative(nm->mkNode(kind::APPLY_CONSTRUCTOR, composed_tuple));
           Node fact = MEMBER(composed_tuple_rep, new_rel_rep);
@@ -554,7 +760,8 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
         }
         Trace("rels-lemma") << "[sets-rels-lemma] Process pending lemma : "
                             << d_lemma_cache[i] << std::endl;
-        d_sets.d_out->lemma( d_lemma_cache[i] );
+        d_sets_theory.d_out->lemma( d_lemma_cache[i] );
+//        d_sets_theory.d_out->conflict()
       }
       for( std::map<Node, Node>::iterator child_it = d_pending_facts.begin();
            child_it != d_pending_facts.end(); child_it++ ) {
@@ -565,11 +772,13 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
         }
         Trace("rels-lemma") << "[sets-rels-fact-lemma] Process pending fact as lemma : "
                             << child_it->first << " with reason " << child_it->second << std::endl;
-        d_sets.d_out->lemma(NodeManager::currentNM()->mkNode(kind::IMPLIES, child_it->second, child_it->first));
+        d_sets_theory.d_out->lemma(NodeManager::currentNM()->mkNode(kind::IMPLIES, child_it->second, child_it->first));
       }
     }
     d_pending_facts.clear();
-    d_membership_cache.clear();
+    d_membership_constraints_cache.clear();
+    d_membership_tc_cache.clear();
+    d_membership_tc_exp_cache.clear();
     d_membership_exp_cache.clear();
     d_membership_db.clear();
     d_membership_exp_db.clear();
@@ -621,7 +830,7 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
       map_it++;
     }
     d_pending_facts.clear();
-    d_membership_cache.clear();
+    d_membership_constraints_cache.clear();
     d_membership_db.clear();
     d_membership_exp_cache.clear();
     d_terms_cache.clear();
@@ -659,7 +868,7 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
     Datatype dt = tn.getDatatype();
     elements.push_back( Node::fromExpr(dt[0].getConstructor() ) );
     for(int i = tuple_types.size() - 1; i >= 0; --i) {
-      elements.push_back( selectElement(tuple, i) );
+      elements.push_back( nthElementOfTuple(tuple, i) );
     }
     return NodeManager::currentNM()->mkNode( kind::APPLY_CONSTRUCTOR, elements );
   }
@@ -681,49 +890,22 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
   }
 
   bool TheorySetsRels::areEqual( Node a, Node b ){
+    Trace("rels-debug") << "[sets-rels] areEqual( a = " << a << ", b = " << b << ")" << std::endl;
     if(a == b) {
       return true;
-    } else if(a.isConst() && b.isConst()) {
-      return a == b;
     } else if( hasTerm( a ) && hasTerm( b ) ){
-//      if( d_eqEngine->isTriggerTerm(a, THEORY_SETS) &&
-//          d_eqEngine->isTriggerTerm(b, THEORY_SETS) ) {
-//        // Get representative trigger terms
-//          TNode x_shared = d_eqEngine->getTriggerTermRepresentative(a, THEORY_SETS);
-//          TNode y_shared = d_eqEngine->getTriggerTermRepresentative(b, THEORY_SETS);
-//          EqualityStatus eqStatusDomain = d_sets.d_valuation.getEqualityStatus(x_shared, y_shared);
-//          switch (eqStatusDomain) {
-//            case EQUALITY_TRUE_AND_PROPAGATED:
-//              // Should have been propagated to us
-//              Trace("rels-debug") << "EQUALITY_TRUE_AND_PROPAGATED ****  equality( a, b ) = true" << std::endl;
-//              return true;
-//              break;
-//            case EQUALITY_TRUE:
-//              // Missed propagation - need to add the pair so that theory engine can force propagation
-//              Trace("rels-debug") << "EQUALITY_TRUE **** equality( a, b ) = true" << std::endl;
-//              return true;
-//              break;
-//            case EQUALITY_FALSE_AND_PROPAGATED:
-//              // Should have been propagated to us
-//              Trace("rels-debug") << "EQUALITY_FALSE_AND_PROPAGATED ******** equality( a, b ) = false" << std::endl;
-//              return false;
-//              break;
-//            case EQUALITY_FALSE:
-//              Trace("rels-debug") << "EQUALITY_FALSE **** equality( a, b ) = false" << std::endl;
-//              return false;
-//              break;
-//
-//            default:
-//              // Covers EQUALITY_TRUE_IN_MODEL (common case) and EQUALITY_UNKNOWN
-//              break;
-//          }
-//      }
       return d_eqEngine->areEqual( a, b );
-    } else {
+    } else if(a.getType().isTuple()) {
+      bool equal = true;
+      for(unsigned int i = 0; i < a.getType().getTupleLength(); i++) {
+        equal = equal && areEqual(nthElementOfTuple(a, i), nthElementOfTuple(b, i));
+      }
+      return equal;
+    } else if(!a.getType().isBoolean()){
       makeSharedTerm(a);
       makeSharedTerm(b);
-      return false;
     }
+    return false;
   }
 
   bool TheorySetsRels::checkCycles(Node join_term) {
@@ -753,7 +935,56 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
     }
   }
 
-  inline Node TheorySetsRels::selectElement( Node tuple, int n_th ) {
+  inline Node TheorySetsRels::getReason(Node tc_rep, Node tc_term, Node tc_r_rep, Node tc_r) {
+    if(tc_term != tc_rep) {
+      Node reason = explain(EQUAL(tc_term, tc_rep));
+      if(tc_term[0] != tc_r_rep) {
+        return AND(reason, explain(EQUAL(tc_term[0], tc_r_rep)));
+      }
+    }
+    return Node::null();
+  }
+
+  // tuple might be a member of tc_rep; or it might be a member of tc_terms
+  Node TheorySetsRels::findMemExp(Node tc_rep, Node tuple) {
+    Trace("rels-exp") << "TheorySetsRels::findMemExp ( tc_rep = " << tc_rep << ", tuple = " << tuple << ")" << std::endl;
+    std::vector<Node> tc_terms = d_terms_cache.find(tc_rep)->second[kind::TRANSCLOSURE];
+    Assert(tc_terms.size() > 0);
+    for(unsigned int i = 0; i < tc_terms.size(); i++) {
+      Node r_rep = getRepresentative(tc_terms[i][0]);
+      Trace("rels-exp") << "TheorySetsRels::findMemExp ( r_rep = " << r_rep << ", tuple = " << tuple << ")" << std::endl;
+      std::map< Node, std::vector< Node > >::iterator tc_r_mems = d_membership_db.find(r_rep);
+      if(tc_r_mems != d_membership_db.end()) {
+        for(unsigned int i = 0; i < tc_r_mems->second.size(); i++) {
+          if(areEqual(tc_r_mems->second[i], tuple)) {
+            return explain(d_membership_exp_db[r_rep][i]);
+          }
+        }
+      }
+
+      Node tc_term_rep = getRepresentative(tc_terms[i]);
+      std::map< Node, std::vector< Node > >::iterator tc_t_mems = d_membership_db.find(tc_term_rep);
+      Trace("rels-exp") << "TheorySetsRels::findMemExp ( tc_t_rep = " << tc_term_rep << ", tuple = " << tuple << ")" << std::endl;
+      if(tc_t_mems != d_membership_db.end()) {
+        for(unsigned int i = 0; i < tc_t_mems->second.size(); i++) {
+          if(areEqual(tc_t_mems->second[i], tuple)) {
+            return explain(d_membership_exp_db[tc_term_rep][i]);
+          }
+        }
+      }
+    }
+//    std::map< Node, std::vector< Node > >::iterator tc_mems = d_membership_db.find(tc_rep);
+//    if(tc_mems != d_membership_db.end()) {
+//      for(unsigned int i = 0; i < tc_mems->second.size(); i++) {
+//        if(tc_mems->second[i] == tuple) {
+//          return explain(d_membership_exp_db[tc_rep][i]);
+//        }
+//      }
+//    }
+    return Node::null();
+  }
+
+  inline Node TheorySetsRels::nthElementOfTuple( Node tuple, int n_th ) {
     if(tuple.isConst() || (!tuple.isVar() && !tuple.isConst()))
       return tuple[n_th];
     Datatype dt = tuple.getType().getDatatype();
@@ -762,11 +993,11 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
 
   void TheorySetsRels::addSharedTerm( TNode n ) {
     Trace("rels-debug") << "[sets-rels] Add a shared term:  " << n << std::endl;
-    d_sets.addSharedTerm(n);
+    d_sets_theory.addSharedTerm(n);
     d_eqEngine->addTriggerTerm(n, THEORY_SETS);
   }
 
-  bool TheorySetsRels::hasTuple( Node rel_rep, Node tuple ){
+  bool TheorySetsRels::hasMember( Node rel_rep, Node tuple ){
     if(d_membership_db.find(rel_rep) == d_membership_db.end())
       return false;
     return std::find(d_membership_db[rel_rep].begin(),
@@ -774,7 +1005,7 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
   }
 
   void TheorySetsRels::makeSharedTerm( Node n ) {
-    if(d_shared_terms.find(n) == d_shared_terms.end()) {
+    if(d_shared_terms.find(n) == d_shared_terms.end() && !n.getType().isBoolean()) {
       Node skolem = NodeManager::currentNM()->mkSkolem( "sde", n.getType() );
       sendLemma(MEMBER(skolem, SINGLETON(n)), d_trueNode, "share-term");
       d_shared_terms.insert(n);
@@ -788,8 +1019,9 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
     if(d_eqEngine->hasTerm(node)) {
       return areEqual(node, polarity_atom);
     } else {
-      Node atom_mod = NodeManager::currentNM()->mkNode(atom.getKind(), getRepresentative(atom[0]),
-                                                       getRepresentative(atom[1]) );
+      Node atom_mod = NodeManager::currentNM()->mkNode(atom.getKind(),
+                                                       getRepresentative(atom[0]),
+                                                       getRepresentative(atom[1]));
       if(d_eqEngine->hasTerm(atom_mod)) {
         return areEqual(node, polarity_atom);
       }
@@ -800,7 +1032,7 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
   void TheorySetsRels::computeTupleReps( Node n ) {
     if( d_tuple_reps.find( n ) == d_tuple_reps.end() ){
       for( unsigned i = 0; i < n.getType().getTupleLength(); i++ ){
-        d_tuple_reps[n].push_back( getRepresentative( selectElement(n, i) ) );
+        d_tuple_reps[n].push_back( getRepresentative( nthElementOfTuple(n, i) ) );
       }
     }
   }
@@ -812,13 +1044,18 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
     d_membership_trie[rel].addTerm(member, d_tuple_reps[member]);
   }
 
+  inline Node TheorySetsRels::constructPair(Node tc_rep, Node a, Node b) {
+    Datatype dt = tc_rep.getType().getSetElementType().getDatatype();
+    return NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, Node::fromExpr(dt[0].getConstructor()), a, b);
+  }
+
   void TheorySetsRels::reduceTupleVar(Node n) {
     if(d_symbolic_tuples.find(n) == d_symbolic_tuples.end()) {
       Trace("rels-debug") << "Reduce tuple var: " << n[0] << " to concrete one " << std::endl;
       std::vector<Node> tuple_elements;
       tuple_elements.push_back(Node::fromExpr((n[0].getType().getDatatype())[0].getConstructor()));
       for(unsigned int i = 0; i < n[0].getType().getTupleLength(); i++) {
-        Node element = selectElement(n[0], i);
+        Node element = nthElementOfTuple(n[0], i);
         makeSharedTerm(element);
         tuple_elements.push_back(element);
       }
@@ -835,9 +1072,11 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
                                  eq::EqualityEngine* eq,
                                  context::CDO<bool>* conflict,
                                  TheorySets& d_set):
-    d_sets(d_set),
+    d_c(c),
+    d_sets_theory(d_set),
     d_trueNode(NodeManager::currentNM()->mkConst<bool>(true)),
     d_falseNode(NodeManager::currentNM()->mkConst<bool>(false)),
+    d_pending_merge(c),
     d_infer(c),
     d_infer_exp(c),
     d_lemma(u),
@@ -914,6 +1153,238 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
     }
   }
 
+  Node TheorySetsRels::explain(Node literal)
+  {
+    Trace("rels-debug") << "[sets-rels] TheorySetsRels::explain(" << literal << ")"<< std::endl;
+
+    bool polarity = literal.getKind() != kind::NOT;
+    TNode atom = polarity ? literal : literal[0];
+    std::vector<TNode> assumptions;
+
+    if(atom.getKind() == kind::EQUAL || atom.getKind() == kind::IFF) {
+      d_eqEngine->explainEquality(atom[0], atom[1], polarity, assumptions);
+    } else if(atom.getKind() == kind::MEMBER) {
+      if( !d_eqEngine->hasTerm(atom)) {
+        d_eqEngine->addTerm(atom);
+      }
+      d_eqEngine->explainPredicate(atom, polarity, assumptions);
+    } else {
+      Trace("rels-debug") << "unhandled: " << literal << "; (" << atom << ", "
+                    << polarity << "); kind" << atom.getKind() << std::endl;
+      Unhandled();
+    }
+    Trace("rels-debug") << "[sets-rels] ****** done with TheorySetsRels::explain(" << literal << ")"<< std::endl;
+    return mkAnd(assumptions);
+  }
+
+  TheorySetsRels::EqcInfo::EqcInfo( context::Context* c ) :
+  d_mem(c), d_not_mem(c), d_tp(c) {}
+
+  void TheorySetsRels::eqNotifyNewClass( Node n ) {
+    Trace("rels-std") << "[sets-rels] eqNotifyNewClass:" << " t = " << n << std::endl;
+    if(isRel(n) && n.getKind() == kind::TRANSPOSE) {
+      getOrMakeEqcInfo( n, true );
+    }
+  }
+
+  void TheorySetsRels::eqNotifyPostMerge( Node t1, Node t2 ) {
+    Trace("rels-std") << "[sets-rels] eqNotifyPostMerge:" << " t1 = " << t1 << " t2 = " << t2 << std::endl;
+
+    // Merge membership constraint with "true" or "false" eqc
+    // Todo: t1 might not be "true" or "false" rep
+    if((t1 == d_trueNode || t1 == d_falseNode) &&
+        t2.getKind() == kind::MEMBER &&
+        t2[0].getType().isTuple()) {
+
+      Assert(t1 == d_trueNode || t1 == d_falseNode);
+      bool polarity = t1 == d_trueNode;
+      Node t2_1rep = getRepresentative(t2[1]);
+      EqcInfo* ei = getOrMakeEqcInfo( t2_1rep );
+
+      if(ei == NULL) {
+        ei = getOrMakeEqcInfo( t2_1rep, true );
+      }
+      // might not need to store the membership info
+      // if we don't need to consider the eqc merge?
+      if(polarity) {
+        ei->d_mem.insert(t2[0]);
+      } else {
+        ei->d_not_mem.insert(t2[0]);
+      }
+      if(!ei->d_tp.get().isNull()) {
+        Node exp = polarity ? explain(t2) : explain(t2.negate());
+        if(ei->d_tp.get() != t2[1])
+          exp = AND( explain(EQUAL( ei->d_tp.get(), t2[1]) ), exp );
+        sendInferTranspose( polarity, t2[0], ei->d_tp.get(), exp, true );
+      }
+    // Merge two relation eqcs
+    } else if(t1.getType().isSet() &&
+              t2.getType().isSet() &&
+              t1.getType().getSetElementType().isTuple()) {
+
+      EqcInfo* t1_ei = getOrMakeEqcInfo(t1);
+      EqcInfo* t2_ei = getOrMakeEqcInfo(t2);
+      if(t1_ei != NULL && t2_ei != NULL) {
+        // TP(t1) = TP(t2) -> t1 = t2;
+        if(!t1_ei->d_tp.get().isNull() && !t2_ei->d_tp.get().isNull()) {
+          sendInferTranspose( true, t1_ei->d_tp.get(), t2_ei->d_tp.get(), explain(EQUAL(t1, t2)) );
+        }
+        // Apply transpose rule on (non)members of t2 and t1->tp
+        if(!t1_ei->d_tp.get().isNull()) {
+          for(NodeSet::key_iterator itr = t2_ei->d_mem.key_begin(); itr != t2_ei->d_mem.key_end(); itr++) {
+            if(!t1_ei->d_mem.contains(*itr)) {
+              sendInferTranspose( true, *itr, t1_ei->d_tp.get(), AND(explain(EQUAL(t1_ei->d_tp.get(), t2)), explain(MEMBER(*itr, t2))) );
+            }
+          }
+          for(NodeSet::key_iterator itr = t2_ei->d_not_mem.key_begin(); itr != t2_ei->d_not_mem.key_end(); itr++) {
+            if(!t1_ei->d_not_mem.contains(*itr)) {
+              sendInferTranspose( false, *itr, t1_ei->d_tp.get(), AND(explain(EQUAL(t1_ei->d_tp.get(), t2)), explain(MEMBER(*itr, t2).negate())) );
+            }
+          }
+          // Apply transpose rule on (non)members of t1 and t2->tp
+        } else if(!t2_ei->d_tp.get().isNull()) {
+          t1_ei->d_tp.set(t2_ei->d_tp);
+          for(NodeSet::key_iterator itr = t1_ei->d_mem.key_begin(); itr != t1_ei->d_mem.key_end(); itr++) {
+            if(!t2_ei->d_mem.contains(*itr)) {
+              sendInferTranspose( true, *itr, t2_ei->d_tp.get(), AND(explain(EQUAL(t1, t2_ei->d_tp.get())), explain(MEMBER(*itr, t1))) );
+            }
+          }
+          for(NodeSet::key_iterator itr = t1_ei->d_not_mem.key_begin(); itr != t1_ei->d_not_mem.key_end(); itr++) {
+            if(!t2_ei->d_not_mem.contains(*itr)) {
+              sendInferTranspose( false, *itr, t2_ei->d_tp.get(), AND(explain(EQUAL(t1, t2_ei->d_tp.get())), explain(MEMBER(*itr, t1).negate())) );
+            }
+          }
+        }
+      // t1 was created already and t2 was not
+      } else if(t1_ei != NULL) {
+        if(t1_ei->d_tp.get().isNull() && t2.getKind() == kind::TRANSPOSE) {
+          t1_ei->d_tp.set( t2 );
+        }
+      } else if(t2_ei != NULL){
+        t1_ei = getOrMakeEqcInfo(t1, true);
+        for(NodeSet::key_iterator itr = t2_ei->d_mem.key_begin(); itr != t2_ei->d_mem.key_end(); itr++) {
+          t1_ei->d_mem.insert(*itr);
+        }
+        for(NodeSet::key_iterator itr = t2_ei->d_not_mem.key_begin(); itr != t2_ei->d_not_mem.key_end(); itr++) {
+          t1_ei->d_not_mem.insert(*itr);
+        }
+        if(t1_ei->d_tp.get().isNull() && !t2_ei->d_tp.get().isNull()) {
+          t1_ei->d_tp.set(t2_ei->d_tp);
+        }
+      }
+    }
+
+    Trace("rels-std") << "[sets-rels] done with eqNotifyPostMerge:" << " t1 = " << t1 << " t2 = " << t2 << std::endl;
+  }
+
+  void TheorySetsRels::doPendingMerge() {
+    for(NodeList::const_iterator itr = d_pending_merge.begin(); itr != d_pending_merge.end(); itr++) {
+      Trace("rels-std") << "[sets-rels-lemma] Process pending merge fact : "
+                        << *itr << std::endl;
+      d_sets_theory.d_out->lemma(*itr);
+    }
+  }
+
+  void TheorySetsRels::sendInferTranspose( bool polarity, Node t1, Node t2, Node exp, bool reverseOnly ) {
+    Assert(t2.getKind() == kind::TRANSPOSE);
+    if(polarity && isRel(t1) && isRel(t2)) {
+      Assert(t1.getKind() == kind::TRANSPOSE);
+      Node n = NodeManager::currentNM()->mkNode( kind::IMPLIES, exp, EQUAL(t1[0], t2[0]) );
+      Trace("rels-std") << "[sets-rels-lemma] Generate a lemma by applying transpose rule: "
+                        << n << std::endl;
+      d_pending_merge.push_back(n);
+      d_lemma.insert(n);
+      return;
+    }
+
+    Node n1;
+    if(reverseOnly) {
+      if(polarity) {
+        n1 = NodeManager::currentNM()->mkNode( kind::IMPLIES, exp, MEMBER(reverseTuple(t1), t2[0]) );
+      } else {
+        n1 = NodeManager::currentNM()->mkNode( kind::IMPLIES, exp, MEMBER(reverseTuple(t1), t2[0]).negate() );
+      }
+    } else {
+      Node n2;
+      if(polarity) {
+        n1 = NodeManager::currentNM()->mkNode( kind::IMPLIES, exp, MEMBER(t1, t2) );
+        n2 = NodeManager::currentNM()->mkNode( kind::IMPLIES, exp, MEMBER(reverseTuple(t1), t2[0]) );
+      } else {
+        n1 = NodeManager::currentNM()->mkNode( kind::IMPLIES, exp, MEMBER(t1, t2).negate() );
+        n2 = NodeManager::currentNM()->mkNode( kind::IMPLIES, exp, MEMBER(reverseTuple(t1), t2[0]).negate() );
+      }
+      Trace("rels-std") << "[sets-rels-lemma] Generate a lemma by applying transpose rule: "
+                        << n2 << std::endl;
+      d_pending_merge.push_back(n2);
+      d_lemma.insert(n2);
+    }
+    Trace("rels-std") << "[sets-rels-lemma] Generate a lemma by applying transpose rule: "
+                      << n1 << std::endl;
+    d_pending_merge.push_back(n1);
+    d_lemma.insert(n1);
+
+  }
+
+  TheorySetsRels::EqcInfo* TheorySetsRels::getOrMakeEqcInfo( Node n, bool doMake ){
+    std::map< Node, EqcInfo* >::iterator eqc_i = d_eqc_info.find( n );
+    if(eqc_i == d_eqc_info.end()){
+      if( doMake ){
+        EqcInfo* ei;
+        if(eqc_i!=d_eqc_info.end()){
+          ei = eqc_i->second;
+        }else{
+          ei = new EqcInfo(d_sets_theory.getSatContext());
+          d_eqc_info[n] = ei;
+        }
+        if(n.getKind() == kind::TRANSPOSE){
+          ei->d_tp = n;
+        }
+        return ei;
+      }else{
+        return NULL;
+      }
+    }else{
+      return (*eqc_i).second;
+    }
+  }
+
+
+  Node TheorySetsRels::mkAnd( std::vector<TNode>& conjunctions ) {
+    Assert(conjunctions.size() > 0);
+    std::set<TNode> all;
+
+    for (unsigned i = 0; i < conjunctions.size(); ++i) {
+      TNode t = conjunctions[i];
+      if (t.getKind() == kind::AND) {
+        for(TNode::iterator child_it = t.begin();
+            child_it != t.end(); ++child_it) {
+          Assert((*child_it).getKind() != kind::AND);
+          all.insert(*child_it);
+        }
+      }
+      else {
+        all.insert(t);
+      }
+    }
+
+    Assert(all.size() > 0);
+
+    if (all.size() == 1) {
+      // All the same, or just one
+      return conjunctions[0];
+    }
+
+    NodeBuilder<> conjunction(kind::AND);
+    std::set<TNode>::const_iterator it = all.begin();
+    std::set<TNode>::const_iterator it_end = all.end();
+    while (it != it_end) {
+      conjunction << *it;
+      ++ it;
+    }
+
+    return conjunction;
+  }/* mkAnd() */
+
 }
 }
 }
index 500c1db5bb1ad57e99a5a809b868d357bac44ccd..d0f5e8cbd74729637917a13b4cd49f2d7430479b 100644 (file)
@@ -36,7 +36,6 @@ public:
 public:
   Node existsTerm( std::vector< Node >& reps, int argIndex = 0 );
   std::vector<Node> findTerms( std::vector< Node >& reps, int argIndex = 0 );
-//  void findTerms( std::vector< Node >& reps, std::vector< Node >& elements, int argIndex = 0 );
   bool addTerm( Node n, std::vector< Node >& reps, int argIndex = 0 );
   void debugPrint( const char * c, Node n, unsigned depth = 0 );
   void clear() { d_data.clear(); }
@@ -56,12 +55,32 @@ public:
 
   ~TheorySetsRels();
   void check(Theory::Effort);
-
   void doPendingLemmas();
+  context::Context * d_c;
+
+private:
+  /** equivalence class info
+   * d_mem tuples that are members of this equivalence class
+   * d_not_mem tuples that are not members of this equivalence class
+   * d_tp is a node of kind TRANSPOSE (if any) in this equivalence class,
+   */
+  class EqcInfo
+  {
+  public:
+    EqcInfo( context::Context* c );
+    ~EqcInfo(){}
+    NodeSet d_mem;
+    NodeSet d_not_mem;
+    context::CDO< Node > d_tp;
+  };
+
+  /** has eqc info */
+  bool hasEqcInfo( TNode n ) { return d_eqc_info.find( n )!=d_eqc_info.end(); }
+
 
 private:
 
-  TheorySets& d_sets;
+  TheorySets& d_sets_theory;
 
   /** True and false constant nodes */
   Node d_trueNode;
@@ -72,6 +91,8 @@ private:
   std::map< Node, Node > d_pending_split_facts;
   std::vector< Node > d_lemma_cache;
 
+  NodeList d_pending_merge;
+
   /** inferences: maintained to ensure ref count for internally introduced nodes */
   NodeList d_infer;
   NodeList d_infer_exp;
@@ -81,25 +102,48 @@ private:
   std::map< Node, std::vector<Node> > d_tuple_reps;
   std::map< Node, TupleTrie > d_membership_trie;
   std::hash_set< Node, NodeHashFunction > d_symbolic_tuples;
-  std::map< Node, std::vector<Node> > d_membership_cache;
-  std::map< Node, std::vector<Node> > d_membership_db;
-  std::map< Node, std::vector<Node> > d_membership_exp_db;
+  std::map< Node, std::vector<Node> > d_membership_constraints_cache;
   std::map< Node, std::vector<Node> > d_membership_exp_cache;
   std::map< Node, std::map<kind::Kind_t, std::vector<Node> > > d_terms_cache;
+  std::map< Node, std::vector<Node> > d_membership_db;
+  std::map< Node, std::vector<Node> > d_membership_exp_db;
+  std::map< Node, std::map< Node, std::hash_set<Node, NodeHashFunction> > > d_membership_tc_cache;
+  std::map< Node, Node > d_membership_tc_exp_cache;
 
   eq::EqualityEngine *d_eqEngine;
   context::CDO<bool> *d_conflict;
 
+  /** information necessary for equivalence classes */
+public:
+  void eqNotifyNewClass(Node t);
+  void eqNotifyPostMerge(Node t1, Node t2);
+
+private:
+
+  std::map< Node, EqcInfo* > d_eqc_info;
+  void doPendingMerge();
+  EqcInfo* getOrMakeEqcInfo( Node n, bool doMake = false );
+  void sendInferTranspose(bool, Node, Node, Node, bool reverseOnly = false);
+
+
   void check();
   void collectRelsInfo();
   void assertMembership( Node fact, Node reason, bool polarity );
-  void composeTuplesForRels( Node );
+  void composeTupleMemForRels( Node );
   void applyTransposeRule( Node, Node, bool tp_occur_rule = false );
   void applyJoinRule( Node, Node );
   void applyProductRule( Node, Node );
+  void applyTCRule( Node, Node );
+  void buildTCGraph( Node, Node, Node );
   void computeRels( Node );
   void computeTransposeRelations( Node );
   Node reverseTuple( Node );
+  void finalizeTCInfer();
+  void inferTC( Node, std::map< Node, std::hash_set< Node, NodeHashFunction > >& );
+  void inferTC( Node, Node, std::map< Node, std::hash_set< Node, NodeHashFunction > >&,
+                Node, Node, std::hash_set< Node, NodeHashFunction >&, bool first_round = false);
+
+  Node explain(Node);
 
   void sendInfer( Node fact, Node exp, const char * c );
   void sendLemma( Node fact, Node reason, const char * c );
@@ -107,13 +151,17 @@ private:
   void doPendingFacts();
   void doPendingSplitFacts();
   void addSharedTerm( TNode n );
+  void checkTCGraphForConflict( Node, Node, Node, Node, Node, std::map< Node, std::hash_set< Node, NodeHashFunction > >& );
   bool checkCycles( Node );
 
   // Helper functions
-  inline Node selectElement( Node, int);
+  inline Node nthElementOfTuple( Node, int);
+  inline Node getReason(Node tc_rep, Node tc_term, Node tc_r_rep, Node tc_r);
+  inline Node constructPair(Node tc_rep, Node a, Node b);
+  Node findMemExp(Node r, Node tuple);
   bool safeAddToMap( std::map< Node, std::vector<Node> >&, Node, Node );
   void addToMap( std::map< Node, std::vector<Node> >&, Node, Node );
-  bool hasTuple( Node, Node );
+  bool hasMember( Node, Node );
   Node getRepresentative( Node t );
   bool hasTerm( Node a );
   bool areEqual( Node a, Node b );
@@ -123,6 +171,8 @@ private:
   void makeSharedTerm( Node );
   void reduceTupleVar( Node );
   inline void addToMembershipDB( Node, Node, Node  );
+  bool isRel( Node n ) {return n.getType().isSet() && n.getType().getSetElementType().isTuple();}
+  Node mkAnd( std::vector< TNode >& assumptions );
 
 };
 
index 5c59d96ce1cf2f7bda94c779cc1a55de76d0a32e..8d76748bb6077497989f07956f965ea92e14d506 100644 (file)
@@ -208,8 +208,6 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
   }//kind::UNION
 
   case kind::TRANSPOSE: {
-    if(node[0].getType().isSet() && !node[0].getType().getSetElementType().isTuple())
-      return RewriteResponse(REWRITE_AGAIN, node[0]);
     if(node[0].getKind() != kind::TRANSPOSE) {
       Trace("sets-postrewrite") << "Sets::postRewrite returning " << node << std::endl;
       return RewriteResponse(REWRITE_DONE, node);
@@ -220,6 +218,17 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
     break;
   }
 
+  case kind::TRANSCLOSURE: {
+    if(node[0].getKind() != kind::TRANSCLOSURE) {
+      Trace("sets-postrewrite") << "Sets::postRewrite returning " << node << std::endl;
+      return RewriteResponse(REWRITE_DONE, node);
+    }
+    if(node[0].getKind() == kind::TRANSCLOSURE) {
+      return RewriteResponse(REWRITE_AGAIN, node[0]);
+    }
+    break;
+  }
+
   default:
     break;
   }//switch(node.getKind())