implemented TC for standard effort
authorPaulMeng <baolmeng@gmail.com>
Wed, 4 May 2016 15:21:18 +0000 (10:21 -0500)
committerPaulMeng <baolmeng@gmail.com>
Wed, 4 May 2016 15:21:18 +0000 (10:21 -0500)
src/theory/sets/theory_sets_private.cpp
src/theory/sets/theory_sets_rels.cpp
src/theory/sets/theory_sets_rels.h

index bc9227e5474e52006ea7bc9c71c6b8448eb29e82..aec2c119cb21a9a5702096d98587bd55a4a9960b 100644 (file)
@@ -696,6 +696,11 @@ const TheorySetsPrivate::Elements& TheorySetsPrivate::getElements
                           std::inserter(cur, cur.begin()) );
       break;
     }
+    case kind::JOIN: 
+    case kind::TCLOSURE:
+    case kind::TRANSPOSE:
+    case kind::PRODUCT:
+      break;
     default:
       Assert(theory::kindToTheoryId(k) != theory::THEORY_SETS,
              (std::string("Kind belonging to set theory not explicitly handled: ") + kindToString(k)).c_str());
index 75e3d483183593cbd0c3fc7296a14ac91a2d935c..428027acc7447eabd9e55e3b5d8c90a55d02ced5 100644 (file)
@@ -230,6 +230,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
                         << tc_term << " with explanation " << exp << std::endl;
     bool polarity = exp.getKind() != kind::NOT;
     Node atom = polarity ? exp : exp[0];
+    Node tup_rep = getRepresentative(atom[0]);
     Node tc_rep = getRepresentative(tc_term);
     Node tc_r_rep = getRepresentative(tc_term[0]);
 
@@ -239,17 +240,17 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
       buildTCGraph(tc_r_rep, tc_rep, tc_term);
       d_rel_nodes.insert(tc_rep);
     }
-    // insert atom[0] in the tc_graph if it is not in the graph already
+    // insert tup_rep in the tc_graph if it is not in the graph already
     TC_IT tc_graph_it = d_membership_tc_cache.find(tc_rep);
     if(polarity) {
       if(tc_graph_it != d_membership_tc_cache.end()) {
-        TC_PAIR_IT pair_set_it = tc_graph_it->second.find(RelsUtils::nthElementOfTuple(atom[0], 0));
+        TC_PAIR_IT pair_set_it = tc_graph_it->second.find(RelsUtils::nthElementOfTuple(tup_rep, 0));
         if(pair_set_it != tc_graph_it->second.end()) {
-          pair_set_it->second.insert(RelsUtils::nthElementOfTuple(atom[0], 1));
+          pair_set_it->second.insert(RelsUtils::nthElementOfTuple(tup_rep, 1));
         } else {
           std::hash_set< Node, NodeHashFunction > pair_set;
-          pair_set.insert(RelsUtils::nthElementOfTuple(atom[0], 1));
-          tc_graph_it->second[RelsUtils::nthElementOfTuple(atom[0], 0)] = pair_set;
+          pair_set.insert(RelsUtils::nthElementOfTuple(tup_rep, 1));
+          tc_graph_it->second[RelsUtils::nthElementOfTuple(tup_rep, 0)] = pair_set;
         }
         Node reason = getReason(tc_rep, tc_term, tc_r_rep, tc_term[0]);
         std::map< Node, Node >::iterator exp_it = d_membership_tc_exp_cache.find(tc_rep);
@@ -259,19 +260,49 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
       } else {
         std::map< Node, std::hash_set< Node, NodeHashFunction > > pair_set;
         std::hash_set< Node, NodeHashFunction > set;
-        set.insert(RelsUtils::nthElementOfTuple(atom[0], 1));
-        pair_set[RelsUtils::nthElementOfTuple(atom[0], 0)] = set;
+        set.insert(RelsUtils::nthElementOfTuple(tup_rep, 1));
+        pair_set[RelsUtils::nthElementOfTuple(tup_rep, 0)] = set;
         d_membership_tc_cache[tc_rep] = pair_set;
         Node reason = getReason(tc_rep, tc_term, tc_r_rep, tc_term[0]);
         if(!reason.isNull()) {
           d_membership_tc_exp_cache[tc_rep] = reason;
         }
       }
-    // check if atom[0] already exists in TC graph for conflict
+      // if(!d_tc_saver.contains(exp) && 
+      //    atom[0][0].getKind() != kind::SKOLEM && 
+      //    atom[0][1].getKind() != kind::SKOLEM) {
+           
+      //   TypeNode k_type = tup_rep.getType().getTupleTypes()[1];
+      //   Node k_0 = NodeManager::currentNM()->mkSkolem("tc_sde_", k_type);
+      //   Node k_1 = NodeManager::currentNM()->mkSkolem("tc_sde_", k_type);
+      //   Node k_2 = NodeManager::currentNM()->mkSkolem("tc_sde_", k_type);
+      //   Node k_3 = NodeManager::currentNM()->mkSkolem("tc_sde_", k_type);
+      //   Node fact = NodeManager::currentNM()->mkNode( kind::AND, MEMBER(RelsUtils::constructPair(tc_rep, tup_rep[0], k_0), tc_r_rep),
+      //                                                           MEMBER(RelsUtils::constructPair(tc_rep, k_0, k_1), tc_r_rep),
+      //                                                           MEMBER(RelsUtils::constructPair(tc_rep, k_1, k_2), tc_r_rep),
+      //                                                           MEMBER(RelsUtils::constructPair(tc_rep, k_2, k_3), tc_r_rep),
+      //                                                           MEMBER(RelsUtils::constructPair(tc_rep, k_3, tup_rep[1]), tc_r_rep) );
+      //   Node reason = exp;
+      //   if(tc_rep != tc_term) {
+      //     reason = AND(reason, explain(EQUAL(tc_rep, tc_term)));
+      //   }
+      //   if(tc_r_rep != tc_term[0]) {
+      //     reason = AND(reason, explain(EQUAL(tc_r_rep, tc_term[0])));
+      //   }                                                           
+        
+      //   makeSharedTerm(k_0);
+      //   makeSharedTerm(k_1);
+      //   makeSharedTerm(k_2);
+      //   makeSharedTerm(k_3);
+      //   sendLemma( fact, reason, "tc-decompose" );
+      //   d_tc_saver.insert(exp);  
+      // }     
+    // check if tup_rep already exists in TC graph for conflict
     } else {
       if(tc_graph_it != d_membership_tc_cache.end()) {
-        checkTCGraphForConflict(atom, tc_rep, d_trueNode, RelsUtils::nthElementOfTuple(atom[0], 0),
-                                RelsUtils::nthElementOfTuple(atom[0], 1), tc_graph_it->second);
+        Trace("rels-debug") << "********** tc reach here 0" << std::endl;
+        checkTCGraphForConflict(atom, tc_rep, d_trueNode, RelsUtils::nthElementOfTuple(tup_rep, 0),
+                                RelsUtils::nthElementOfTuple(tup_rep, 1), tc_graph_it->second);
       }
     }
   }
