added support for expansion of transitive closure
authorPaul Meng <baolmeng@gmail.com>
Tue, 12 Jul 2016 01:40:25 +0000 (21:40 -0400)
committerPaul Meng <baolmeng@gmail.com>
Tue, 12 Jul 2016 01:40:25 +0000 (21:40 -0400)
src/theory/sets/theory_sets_rels.cpp
src/theory/sets/theory_sets_rels.h

index 3f7d079bd766cea1c26af034285c40ec525ee53e..24aa44f3b388787f98378ffccbed9136eb5b2624 100644 (file)
@@ -234,12 +234,9 @@ int TheorySetsRels::EqcInfo::counter        = 0;
   void TheorySetsRels::applyTCRule(Node exp, Node tc_term) {
     Trace("rels-debug") << "\n[sets-rels] *********** Applying TRANSITIVE CLOSURE rule on  "
                         << tc_term << " with explanation " << exp << std::endl;
-    bool polarity       = exp.getKind() != kind::NOT;
-    Node atom           = polarity ? exp : exp[0];
-    Node tup_rep        = getRepresentative(atom[0]);
+
     Node tc_rep         = getRepresentative(tc_term);
     Node tc_r_rep       = getRepresentative(tc_term[0]);
-
     // build the TC graph for tc_rep if it was not created before
     if( d_rel_nodes.find(tc_rep) == d_rel_nodes.end() ) {
       Trace("rels-debug") << "[sets-rels]  Start building the TC graph!" << std::endl;
@@ -247,50 +244,77 @@ int TheorySetsRels::EqcInfo::counter        = 0;
       d_rel_nodes.insert(tc_rep);
     }
 
-    // insert tup_rep in the tc_graph if it is not in the graph already
-    TC_IT tc_graph_it = d_membership_tc_cache.find(tc_rep);
+    bool polarity = exp.getKind() != kind::NOT;
 
     if(polarity) {
-      Node fst_rep = getRepresentative(RelsUtils::nthElementOfTuple(tup_rep, 0));
-      Node snd_rep = getRepresentative(RelsUtils::nthElementOfTuple(tup_rep, 1));
-
-      if(tc_graph_it != d_membership_tc_cache.end()) {
-        TC_PAIR_IT pair_set_it = tc_graph_it->second.find(fst_rep);
+      std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator mem_it  = d_tc_membership_db.find(tc_term);
 
-        if(pair_set_it != tc_graph_it->second.end()) {
-          pair_set_it->second.insert(snd_rep);
-        } else {
-          std::hash_set< Node, NodeHashFunction > pair_set;
-          pair_set.insert(snd_rep);
-          tc_graph_it->second[fst_rep] = pair_set;
-        }
-
-        Node                                    reason = getReason(tc_rep, tc_term, tc_r_rep, tc_term[0]);
-        std::map< Node, Node >::iterator        exp_it = d_membership_tc_exp_cache.find(tc_rep);
-
-        if(!reason.isNull() && exp_it->second != reason) {
-          d_membership_tc_exp_cache[tc_rep] = Rewriter::rewrite(AND(exp_it->second, reason));
-        }
+      if( mem_it == d_tc_membership_db.end() ) {
+        std::hash_set<Node, NodeHashFunction> members;
+        members.insert(exp[0]);
+        d_tc_membership_db[tc_term] = members;
       } else {
-        std::map< Node, std::hash_set< Node, NodeHashFunction > >       pair_set;
-        std::hash_set< Node, NodeHashFunction >                         snd_set;
-
-        snd_set.insert(snd_rep);
-        pair_set[fst_rep]               = snd_set;
-        d_membership_tc_cache[tc_rep]   = pair_set;
-        Node reason = getReason(tc_rep, tc_term, tc_r_rep, tc_term[0]);
-
-        if(!reason.isNull()) {
-          d_membership_tc_exp_cache[tc_rep] = reason;
-        }
-      }
-    // check if tup_rep already exists in TC graph for conflict
-    } else {
-      if(tc_graph_it != d_membership_tc_cache.end()) {
-        checkTCGraphForConflict(atom, tc_rep, d_trueNode, RelsUtils::nthElementOfTuple(tup_rep, 0),
-                                RelsUtils::nthElementOfTuple(tup_rep, 1), tc_graph_it->second);
+        mem_it->second.insert(exp[0]);
       }
     }
+    //todo: need to construct a tc_graph if transitive closure is used in the context
+
+//    Node atom           = polarity ? exp : exp[0];
+//    Node tup_rep        = getRepresentative(atom[0]);
+//    Node tc_rep         = getRepresentative(tc_term);
+//    Node tc_r_rep       = getRepresentative(tc_term[0]);
+//
+//    // build the TC graph for tc_rep if it was not created before
+//    if( d_rel_nodes.find(tc_rep) == d_rel_nodes.end() ) {
+//      Trace("rels-debug") << "[sets-rels]  Start building the TC graph!" << std::endl;
+//      buildTCGraph(tc_r_rep, tc_rep, tc_term);
+//      d_rel_nodes.insert(tc_rep);
+//    }
+
+    // insert tup_rep in the tc_graph if it is not in the graph already
+//    TC_IT tc_graph_it = d_membership_tc_cache.find(tc_rep);
+//
+//    if( polarity ) {
+//      Node fst_rep = getRepresentative(RelsUtils::nthElementOfTuple(tup_rep, 0));
+//      Node snd_rep = getRepresentative(RelsUtils::nthElementOfTuple(tup_rep, 1));
+//
+//      if( tc_graph_it != d_membership_tc_cache.end() ) {
+//        TC_PAIR_IT pair_set_it = tc_graph_it->second.find(fst_rep);
+//
+//        if( pair_set_it != tc_graph_it->second.end() ) {
+//          pair_set_it->second.insert(snd_rep);
+//        } else {
+//          std::hash_set< Node, NodeHashFunction > pair_set;
+//          pair_set.insert(snd_rep);
+//          tc_graph_it->second[fst_rep] = pair_set;
+//        }
+//
+//        Node                                    reason = getReason(tc_rep, tc_term, tc_r_rep, tc_term[0]);
+//        std::map< Node, Node >::iterator        exp_it = d_membership_tc_exp_cache.find(tc_rep);
+//
+//        if(!reason.isNull() && exp_it->second != reason) {
+//          d_membership_tc_exp_cache[tc_rep] = Rewriter::rewrite(AND(exp_it->second, reason));
+//        }
+//      } else {
+//        std::map< Node, std::hash_set< Node, NodeHashFunction > >       pair_set;
+//        std::hash_set< Node, NodeHashFunction >                         snd_set;
+//
+//        snd_set.insert(snd_rep);
+//        pair_set[fst_rep]               = snd_set;
+//        d_membership_tc_cache[tc_rep]   = pair_set;
+//        Node reason = getReason(tc_rep, tc_term, tc_r_rep, tc_term[0]);
+//
+//        if(!reason.isNull()) {
+//          d_membership_tc_exp_cache[tc_rep] = reason;
+//        }
+//      }
+//    // check if tup_rep already exists in TC graph for conflict
+//    } else {
+//      if( tc_graph_it != d_membership_tc_cache.end() ) {
+//        checkTCGraphForConflict(atom, tc_rep, d_trueNode, RelsUtils::nthElementOfTuple(tup_rep, 0),
+//                                RelsUtils::nthElementOfTuple(tup_rep, 1), tc_graph_it->second);
+//      }
+//    }
   }
 
   void TheorySetsRels::checkTCGraphForConflict (Node atom, Node tc_rep, Node exp, Node a, Node b,
@@ -555,6 +579,33 @@ int TheorySetsRels::EqcInfo::counter        = 0;
   }
 
 
+  void TheorySetsRels::finalizeTCInfer() {
+    Trace("rels-debug") << "[sets-rels] Finalizing transitive closure inferences!" << std::endl;
+    for(TC_IT tc_it = d_membership_tc_cache.begin(); tc_it != d_membership_tc_cache.end(); tc_it++) {
+      inferTC(tc_it->first, tc_it->second);
+    }
+  }
+
+  void TheorySetsRels::inferTC(Node tc_rep, std::map< Node, std::hash_set< Node, NodeHashFunction > >& tc_graph) {
+    Trace("rels-debug") << "[sets-rels] Build TC graph for tc_rep = " << tc_rep << std::endl;
+    for(std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator pair_set_it = tc_graph.begin();
+        pair_set_it != tc_graph.end(); pair_set_it++) {
+      for(std::hash_set< Node, NodeHashFunction >::iterator set_it = pair_set_it->second.begin();
+          set_it != pair_set_it->second.end(); set_it++) {
+        std::hash_set<Node, NodeHashFunction>   elements;
+        Node    pair    = constructPair(tc_rep, pair_set_it->first, *set_it);
+        Node    exp     = findMemExp(tc_rep, pair);
+
+        if(d_membership_tc_exp_cache.find(tc_rep) != d_membership_tc_exp_cache.end()) {
+          exp = AND(d_membership_tc_exp_cache[tc_rep], exp);
+        }
+        Assert(!exp.isNull());
+        elements.insert(pair_set_it->first);
+        inferTC( exp, tc_rep, tc_graph, pair_set_it->first, *set_it, elements );
+      }
+    }
+  }
+
   void TheorySetsRels::inferTC( Node exp, Node tc_rep, std::map< Node, std::hash_set< Node, NodeHashFunction > >& tc_graph,
                                 Node start_node, Node cur_node, std::hash_set< Node, NodeHashFunction >& traversed ) {
     Node                                                pair    = constructPair(tc_rep, start_node, cur_node);
@@ -584,33 +635,6 @@ int TheorySetsRels::EqcInfo::counter        = 0;
     }
   }
 