@@ -279,8 +310,11 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
   void TheorySetsRels::checkTCGraphForConflict (Node atom, Node tc_rep, Node exp, Node a, Node b,
                                                 std::map< Node, std::hash_set< Node, NodeHashFunction > >& pair_set) {
     TC_PAIR_IT pair_set_it = pair_set.find(a);
+    Trace("rels-debug") << "********** tc reach here 1" << " a = " << a << " b = " << b << std::endl;
     if(pair_set_it != pair_set.end()) {
+      Trace("rels-debug") << "********** tc reach here 2" << std::endl;
       if(pair_set_it->second.find(b) != pair_set_it->second.end()) {
+        Trace("rels-debug") << "********** tc reach here 3" << std::endl;
         Node reason = AND(exp, findMemExp(tc_rep, constructPair(tc_rep, a, b)));
         if(atom[1] != tc_rep) {
           reason = AND(exp, explain(EQUAL(atom[1], tc_rep)));
@@ -292,14 +326,20 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
 //                            << AND(reason.negate(), atom) << std::endl;
 //        d_sets_theory.d_out->conflict(AND(reason.negate(), atom));
       } else {
+        Trace("rels-debug") << "********** tc reach here 4" << std::endl;
         std::hash_set< Node, NodeHashFunction >::iterator set_it = pair_set_it->second.begin();
         while(set_it != pair_set_it->second.end()) {
-          checkTCGraphForConflict(atom, tc_rep, AND(exp, findMemExp(tc_rep, constructPair(tc_rep, a, *set_it))),
-                                  *set_it, b, pair_set);
+          // need to check if *set_it has been looked already
+          if(!areEqual(*set_it, a)) {
+            checkTCGraphForConflict(atom, tc_rep, AND(exp, findMemExp(tc_rep, constructPair(tc_rep, a, *set_it))),
+                                    *set_it, b, pair_set);  
+          }         
+          Trace("rels-debug") << "********** looping here 6 *set_it = " << *set_it << std::endl;                                  
           set_it++;
         }
       }
     }
+    Trace("rels-debug") << "********** tc reach here 5" << std::endl;
   }
 
 
@@ -368,14 +408,15 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
       reason_2 = Rewriter::rewrite(AND(reason_2, explain(EQUAL(t2, t2_rep))));
     }
     if(polarity) {
-      if(safeAddToMap(d_membership_db, r1_rep, t1_rep)) {
-        addToMap(d_membership_exp_db, r1_rep, reason_1);
-        sendInfer(fact_1, reason_1, "product-split");
-      }
-      if(safeAddToMap(d_membership_db, r2_rep, t2_rep)) {
-        addToMap(d_membership_exp_db, r2_rep, reason_2);
-        sendInfer(fact_2, reason_2, "product-split");
-      }
+      sendInfer(fact_1, reason_1, "product-split");
+      sendInfer(fact_2, reason_2, "product-split");
+      // if(safeAddToMap(d_membership_db, r1_rep, t1_rep)) {
+      //   addToMap(d_membership_exp_db, r1_rep, reason_1);        
+      // }
+      // if(safeAddToMap(d_membership_db, r2_rep, t2_rep)) {
+      //   addToMap(d_membership_exp_db, r2_rep, reason_2);
+        
+      // }
 
     } else {
       sendInfer(fact_1.negate(), reason_1, "product-split");
@@ -457,8 +498,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
       fact = MEMBER(t1, r1_rep);
       if(r1_rep != join_term[0]) {
         reasons = Rewriter::rewrite(AND(reason, explain(EQUAL(r1_rep, join_term[0]))));
-      }
-      addToMembershipDB(r1_rep, t1, reasons);
+      }    
       sendInfer(fact, reasons, "join-split");
 
       reasons = reason;