-  void TheorySetsRels::finalizeTCInfer() {
-    Trace("rels-debug") << "[sets-rels] Finalizing transitive closure inferences!" << std::endl;
-    for(TC_IT tc_it = d_membership_tc_cache.begin(); tc_it != d_membership_tc_cache.end(); tc_it++) {
-      inferTC(tc_it->first, tc_it->second);
-    }
-  }
-
-  void TheorySetsRels::inferTC(Node tc_rep, std::map< Node, std::hash_set< Node, NodeHashFunction > >& tc_graph) {
-    Trace("rels-debug") << "[sets-rels] Build TC graph for tc_rep = " << tc_rep << std::endl;
-    for(std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator pair_set_it = tc_graph.begin();
-        pair_set_it != tc_graph.end(); pair_set_it++) {
-      for(std::hash_set< Node, NodeHashFunction >::iterator set_it = pair_set_it->second.begin();
-          set_it != pair_set_it->second.end(); set_it++) {
-        std::hash_set<Node, NodeHashFunction>   elements;
-        Node                                    pair    = constructPair(tc_rep, pair_set_it->first, *set_it);
-        Node                                    exp     = findMemExp(tc_rep, pair);
-
-        if(d_membership_tc_exp_cache.find(tc_rep) != d_membership_tc_exp_cache.end()) {
-          exp = AND(d_membership_tc_exp_cache[tc_rep], exp);
-        }
-        Assert(!exp.isNull());
-        elements.insert(pair_set_it->first);
-        inferTC( exp, tc_rep, tc_graph, pair_set_it->first, *set_it, elements );
-      }
-    }
-  }
-
   // Bottom-up fashion to compute relations
   void TheorySetsRels::computeRels(Node n) {
     Trace("rels-debug") << "\n[sets-rels] computeJoinOrProductRelations for relation  " << n << std::endl;
@@ -793,7 +817,6 @@ int TheorySetsRels::EqcInfo::counter        = 0;
         Trace("rels-lemma") << "[sets-rels-lemma] Process pending lemma : "
                             << d_lemma_cache[i] << std::endl;
         d_sets_theory.d_out->lemma( d_lemma_cache[i] );
-//        d_sets_theory.d_out->conflict()
       }
       for( std::map<Node, Node>::iterator child_it = d_pending_facts.begin();
            child_it != d_pending_facts.end(); child_it++ ) {
@@ -806,7 +829,11 @@ int TheorySetsRels::EqcInfo::counter        = 0;
                             << child_it->first << " with reason " << child_it->second << std::endl;
         d_sets_theory.d_out->lemma(NodeManager::currentNM()->mkNode(kind::IMPLIES, child_it->second, child_it->first));
       }
+      doTCLemmas();
     }
+
+
+    d_tc_membership_db.clear();
     d_rel_nodes.clear();
     d_pending_facts.clear();
     d_membership_constraints_cache.clear();
@@ -823,6 +850,94 @@ int TheorySetsRels::EqcInfo::counter        = 0;
     d_node_id.clear();
   }
 
+  void TheorySetsRels::doTCLemmas() {
+    std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator mem_it = d_tc_membership_db.begin();
+
+    while(mem_it != d_tc_membership_db.end()) {
+      Node tc_rep       = getRepresentative(mem_it->first);
+      Node tc_r_rep     = getRepresentative(mem_it->first[0]);
+
+      // build the TC graph for tc_rep if it was not created before
+      if( d_rel_nodes.find(tc_rep) == d_rel_nodes.end() ) {
+        Trace("rels-debug") << "[sets-rels]  Start building the TC graph for relation " << mem_it->first << std::endl;
+        buildTCGraph(tc_r_rep, tc_rep, mem_it->first);
+        d_rel_nodes.insert(tc_rep);
+      }
+
+      std::hash_set< Node, NodeHashFunction >::iterator set_it = mem_it->second.begin();
+
+      while(set_it != mem_it->second.end()) {
+        std::hash_set<Node, NodeHashFunction> hasSeen;
+        Node    fst             = RelsUtils::nthElementOfTuple(*set_it, 0);
+        Node    snd             = RelsUtils::nthElementOfTuple(*set_it, 1);
+        Node    fst_rep         = getRepresentative(fst);
+        Node    snd_rep         = getRepresentative(RelsUtils::nthElementOfTuple(*set_it, 1));
+        TC_IT   tc_graph_it     = d_membership_tc_cache.find(tc_rep);
+
+        if((tc_graph_it != d_membership_tc_cache.end() && !isTCReachable(fst_rep, snd_rep, hasSeen, tc_graph_it->second)) ||
+           (tc_graph_it == d_membership_tc_cache.end())) {
+          Node reason   = explain(MEMBER(*set_it, mem_it->first));
+          Node sk_1     = NodeManager::currentNM()->mkSkolem("sde", fst_rep.getType());
+          Node sk_2     = NodeManager::currentNM()->mkSkolem("sde", snd_rep.getType());
+          Node mem_of_r = MEMBER(RelsUtils::constructPair(tc_r_rep, fst_rep, snd_rep), tc_r_rep);
+          Node sk_eq    = EQUAL(sk_1, sk_2);
+
+          if(fst_rep != fst) {
+            reason = AND(reason, explain(EQUAL(fst_rep, fst)));
+          }
+          if(snd_rep != snd) {
+            reason = AND(reason, explain(EQUAL(snd_rep, snd)));
+          }
+          if(tc_r_rep != mem_it->first[0]) {
+            reason = AND(reason, explain(EQUAL(tc_r_rep, mem_it->first[0])));
+          }
+          if(tc_rep != mem_it->first) {
+            reason = AND(reason, explain(EQUAL(tc_r_rep, mem_it->first)));
+          }
+
+          Node tc_lemma = NodeManager::currentNM()->mkNode(kind::IMPLIES, reason,
+                                                           OR(mem_of_r,
+                                                              (AND(MEMBER(RelsUtils::constructPair(tc_r_rep, fst_rep, sk_1), tc_r_rep),
+                                                                   (AND(MEMBER(RelsUtils::constructPair(tc_r_rep, sk_2, snd_rep), tc_r_rep),
+                                                                        (OR(sk_eq, MEMBER(RelsUtils::constructPair(tc_rep, sk_1, sk_2), tc_rep)))))))));
+          Trace("rels-lemma") << "[sets-rels-lemma] Process a TC lemma : "
+                              << tc_lemma << std::endl;
+          d_sets_theory.d_out->lemma(tc_lemma);
+          d_sets_theory.d_out->requirePhase(Rewriter::rewrite(mem_of_r), true);
+          d_sets_theory.d_out->requirePhase(Rewriter::rewrite(sk_eq), true);
+        }
+        set_it++;
+      }
+      mem_it++;
+    }
+  }
+
+  bool TheorySetsRels::isTCReachable(Node start, Node dest, std::hash_set<Node, NodeHashFunction>& hasSeen,
+                                      std::map< Node, std::hash_set< Node, NodeHashFunction > >& tc_graph) {
+    if(hasSeen.find(start) == hasSeen.end()) {
+      hasSeen.insert(start);
+    }
+
+    TC_PAIR_IT pair_set_it = tc_graph.find(start);
+
+    if(pair_set_it != tc_graph.end()) {
+      if(pair_set_it->second.find(dest) != pair_set_it->second.end()) {
+        return true;
+      } else {
+        std::hash_set< Node, NodeHashFunction >::iterator set_it = pair_set_it->second.begin();
+
+        while(set_it != pair_set_it->second.end()) {
+          // need to check if *set_it has been looked already
+          if(hasSeen.find(*set_it) == hasSeen.end()) {
+            isTCReachable(*set_it, dest, hasSeen, tc_graph);
+          }
+          set_it++;
+        }
+      }
+    }
+    return false;
+  }
+
   void TheorySetsRels::sendSplit(Node a, Node b, const char * c) {
     Node eq             = a.eqNode( b );
     Node neq            = NOT( eq );
@@ -931,10 +1046,6 @@ int TheorySetsRels::EqcInfo::counter        = 0;
     return false;
   }
 