@@ -466,7 +506,6 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
       if(r2_rep != join_term[1]) {
         reasons = Rewriter::rewrite(AND(reason, explain(EQUAL(r2_rep, join_term[1]))));
       }
-      addToMembershipDB(r2_rep, t2, reasons);
       sendInfer(fact, reasons, "join-split");
 
       // Need to make the skolem "shared_x" as shared term
@@ -502,14 +541,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
       Trace("rels-debug") << "\n[sets-rels] Apply TRANSPOSE-OCCUR rule on term: " << tp_term
                              << " with explanation: " << exp << std::endl;
       Node fact = polarity ? MEMBER(reversedTuple, tp_term) : MEMBER(reversedTuple, tp_term).negate();
-      if(holds(fact)) {
-        Trace("rels-debug") << "[sets-rels] New fact: " << fact << " already holds. Skip...." << std::endl;
-      } else {
-        sendInfer(fact, exp, "transpose-occur");
-        if(polarity) {
-          addToMembershipDB(tp_term, reversedTuple, exp);
-        }
-      }
+      sendInfer(fact, exp, "transpose-occur");
       return;
     }
 
@@ -539,14 +571,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
       }
       fact = fact.negate();
     }
-    if(holds(fact)) {
-      Trace("rels-debug") << "[sets-rels] New fact: " << fact << " already holds. Skip...." << std::endl;
-    } else {
-      sendInfer(fact, reason, "transpose-rule");
-      if(polarity) {
-        addToMembershipDB(tp_t0_rep, reversedTuple, reason);
-      }
-    }
+    sendInfer(fact, reason, "transpose-rule");    
   }
 
 
@@ -676,7 +701,6 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
       if(holds(fact)) {
         Trace("rels-debug") << "[sets-rels] New fact: " << fact << " already holds! Skip..." << std::endl;
       } else {
-        addToMembershipDB(n_rep, rev_tup, Rewriter::rewrite(reason));
         sendInfer(fact, Rewriter::rewrite(reason), "transpose-rule");
       }
     }