-  bool TheorySetsRels::checkCycles(Node join_term) {
-    return false;
-  }
-
   bool TheorySetsRels::safeAddToMap(std::map< Node, std::vector<Node> >& map, Node rel_rep, Node member) {
     std::map< Node, std::vector< Node > >::iterator mem_it = map.find(rel_rep);
     if(mem_it == map.end()) {
@@ -1394,30 +1505,19 @@ int TheorySetsRels::EqcInfo::counter        = 0;
     Trace("rels-std") << "Start making TC inference after adding a member " << mem_rep << " to " << tc_ei->d_tc.get() << std::endl;
 
     NodeMap::iterator map_it    = tc_ei->d_mem_exp.begin();
-    while(map_it != tc_ei->d_mem_exp.end()) {
-      Trace("rels-debug") << " mem =  "<< (*map_it).first << " exp = " << (*map_it).second<< std::endl;
-      map_it++;
-    }
     Node exp            = explainTCMem(tc_ei, mem_rep, fst_rep, snd_rep);
     Assert(!exp.isNull());
     Node tc_lemma       = NodeManager::currentNM()->mkNode(kind::IMPLIES, exp, MEMBER(mem_rep, tc_ei->d_tc.get()));
     d_pending_merge.push_back(tc_lemma);
     d_lemma.insert(tc_lemma);
 
-    Trace("rels-std") << "Start making TC inference after adding a member " << mem_rep << " to " << tc_ei->d_tc.get()
-                      << " in_reachable size = " << in_reachable.size()
-                      << " out_reachable size = " << out_reachable.size()
-                      <<  " ***** 2" << std::endl;
-
     std::hash_set<int>::iterator        in_reachable_it = in_reachable.begin();
     while(in_reachable_it != in_reachable.end()) {
       Node    in_node         = d_id_node[*in_reachable_it];
       Node    in_pair         = RelsUtils::constructPair(tc_ei->d_tc.get(), in_node, fst_rep);
       Node    new_pair        = RelsUtils::constructPair(tc_ei->d_tc.get(), in_node, snd_rep);
-
-      Trace("rels-std") << "Reason for " << in_pair << "   " << explainTCMem(tc_ei, in_pair, in_node, fst_rep) << std::endl;
-
       Node    reason          = AND(explainTCMem(tc_ei, in_pair, in_node, fst_rep), exp);
+
       tc_ei->d_mem_exp[new_pair] = reason;
       tc_ei->d_mem.insert(new_pair);
       Node tc_lemma = NodeManager::currentNM()->mkNode(kind::IMPLIES, reason, MEMBER(new_pair, tc_ei->d_tc.get()));
@@ -1425,7 +1525,7 @@ int TheorySetsRels::EqcInfo::counter        = 0;
       d_lemma.insert(tc_lemma);
       in_reachable_it++;
     }
-    Trace("rels-std") << "Start making TC inference after adding a member " << mem_rep << " to " << tc_ei->d_tc.get() <<  " ***** 3" << std::endl;
+
     std::hash_set<int>::iterator        out_reachable_it = out_reachable.begin();
     while(out_reachable_it != out_reachable.end()) {
       Node    out_node        = d_id_node[*out_reachable_it];
@@ -1441,9 +1541,7 @@ int TheorySetsRels::EqcInfo::counter        = 0;
         Node    in_pair_exp     = explainTCMem(tc_ei, in_pair, in_node, snd_rep);
 
         Assert(in_pair_exp != Node::null());
-        reason                = AND(reason, in_pair_exp);
-
-        Trace("rels-std") << "Start making TC inference after adding a member " << mem_rep << " to " << tc_ei->d_tc.get() <<  " ***** 3 9" << std::endl;
+        reason  = AND(reason, in_pair_exp);
         tc_ei->d_mem_exp[new_pair] = reason;
         tc_ei->d_mem.insert(new_pair);
         Node tc_lemma = NodeManager::currentNM()->mkNode(kind::IMPLIES, reason, MEMBER(new_pair, tc_ei->d_tc.get()));
@@ -1453,7 +1551,6 @@ int TheorySetsRels::EqcInfo::counter        = 0;
       }
       out_reachable_it++;
     }
-    Trace("rels-std") << "Start making TC inference after adding a member " << mem_rep << " to " << tc_ei->d_tc.get() <<  " ***** 4" << std::endl;
   }
 
   void TheorySetsRels::collectInReachableNodes(EqcInfo* tc_ei, int start_id, std::hash_set<int>& in_reachable, bool firstRound) {
index 381ccddd9a874a0c7f0d8d137bbe9d9d2c8eb9da..5a1985d4e84dbd6701a5214a438ce88f8b463a2d 100644 (file)
@@ -129,6 +129,8 @@ private:
   std::map< Node, std::vector<Node> >           d_membership_db;
   std::map< Node, std::vector<Node> >           d_membership_exp_db;
   std::map< Node, Node >                        d_membership_tc_exp_cache;
+
+  std::map< Node, std::hash_set<Node, NodeHashFunction> >                       d_tc_membership_db;
   std::map< Node, std::map<kind::Kind_t, std::vector<Node> > >                  d_terms_cache;
   std::map< Node, std::map< Node, std::hash_set<Node, NodeHashFunction> > >     d_membership_tc_cache;
 
@@ -173,9 +175,12 @@ private:
   void inferTC( Node, std::map< Node, std::hash_set< Node, NodeHashFunction > >& );
   void inferTC( Node, Node, std::map< Node, std::hash_set< Node, NodeHashFunction > >&,
                 Node, Node, std::hash_set< Node, NodeHashFunction >&);
+  bool isTCReachable(Node fst, Node snd, std::hash_set<Node, NodeHashFunction>& hasSeen,
+                      std::map< Node, std::hash_set< Node, NodeHashFunction > >& tc_graph);
 
   Node explain(Node);
 
+  void doTCLemmas();
   void sendInfer( Node fact, Node exp, const char * c );
   void sendLemma( Node fact, Node reason, const char * c );
   void sendSplit( Node a, Node b, const char * c );
@@ -183,7 +188,6 @@ private:
   void doPendingSplitFacts();
   void addSharedTerm( TNode n );
   void checkTCGraphForConflict( Node, Node, Node, Node, Node, std::map< Node, std::hash_set< Node, NodeHashFunction > >& );
-  bool checkCycles( Node );
 
   // Helper functions
   bool insertIntoIdList(IdList&, int);