@@ -717,8 +741,8 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
       for(unsigned int j = 0; j < r2_elements.size(); j++) {
         std::vector<Node> composed_tuple;
         TypeNode tn = n.getType().getSetElementType();
-        Node r2_lmost = RelsUtils::nthElementOfTuple(r2_elements[j], 0);
         Node r1_rmost = RelsUtils::nthElementOfTuple(r1_elements[i], t1_len-1);
+        Node r2_lmost = RelsUtils::nthElementOfTuple(r2_elements[j], 0);        
         composed_tuple.push_back(Node::fromExpr(tn.getDatatype()[0].getConstructor()));
 
         if((areEqual(r1_rmost, r2_lmost) && n.getKind() == kind::JOIN) ||
@@ -742,24 +766,35 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
             Trace("rels-debug") << "[sets-rels] New fact: " << fact << " already holds! Skip..." << std::endl;
           } else {
             std::vector<Node> reasons;
-            reasons.push_back(r1_exps[i]);
-            reasons.push_back(r2_exps[j]);
-            if(!isProduct)
-              reasons.push_back(EQUAL(r1_rmost, r2_lmost));
+            //Todo: need more explanation
+            reasons.push_back(explain(r1_exps[i]));            
+            reasons.push_back(explain(r2_exps[j]));
+            if(r1_exps[i].getKind() == kind::MEMBER && r1_exps[i][0] != r1_elements[i]) {
+              Trace("rels-debug") << "************* $ r1 ele = " << r1_elements[i] << " r1 exp ele = " << r1_exps[i][0] << std::endl;
+              reasons.push_back(explain(EQUAL(r1_elements[i], r1_exps[i][0])));            
+            }
+            if(r2_exps[j].getKind() == kind::MEMBER && r2_exps[j][0] != r2_elements[j]) {
+              Trace("rels-debug") << "************* $ r2 ele = " << r2_elements[j] << " r2 exp ele = " << r2_exps[j][0] << std::endl;
+              reasons.push_back(explain(EQUAL(r2_elements[j], r2_exps[j][0])));            
+            }
+            if(!isProduct) {              
+              if(r1_rmost != r2_lmost) {
+                reasons.push_back(explain(EQUAL(r1_rmost, r2_lmost)));
+              }
+            }
             if(r1 != r1_rep) {
-              reasons.push_back(EQUAL(r1, r1_rep));
+              reasons.push_back(explain(EQUAL(r1, r1_rep)));
             }
             if(r2 != r2_rep) {
-              reasons.push_back(EQUAL(r2, r2_rep));
+              reasons.push_back(explain(EQUAL(r2, r2_rep)));
             }
 
             Node reason = Rewriter::rewrite(nm->mkNode(kind::AND, reasons));
-            addToMembershipDB(new_rel_rep, composed_tuple_rep, reason);
-
-            if(isProduct)
+            if(isProduct) {
               sendInfer( fact, reason, "product-compose" );
-            else
+            } else {
               sendInfer( fact, reason, "join-compose" );
+            }
 
             Trace("rels-debug") << "[sets-rels] Compose tuples: " << r1_elements[i]
                                << " and " << r2_elements[j]
@@ -1120,8 +1155,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
                                  context::UserContext* u,
                                  eq::EqualityEngine* eq,
                                  context::CDO<bool>* conflict,
-                                 TheorySets& d_set):
-    d_c(c),
+                                 TheorySets& d_set):    
     d_sets_theory(d_set),
     d_trueNode(NodeManager::currentNM()->mkConst<bool>(true)),
     d_falseNode(NodeManager::currentNM()->mkConst<bool>(false)),
@@ -1130,6 +1164,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
     d_infer_exp(c),
     d_lemma(u),
     d_shared_terms(u),
+    d_tc_saver(u),
     d_eqEngine(eq),
     d_conflict(conflict)
   {
@@ -1227,15 +1262,48 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
   }
 
   TheorySetsRels::EqcInfo::EqcInfo( context::Context* c ) :
-  d_mem(c), d_not_mem(c), d_tp(c), d_pt(c) {}
+  d_mem(c), d_not_mem(c), d_in(c), d_out(c), d_tc_mem_exp(c), d_tp(c), d_pt(c), d_join(c), d_tc(c) {}
 
   void TheorySetsRels::eqNotifyNewClass( Node n ) {
     Trace("rels-std") << "[sets-rels] eqNotifyNewClass:" << " t = " << n << std::endl;
-    if(isRel(n) && (n.getKind() == kind::TRANSPOSE || n.getKind() == kind::PRODUCT)) {
+    if(isRel(n) && (n.getKind() == kind::TRANSPOSE || 
+                    n.getKind() == kind::PRODUCT ||
+                    n.getKind() == kind::JOIN ||
+                    n.getKind() == kind::TCLOSURE)) {
       getOrMakeEqcInfo( n, true );
     }
+    Trace("rels-std") << "[sets-rels] eqNotifyNewClass*****:" << " t = " << n << std::endl;
   }
+  void TheorySetsRels::addTCMem(EqcInfo* tc_ei, Node mem) {
+    Node fst = RelsUtils::nthElementOfTuple(mem, 0);
+    Node snd = RelsUtils::nthElementOfTuple(mem, 1);
+
+    NodeList* in_lst;
+    NodeList* out_lst;
+    NodeListMap::iterator tc_in_mem_it = tc_ei->d_in.find(snd);
+    if(tc_in_mem_it == tc_ei->d_in.end()) {
+      in_lst = new(d_sets_theory.getSatContext()->getCMM()) NodeList( true, d_sets_theory.getSatContext(), false,
+                                                                      context::ContextMemoryAllocator<TNode>(d_sets_theory.getSatContext()->getCMM()) );
+      tc_ei->d_in.insertDataFromContextMemory(snd, in_lst);
+      Trace("rels-std") << "Create cache for " << snd << std::endl;
+    } else {
+      in_lst = (*tc_in_mem_it).second;
+    }
+    Trace("rels-std") << "Add in membership arrow for " << snd << " : " << fst << std::endl;
+    in_lst->push_back( fst );
 
+    NodeListMap::iterator tc_out_mem_it = tc_ei->d_out.find(fst);
+    if(tc_out_mem_it == tc_ei->d_out.end()) {
+      out_lst = new(d_sets_theory.getSatContext()->getCMM()) NodeList( true, d_sets_theory.getSatContext(), false,
+                                                                      context::ContextMemoryAllocator<TNode>(d_sets_theory.getSatContext()->getCMM()) );
+      tc_ei->d_out.insertDataFromContextMemory(fst, out_lst);
+      Trace("rels-std") << "Create cache for " << fst << std::endl;
+    } else {
+      out_lst = (*tc_out_mem_it).second;
+    }
+    Trace("rels-std") << "Add out membership arrow for " << fst << " : " << snd << std::endl;
+    out_lst->push_back( snd );
+  }
   void TheorySetsRels::eqNotifyPostMerge( Node t1, Node t2 ) {
     Trace("rels-std") << "[sets-rels] eqNotifyPostMerge:" << " t1 = " << t1 << " t2 = " << t2 << std::endl;
 
@@ -1274,18 +1342,170 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
         }
         sendInferProduct(polarity, t2[0], ei->d_pt.get(), exp);
       }
+      if(polarity) {
+        if(!ei->d_tc.get().isNull()) {
+          addTCMem(ei, t2[0]);
+          ei->d_tc_mem_exp.insert(t2[0], t2);
+          sendInferTC(ei, t2[0], t2);
+        } else {
+          std::vector<TypeNode> tup_types = t2[1].getType().getSetElementType().getTupleTypes();
+          if( tup_types.size() == 2 && tup_types[0] == tup_types[1] ) {
+            Node tc_n = NodeManager::currentNM()->mkNode(kind::TCLOSURE, t2[1]);
+            EqcInfo* tc_ei = getOrMakeEqcInfo( tc_n );
+            if(tc_ei != NULL) {
+              addTCMem(tc_ei, t2[0]);
+              Node exp = (tc_n == tc_ei->d_tc.get()) ? t2 : AND(EQUAL(tc_n, tc_ei->d_tc.get()), t2);
+              tc_ei->d_tc_mem_exp.insert(t2[0], exp);
+              sendInferTC(tc_ei, t2[0], exp);
+            }
+          }
+        }
+      }
+
     // Merge two relation eqcs
     } else if(t1.getType().isSet() &&
               t2.getType().isSet() &&
               t1.getType().getSetElementType().isTuple()) {
       mergeTransposeEqcs(t1, t2);
       mergeProductEqcs(t1, t2);
+      mergeTCEqcs(t1, t2);
     }
 
     Trace("rels-std") << "[sets-rels] done with eqNotifyPostMerge:" << " t1 = " << t1 << " t2 = " << t2 << std::endl;
   }
 
+  void TheorySetsRels::sendInferTC(EqcInfo* tc_ei, Node mem, Node exp) {
+    Trace("rels-std") << "[sets-rels] sendInferTC member = " << mem << " with explanation = " << exp << std::endl;
+    if(!tc_ei->d_mem.contains(mem)) {
+      Node tc_lemma = NodeManager::currentNM()->mkNode(kind::IMPLIES, exp, MEMBER(mem, tc_ei->d_tc.get()));
+      d_pending_merge.push_back(tc_lemma);
+      d_lemma.insert(tc_lemma);
+      tc_ei->d_mem.insert(mem);
+    }
+    std::hash_set<Node, NodeHashFunction> seen;
+    seen.insert(RelsUtils::nthElementOfTuple(mem, 0));
+    sendInferInTC(tc_ei, RelsUtils::nthElementOfTuple(mem, 0), RelsUtils::nthElementOfTuple(mem, 1), seen, exp);
+    sendInferOutTC(tc_ei, RelsUtils::nthElementOfTuple(mem, 0), RelsUtils::nthElementOfTuple(mem, 1), seen, exp);
+  }
+
+  void TheorySetsRels::sendInferInTC(EqcInfo* tc_ei, Node fst, Node snd, std::hash_set<Node, NodeHashFunction> seen, Node exp) {
+    for(NodeListMap::iterator nl_it = tc_ei->d_in.begin(); nl_it != tc_ei->d_in.end(); nl_it++) {
+      if(areEqual((*nl_it).first, fst)) {
+        for(NodeList::const_iterator itr = (*nl_it).second->begin(); itr != (*nl_it).second->end(); itr++) {
+          Node pair = RelsUtils::constructPair(tc_ei->d_tc.get(), *itr, snd);
+          if(!tc_ei->d_mem.contains(pair)) {
+            Node reason = ((*nl_it).first == fst) ?
+                          Rewriter::rewrite(AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), *itr, fst)))):
+                          Rewriter::rewrite(AND(EQUAL((*nl_it).first, fst), AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), *itr, fst)))));
+            Node tc_lemma = NodeManager::currentNM()->mkNode(kind::IMPLIES, reason, MEMBER(pair, tc_ei->d_tc.get()));
+            d_pending_merge.push_back(tc_lemma);
+            d_lemma.insert(tc_lemma);
+            tc_ei->d_mem.insert(pair);
+            tc_ei->d_tc_mem_exp.insert(pair, reason);
+          }
+          if(seen.find(*itr) == seen.end()) {
+            seen.insert(*itr);
+            sendInferInTC(tc_ei, *itr, snd, seen, AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), *itr, fst))));
+          }
+        }
+      }
+    }
+  }
+
+  void TheorySetsRels::sendInferOutTC(EqcInfo* tc_ei, Node fst, Node snd, std::hash_set<Node, NodeHashFunction> seen, Node exp) {
+    for(NodeListMap::iterator nl_it = tc_ei->d_out.begin(); nl_it != tc_ei->d_out.end(); nl_it++) {
+      if(areEqual((*nl_it).first, snd)) {
+        for(NodeList::const_iterator itr = (*nl_it).second->begin(); itr != (*nl_it).second->end(); itr++) {
+          Node pair = RelsUtils::constructPair(tc_ei->d_tc.get(), fst, *itr);
+          if(!tc_ei->d_mem.contains(pair)) {
+            Node reason = ((*nl_it).first == snd) ?
+                          Rewriter::rewrite(AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), snd, *itr)))) :
+                          Rewriter::rewrite(AND(EQUAL((*nl_it).first, snd), AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), snd, *itr)))));
+            Node tc_lemma = NodeManager::currentNM()->mkNode(kind::IMPLIES, reason, MEMBER(pair, tc_ei->d_tc.get()));
+            d_pending_merge.push_back(tc_lemma);
+            d_lemma.insert(tc_lemma);
+            tc_ei->d_mem.insert(pair);
+            tc_ei->d_tc_mem_exp.insert(pair, reason);
+          }
+          if(seen.find(*itr) == seen.end()) {
+            seen.insert(*itr);
+            sendInferOutTC(tc_ei, snd, *itr, seen, AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), snd, *itr))));
+          }
+        }
+      }
+    }
+  }
+
+  Node TheorySetsRels::findTCMemExp(EqcInfo* tc_ei, Node mem) {
+    NodeMap::iterator exp_it = tc_ei->d_tc_mem_exp.find(mem);
+    Assert(exp_it != tc_ei->d_tc_mem_exp.end());
+    return (*exp_it).second;
+  }
+
+  void TheorySetsRels::mergeTCEqcExp(EqcInfo* ei_1, EqcInfo* ei_2) {
+    for(NodeMap::iterator itr = ei_2->d_tc_mem_exp.begin(); itr != ei_2->d_tc_mem_exp.end(); itr++) {
+      NodeMap::iterator exp_it = ei_1->d_tc_mem_exp.find((*itr).first);
+      if(exp_it != ei_1->d_tc_mem_exp.end()) {
+        ei_1->d_tc_mem_exp.insert((*itr).first, OR((*itr).second, (*exp_it).second));
+      } else {
+        ei_1->d_tc_mem_exp.insert((*itr).first, (*itr).second);
+      }
+    }
+  }
+
+  void TheorySetsRels::buildTCAndExp(Node n, EqcInfo* ei) {
+    for(NodeSet::key_iterator mem_it = ei->d_mem.key_begin(); mem_it != ei->d_mem.key_end(); mem_it++) {
+      addTCMem(ei, *mem_it);
+      Node exp = (!ei->d_tc.get().isNull() && n == ei->d_tc.get()) ?
+                 AND(MEMBER(*mem_it, n), explain(EQUAL(n, ei->d_tc.get()))) :
+                 MEMBER(*mem_it, n);
+      ei->d_tc_mem_exp.insert(*mem_it, exp);
+    }
+  }
+
+  void TheorySetsRels::mergeTCEqcs(Node t1, Node t2) {
+    Trace("rels-std") << "[sets-rels] Merge TC eqcs t1 = " << t1 << " and t2 = " << t2 << std::endl;
+    EqcInfo* t1_ei = getOrMakeEqcInfo(t1);
+    EqcInfo* t2_ei = getOrMakeEqcInfo(t2);
+    if(t1_ei != NULL && t2_ei != NULL) {
+      // Apply TC rule on members of t2 and t1->tc
+      if(!t1_ei->d_tc.get().isNull()) {
+        mergeTCEqcExp(t1_ei, t2_ei);
+        for(NodeSet::key_iterator itr = t2_ei->d_mem.key_begin(); itr != t2_ei->d_mem.key_end(); itr++) {
+          sendInferTC(t1_ei, *itr, findTCMemExp(t1_ei, *itr));
+          if(!t1_ei->d_mem.contains(*itr)) {
+
+          }
+        }
+      } else if(!t2_ei->d_tc.get().isNull()) {
+        t1_ei->d_tc.set(t2_ei->d_tc);
+        buildTCAndExp(t1, t1_ei);
+        mergeTCEqcExp(t1_ei, t2_ei);
+        for(NodeSet::key_iterator itr = t2_ei->d_mem.key_begin(); itr != t2_ei->d_mem.key_end(); itr++) {
+          sendInferTC(t1_ei, *itr, findTCMemExp(t1_ei, *itr));
+          if(!t1_ei->d_mem.contains(*itr) && !t2_ei->d_mem.contains(*itr)) {
+
+          }
+        }
+      }
+    // t1 was created already and t2 was not
+    } else if(t1_ei != NULL) {
+      if(t1_ei->d_tc.get().isNull() && t2.getKind() == kind::TCLOSURE) {
+        t1_ei->d_tc.set( t2 );
+      }
+    } else if(t2_ei != NULL){
+      t1_ei = getOrMakeEqcInfo(t1, true);
+      for(NodeSet::key_iterator itr = t2_ei->d_mem.key_begin(); itr != t2_ei->d_mem.key_end(); itr++) {
+        t1_ei->d_mem.insert(*itr);
+      }
+      if(t1_ei->d_tc.get().isNull() && !t2_ei->d_tc.get().isNull()) {
+        t1_ei->d_tc.set(t2_ei->d_tc);
+      }
+    }
+  }
+
   void TheorySetsRels::mergeProductEqcs(Node t1, Node t2) {
+    Trace("rels-std") << "[sets-rels] Merge PRODUCT eqcs t1 = " << t1 << " and t2 = " << t2 << std::endl;
     EqcInfo* t1_ei = getOrMakeEqcInfo(t1);
     EqcInfo* t2_ei = getOrMakeEqcInfo(t2);
     if(t1_ei != NULL && t2_ei != NULL) {
@@ -1305,7 +1525,6 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
             sendInferProduct( false, *itr, t1_ei->d_pt.get(), AND(explain(EQUAL(t1_ei->d_pt.get(), t2)), explain(MEMBER(*itr, t2).negate())) );
           }
         }
-        // Apply transpose rule on (non)members of t1 and t2->tp
       } else if(!t2_ei->d_pt.get().isNull()) {
         t1_ei->d_pt.set(t2_ei->d_pt);
         for(NodeSet::key_iterator itr = t1_ei->d_mem.key_begin(); itr != t1_ei->d_mem.key_end(); itr++) {
@@ -1326,7 +1545,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
       }
     } else if(t2_ei != NULL){
       t1_ei = getOrMakeEqcInfo(t1, true);
-      for(NodeSet::key_iterator itr = t2_ei->d_mem.key_begin(); itr != t2_ei->d_mem.key_end(); itr++) {
+      for(NodeSet::key_iterator itr = t2_ei->d_mem.key_begin(); itr != t2_ei->d_mem.key_end(); itr++) {       
         t1_ei->d_mem.insert(*itr);
       }
       for(NodeSet::key_iterator itr = t2_ei->d_not_mem.key_begin(); itr != t2_ei->d_not_mem.key_end(); itr++) {
@@ -1339,6 +1558,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
   }
 
   void TheorySetsRels::mergeTransposeEqcs(Node t1, Node t2) {
+    Trace("rels-std") << "[sets-rels] Merge TRANSPOSE eqcs t1 = " << t1 << " and t2 = " << t2 << std::endl;
     EqcInfo* t1_ei = getOrMakeEqcInfo(t1);
     EqcInfo* t2_ei = getOrMakeEqcInfo(t2);
     if(t1_ei != NULL && t2_ei != NULL) {
@@ -1442,9 +1662,10 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
   void TheorySetsRels::sendInferProduct( bool polarity, Node t1, Node t2, Node exp ) {
     Assert(t2.getKind() == kind::PRODUCT);
     if(polarity && isRel(t1) && isRel(t2)) {
+      //PRODUCT(x) = PRODUCT(y) => x = y;
       Assert(t1.getKind() == kind::PRODUCT);
       Node n = NodeManager::currentNM()->mkNode( kind::IMPLIES, exp, EQUAL(t1[0], t2[0]) );
-      Trace("rels-std") << "[sets-rels-lemma] Generate a lemma by applying transpose rule: "
+      Trace("rels-std") << "[sets-rels-lemma] Generate a lemma by applying product rule: "
                         << n << std::endl;
       d_pending_merge.push_back(n);
       d_lemma.insert(n);
@@ -1483,14 +1704,14 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
       n1 = NodeManager::currentNM()->mkNode( kind::IMPLIES, exp, MEMBER(tuple_1, r1).negate() );
       n2 = NodeManager::currentNM()->mkNode( kind::IMPLIES, exp, MEMBER(tuple_2, r2).negate() );
     }
-    Trace("rels-std") << "[sets-rels-lemma] Generate a lemma by applying product rule: "
-                      << n2 << std::endl;
-    d_pending_merge.push_back(n2);
-    d_lemma.insert(n2);
-    Trace("rels-std") << "[sets-rels-lemma] Generate a lemma by applying product rule: "
+    Trace("rels-std") << "[sets-rels-lemma] Generate a lemma by applying product-split rule: "
                       << n1 << std::endl;
     d_pending_merge.push_back(n1);
     d_lemma.insert(n1);
+    Trace("rels-std") << "[sets-rels-lemma] Generate a lemma by applying product-split rule: "
+                      << n2 << std::endl;
+    d_pending_merge.push_back(n2);
+    d_lemma.insert(n2);
 
   }
 
@@ -1509,6 +1730,10 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
           ei->d_tp = n;
         } else if(n.getKind() == kind::PRODUCT) {
           ei->d_pt = n;
+        } else if(n.getKind() == kind::TCLOSURE) {
+          ei->d_tc = n;
+        } else if(n.getKind() == kind::JOIN) {
+          ei->d_join = n;
         }
         return ei;
       }else{
index ff62b67abbfefcc22d8ad6af2073ec2feb82286c..0d24c65b385f83122dc3f325579f05b86b2c86f5 100644 (file)
@@ -47,6 +47,9 @@ class TheorySetsRels {
   typedef context::CDChunkList<Node> NodeList;
   typedef context::CDHashSet<Node, NodeHashFunction> NodeSet;
   typedef context::CDHashMap<Node, bool, NodeHashFunction> NodeBoolMap;
+  typedef context::CDHashMap<Node, NodeList*, NodeHashFunction> NodeListMap;
+  typedef context::CDHashMap<Node, NodeSet*, NodeHashFunction> NodeSetMap;
+  typedef context::CDHashMap<Node, Node, NodeHashFunction> NodeMap;
 
 public:
   TheorySetsRels(context::Context* c,
@@ -58,13 +61,15 @@ public:
   ~TheorySetsRels();
   void check(Theory::Effort);
   void doPendingLemmas();
-  context::Context * d_c;
 
 private:
   /** equivalence class info
    * d_mem tuples that are members of this equivalence class
    * d_not_mem tuples that are not members of this equivalence class
    * d_tp is a node of kind TRANSPOSE (if any) in this equivalence class,
+   * d_pt is a node of kind PRODUCT (if any) in this equivalence class,
+   * d_join is a node of kind JOIN (if any) in this equivalence class,
+   * d_tc is a node of kind TCLOSURE (if any) in this equivalence class,
    */
   class EqcInfo
   {
@@ -73,8 +78,13 @@ private:
     ~EqcInfo(){}
     NodeSet d_mem;
     NodeSet d_not_mem;
+    NodeListMap d_in;
+    NodeListMap d_out;
+    NodeMap d_tc_mem_exp;
     context::CDO< Node > d_tp;
     context::CDO< Node > d_pt;
+    context::CDO< Node > d_join;
+    context::CDO< Node > d_tc;
   };
 
   /** has eqc info */
@@ -101,6 +111,9 @@ private:
   NodeList d_infer_exp;
   NodeSet d_lemma;
   NodeSet d_shared_terms;
+  
+  // tc terms that have been decomposed
+  NodeSet d_tc_saver;
 
   std::hash_set< Node, NodeHashFunction > d_rel_nodes;
   std::map< Node, std::vector<Node> > d_tuple_reps;
@@ -123,13 +136,22 @@ public:
   void eqNotifyPostMerge(Node t1, Node t2);
 
 private:
-  void mergeTransposeEqcs(Node t1, Node t2);
-  void mergeProductEqcs(Node t1, Node t2);
-  std::map< Node, EqcInfo* > d_eqc_info;
+
   void doPendingMerge();
+  std::map< Node, EqcInfo* > d_eqc_info;
   EqcInfo* getOrMakeEqcInfo( Node n, bool doMake = false );
+  void mergeTransposeEqcs(Node t1, Node t2);
+  void mergeProductEqcs(Node t1, Node t2);
+  void mergeTCEqcs(Node t1, Node t2);
   void sendInferTranspose(bool, Node, Node, Node, bool reverseOnly = false);
   void sendInferProduct(bool, Node, Node, Node);
+  void sendInferTC(EqcInfo* tc_ei, Node mem, Node exp);
+  void sendInferInTC(EqcInfo* tc_ei, Node fst, Node snd, std::hash_set<Node, NodeHashFunction> seen, Node exp);
+  void sendInferOutTC(EqcInfo* tc_ei, Node fst, Node snd, std::hash_set<Node, NodeHashFunction> seen, Node exp);
+  void addTCMem(EqcInfo* tc_ei, Node mem);
+  Node findTCMemExp(EqcInfo*, Node);
+  void mergeTCEqcExp(EqcInfo*, EqcInfo*);
+  void buildTCAndExp(Node, EqcInfo*);
 
 
   void check();