fixed explanation for transitive closure inferences
authorPaulMeng <baolmeng@gmail.com>
Tue, 12 Apr 2016 15:02:26 +0000 (10:02 -0500)
committerPaulMeng <baolmeng@gmail.com>
Tue, 12 Apr 2016 15:02:26 +0000 (10:02 -0500)
src/theory/sets/theory_sets_rels.cpp
src/theory/sets/theory_sets_rels.h

index 5df44d9f8271549bea3cb42e798308d6738e302f..0e20b9bfa7165fa98928083b6edb03196fce9619 100644 (file)
@@ -59,7 +59,6 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
     MEM_IT m_it = d_membership_constraints_cache.begin();
     while(m_it != d_membership_constraints_cache.end()) {
       Node rel_rep = m_it->first;
-      Trace("rels-debug") << "[sets-rels] Processing rel_rep = " << rel_rep << std::endl;
 
       // No relational terms found with rel_rep as its representative
       // But TRANSPOSE(rel_rep) may occur in the context
@@ -201,7 +200,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
    *                            -----------------------------------------------------------
    *                              x <= TRANSCLOSURE(x) && (x JOIN x) <= TRANSCLOSURE(x) ....
    *
-   *                              TC(x) = TC(y) => x = y
+   *                              TC(x) = TC(y) => x = y ?
    *
    */
 
@@ -237,10 +236,12 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
     Node tc_r_rep = getRepresentative(tc_term[0]);
 
     // build the TC graph for tc_rep if it was not created before
-    if( d_membership_tc_cache.find(tc_rep) == d_membership_tc_cache.end() ) {
+    if( d_tc_nodes.find(tc_rep) == d_tc_nodes.end() ) {
+      Trace("rels-debug") << "[sets-rels]  Start building the TC graph!" << std::endl;
       buildTCGraph(tc_r_rep, tc_rep, tc_term);
+      d_tc_nodes.insert(tc_rep);
     }
-    // insert atom[0] in the tc_graph
+    // insert atom[0] in the tc_graph if it is not in the graph already
     TC_IT tc_graph_it = d_membership_tc_cache.find(tc_rep);
     if(polarity) {
       if(tc_graph_it != d_membership_tc_cache.end()) {
@@ -268,7 +269,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
           d_membership_tc_exp_cache[tc_rep] = reason;
         }
       }
-    // check if atom[0] exists in TC graph for conflict
+    // check if atom[0] already exists in TC graph for conflict
     } else {
       if(tc_graph_it != d_membership_tc_cache.end()) {
         checkTCGraphForConflict(atom, tc_rep, d_trueNode, nthElementOfTuple(atom[0], 0),
@@ -284,11 +285,11 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
       if(pair_set_it->second.find(b) != pair_set_it->second.end()) {
         Node reason = AND(exp, findMemExp(tc_rep, constructPair(tc_rep, a, b)));
         if(atom[1] != tc_rep) {
-          reason = AND(exp, EQUAL(atom[1], tc_rep));
+          reason = AND(exp, explain(EQUAL(atom[1], tc_rep)));
         }
         Trace("rels-debug") << "[sets-rels] found a conflict and send out lemma : "
-                            <<  NodeManager::currentNM()->mkNode(kind::IMPLIES, reason, atom) << std::endl;
-        d_sets_theory.d_out->lemma(NodeManager::currentNM()->mkNode(kind::IMPLIES, reason, atom));
+                            <<  NodeManager::currentNM()->mkNode(kind::IMPLIES, Rewriter::rewrite(reason), atom) << std::endl;
+        d_sets_theory.d_out->lemma(NodeManager::currentNM()->mkNode(kind::IMPLIES, Rewriter::rewrite(reason), atom));
 //        Trace("rels-debug") << "[sets-rels] found a conflict and send out lemma : "
 //                            << AND(reason.negate(), atom) << std::endl;
 //        d_sets_theory.d_out->conflict(AND(reason.negate(), atom));
@@ -319,53 +320,67 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
     Node atom = polarity ? exp : exp[0];
     Node r1_rep = getRepresentative(product_term[0]);
     Node r2_rep = getRepresentative(product_term[1]);
+    Trace("rels-debug") << "\n[sets-rels] Apply PRODUCT-SPLIT rule on term: " << product_term
+                        << " with explanation: " << exp << std::endl;
+    std::vector<Node> r1_element;
+    std::vector<Node> r2_element;
 
-    if(polarity) {
-      Trace("rels-debug") << "\n[sets-rels] Apply PRODUCT-SPLIT rule on term: " << product_term
-                          << " with explanation: " << exp << std::endl;
-      std::vector<Node> r1_element;
-      std::vector<Node> r2_element;
-
-      NodeManager *nm = NodeManager::currentNM();
-      Datatype dt = r1_rep.getType().getSetElementType().getDatatype();
-      unsigned int i = 0;
-      unsigned int s1_len = r1_rep.getType().getSetElementType().getTupleLength();
-      unsigned int tup_len = product_term.getType().getSetElementType().getTupleLength();
+    NodeManager *nm = NodeManager::currentNM();
+    Datatype dt = r1_rep.getType().getSetElementType().getDatatype();
+    unsigned int i = 0;
+    unsigned int s1_len = r1_rep.getType().getSetElementType().getTupleLength();
+    unsigned int tup_len = product_term.getType().getSetElementType().getTupleLength();
 
-      r1_element.push_back(Node::fromExpr(dt[0].getConstructor()));
-      for(; i < s1_len; ++i) {
-        r1_element.push_back(nthElementOfTuple(atom[0], i));
-      }
+    r1_element.push_back(Node::fromExpr(dt[0].getConstructor()));
+    for(; i < s1_len; ++i) {
+      r1_element.push_back(nthElementOfTuple(atom[0], i));
+    }
 
-      dt = r2_rep.getType().getSetElementType().getDatatype();
-      r2_element.push_back(Node::fromExpr(dt[0].getConstructor()));
-      for(; i < tup_len; ++i) {
-        r2_element.push_back(nthElementOfTuple(atom[0], i));
-      }
+    dt = r2_rep.getType().getSetElementType().getDatatype();
+    r2_element.push_back(Node::fromExpr(dt[0].getConstructor()));
+    for(; i < tup_len; ++i) {
+      r2_element.push_back(nthElementOfTuple(atom[0], i));
+    }
 
-      Node fact;
-      Node reason = exp;
-      Node t1 = getRepresentative(nm->mkNode(kind::APPLY_CONSTRUCTOR, r1_element));
-      Node t2 = getRepresentative(nm->mkNode(kind::APPLY_CONSTRUCTOR, r2_element));
-
-      if(!hasMember(r1_rep, t1)) {
-        fact = MEMBER( t1, r1_rep );
-        if(r1_rep != product_term[0])
-          reason = Rewriter::rewrite(AND(reason, EQUAL(r1_rep, product_term[0])));
-        addToMap(d_membership_db, r1_rep, t1);
-        addToMap(d_membership_exp_db, r1_rep, reason);
-        sendInfer(fact, reason, "product-split");
+    Node fact_1;
+    Node fact_2;
+    Node reason_1 = exp;
+    Node reason_2 = exp;
+    Node t1 = nm->mkNode(kind::APPLY_CONSTRUCTOR, r1_element);
+    Node t1_rep = getRepresentative(t1);
+    Node t2 = nm->mkNode(kind::APPLY_CONSTRUCTOR, r2_element);
+    Node t2_rep = getRepresentative(t2);
+
+    fact_1 = MEMBER( t1, r1_rep );
+    fact_2 = MEMBER( t2, r2_rep );
+    if(r1_rep != product_term[0]) {
+      reason_1 = AND(reason_1, explain(EQUAL(r1_rep, product_term[0])));
+    }
+    if(t1 != t1_rep) {
+      reason_1 = Rewriter::rewrite(AND(reason_1, explain(EQUAL(t1, t1_rep))));
+    }
+    if(r2_rep != product_term[1]) {
+      reason_2 = AND(reason_2, explain(EQUAL(r2_rep, product_term[1])));
+    }
+    if(t2 != t2_rep) {
+      reason_2 = Rewriter::rewrite(AND(reason_2, explain(EQUAL(t2, t2_rep))));
+    }
+    if(polarity) {
+      if(!hasMember(r1_rep, t1_rep)) {
+        addToMap(d_membership_db, r1_rep, t1_rep);
+        addToMap(d_membership_exp_db, r1_rep, reason_1);
+        sendInfer(fact_1, reason_1, "product-split");
       }
-
       if(!hasMember(r2_rep, t2)) {
-        fact = MEMBER( t2, r2_rep );
-        if(r2_rep != product_term[1])
-          reason = Rewriter::rewrite(AND(reason, EQUAL(r2_rep, product_term[1])));
         addToMap(d_membership_db, r2_rep, t2);
-        addToMap(d_membership_exp_db, r2_rep, reason);
-        sendInfer(fact, reason, "product-split");
+        addToMap(d_membership_exp_db, r2_rep, reason_2);
+        sendInfer(fact_2, reason_2, "product-split");
       }
+
     } else {
+//      sendInfer(fact_1.negate(), reason_1, "product-split");
+//      sendInfer(fact_2.negate(), reason_2, "product-split");
+
       // ONLY need to explicitly compute joins if there are negative literals involving PRODUCT
       Trace("rels-debug") << "\n[sets-rels] Apply PRODUCT-COMPOSE rule on term: " << product_term
                           << " with explanation: " << exp << std::endl;
@@ -528,15 +543,16 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
     }
   }
 
-  // Todo: need to add equality between two pair's left and right elements as explanation
+
   void TheorySetsRels::inferTC( Node exp, Node tc_rep, std::map< Node, std::hash_set< Node, NodeHashFunction > >& tc_graph,
                                 Node start_node, Node cur_node, std::hash_set< Node, NodeHashFunction >& elements, bool first_round ) {
     Node pair = constructPair(tc_rep, start_node, cur_node);
     if(safeAddToMap(d_membership_db, tc_rep, pair)) {
-      addToMap(d_membership_exp_db, tc_rep, exp);
-      sendLemma( MEMBER(pair, tc_rep), exp, "Transitivity" );
+      addToMap(d_membership_exp_cache, tc_rep, Rewriter::rewrite(exp));
+      sendLemma( MEMBER(pair, tc_rep), Rewriter::rewrite(exp), "Transitivity" );
     }
 
+    // check if cur_node has been traversed or not
     if(!first_round) {
       std::hash_set< Node, NodeHashFunction >::iterator ele_it = elements.begin();
       while(ele_it != elements.end()) {
@@ -547,8 +563,10 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
       }
     }
     std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator pair_set_it = tc_graph.begin();
+    Node reason = exp;
     while(pair_set_it != tc_graph.end()) {
       if(areEqual(pair_set_it->first, cur_node)) {
+        reason = AND(exp, EQUAL(pair_set_it->first, cur_node));
         break;
       }
       pair_set_it++;
@@ -557,10 +575,9 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
       for(std::hash_set< Node, NodeHashFunction >::iterator set_it = pair_set_it->second.begin();
           set_it != pair_set_it->second.end(); set_it++) {
         Node p = constructPair( tc_rep, cur_node, *set_it );
-        Node reason = AND( findMemExp(tc_rep, p), exp );
         Assert(!reason.isNull());
         elements.insert(*set_it);
-        inferTC( reason, tc_rep, tc_graph, start_node, *set_it, elements, false );
+        inferTC( AND( findMemExp(tc_rep, p), reason ), tc_rep, tc_graph, start_node, *set_it, elements, false );
       }
     }
   }
@@ -574,7 +591,6 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
         std::hash_set<Node, NodeHashFunction> elements;
         Node pair = constructPair(tc_rep, pair_set_it->first, *set_it);
         Node exp = findMemExp(tc_rep, pair);
-        Trace("rels-debug") << "[sets-rels] pair = " << pair << std::endl;
         if(d_membership_tc_exp_cache.find(tc_rep) != d_membership_tc_exp_cache.end()) {
           exp = AND(d_membership_tc_exp_cache[tc_rep], exp);
         }
@@ -753,7 +769,8 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
   void TheorySetsRels::doPendingLemmas() {
     if( !(*d_conflict) && (!d_lemma_cache.empty() || !d_pending_facts.empty())){
       for( unsigned i=0; i < d_lemma_cache.size(); i++ ){
-        if(holds( d_lemma_cache[i] )) {
+        Assert(d_lemma_cache[i].getKind() == kind::IMPLIES);
+        if(holds( d_lemma_cache[i][1] )) {
           Trace("rels-lemma") << "[sets-rels-lemma-skip] Skip the already held lemma: "
                               << d_lemma_cache[i]<< std::endl;
           continue;
@@ -775,6 +792,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
         d_sets_theory.d_out->lemma(NodeManager::currentNM()->mkNode(kind::IMPLIES, child_it->second, child_it->first));
       }
     }
+    d_tc_nodes.clear();
     d_pending_facts.clear();
     d_membership_constraints_cache.clear();
     d_membership_tc_cache.clear();
@@ -890,7 +908,6 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
   }
 
   bool TheorySetsRels::areEqual( Node a, Node b ){
-    Trace("rels-debug") << "[sets-rels] areEqual( a = " << a << ", b = " << b << ")" << std::endl;
     if(a == b) {
       return true;
     } else if( hasTerm( a ) && hasTerm( b ) ){
@@ -936,28 +953,49 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
   }
 
   inline Node TheorySetsRels::getReason(Node tc_rep, Node tc_term, Node tc_r_rep, Node tc_r) {
+    Trace("rels-reason") << "[sets-rels] getReason(" << tc_rep << ", " << tc_term << ", " << tc_r_rep << ", " << tc_r << std::endl;
     if(tc_term != tc_rep) {
       Node reason = explain(EQUAL(tc_term, tc_rep));
       if(tc_term[0] != tc_r_rep) {
         return AND(reason, explain(EQUAL(tc_term[0], tc_r_rep)));
       }
     }
+    Trace("rels-reason") << "[sets-rels] done getReason(" << tc_rep << ", " << tc_term << ", " << tc_r_rep << ", " << tc_r << std::endl;
     return Node::null();
   }
 
-  // tuple might be a member of tc_rep; or it might be a member of tc_terms
+  // tuple might be a member of tc_rep; or it might be a member of rels or tc_terms such that
+  // tc_terms are transitive closure of rels and are modulo equal to tc_rep
   Node TheorySetsRels::findMemExp(Node tc_rep, Node tuple) {
     Trace("rels-exp") << "TheorySetsRels::findMemExp ( tc_rep = " << tc_rep << ", tuple = " << tuple << ")" << std::endl;
     std::vector<Node> tc_terms = d_terms_cache.find(tc_rep)->second[kind::TRANSCLOSURE];
     Assert(tc_terms.size() > 0);
     for(unsigned int i = 0; i < tc_terms.size(); i++) {
-      Node r_rep = getRepresentative(tc_terms[i][0]);
-      Trace("rels-exp") << "TheorySetsRels::findMemExp ( r_rep = " << r_rep << ", tuple = " << tuple << ")" << std::endl;
-      std::map< Node, std::vector< Node > >::iterator tc_r_mems = d_membership_db.find(r_rep);
+      Node tc_term = tc_terms[i];
+      Node tc_r_rep = getRepresentative(tc_term[0]);
+
+      Trace("rels-exp") << "TheorySetsRels::findMemExp ( r_rep = " << tc_r_rep << ", tuple = " << tuple << ")" << std::endl;
+      std::map< Node, std::vector< Node > >::iterator tc_r_mems = d_membership_db.find(tc_r_rep);
       if(tc_r_mems != d_membership_db.end()) {
         for(unsigned int i = 0; i < tc_r_mems->second.size(); i++) {
           if(areEqual(tc_r_mems->second[i], tuple)) {
-            return explain(d_membership_exp_db[r_rep][i]);
+            Node exp = d_trueNode;
+            if(tc_r_rep != tc_term[0]) {
+              exp = explain(EQUAL(tc_r_rep, tc_term[0]));
+            }
+            if(tc_rep != tc_term) {
+              exp = AND(exp, explain(EQUAL(tc_rep, tc_term)));
+            }
+            if(tc_r_mems->second[i] != tuple) {
+              if(nthElementOfTuple(tc_r_mems->second[i], 0) != nthElementOfTuple(tuple, 0)) {
+                exp = AND(exp, explain(EQUAL(nthElementOfTuple(tc_r_mems->second[i], 0), nthElementOfTuple(tuple, 0))));
+              }
+              if(nthElementOfTuple(tc_r_mems->second[i], 1) != nthElementOfTuple(tuple, 1)) {
+                exp = AND(exp, explain(EQUAL(nthElementOfTuple(tc_r_mems->second[i], 1), nthElementOfTuple(tuple, 1))));
+              }
+              exp = AND(exp, EQUAL(tc_r_mems->second[i], tuple));
+            }
+            return Rewriter::rewrite(AND(exp, explain(d_membership_exp_db[tc_r_rep][i])));
           }
         }
       }
@@ -966,9 +1004,25 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
       std::map< Node, std::vector< Node > >::iterator tc_t_mems = d_membership_db.find(tc_term_rep);
       Trace("rels-exp") << "TheorySetsRels::findMemExp ( tc_t_rep = " << tc_term_rep << ", tuple = " << tuple << ")" << std::endl;
       if(tc_t_mems != d_membership_db.end()) {
-        for(unsigned int i = 0; i < tc_t_mems->second.size(); i++) {
-          if(areEqual(tc_t_mems->second[i], tuple)) {
-            return explain(d_membership_exp_db[tc_term_rep][i]);
+        for(unsigned int j = 0; j < tc_t_mems->second.size(); j++) {
+          if(areEqual(tc_t_mems->second[j], tuple)) {
+            Node exp = d_trueNode;
+            if(tc_rep != tc_terms[i]) {
+              exp = AND(exp, explain(EQUAL(tc_rep, tc_terms[i])));
+            }
+            if(tc_term_rep != tc_terms[i]) {
+              exp = AND(exp, explain(EQUAL(tc_term_rep, tc_terms[i])));
+            }
+            if(tc_t_mems->second[j] != tuple) {
+              if(nthElementOfTuple(tc_t_mems->second[j], 0) != nthElementOfTuple(tuple, 0)) {
+                exp = AND(exp, explain(EQUAL(nthElementOfTuple(tc_t_mems->second[j], 0), nthElementOfTuple(tuple, 0))));
+              }
+              if(nthElementOfTuple(tc_t_mems->second[j], 1) != nthElementOfTuple(tuple, 1)) {
+                exp = AND(exp, explain(EQUAL(nthElementOfTuple(tc_t_mems->second[j], 1), nthElementOfTuple(tuple, 1))));
+              }
+              exp = AND(exp, EQUAL(tc_t_mems->second[j], tuple));
+            }
+            return Rewriter::rewrite(AND(exp, explain(d_membership_exp_db[tc_term_rep][j])));
           }
         }
       }
@@ -1155,7 +1209,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
 
   Node TheorySetsRels::explain(Node literal)
   {
-    Trace("rels-debug") << "[sets-rels] TheorySetsRels::explain(" << literal << ")"<< std::endl;
+    Trace("rels-exp") << "[sets-rels] TheorySetsRels::explain(" << literal << ")"<< std::endl;
 
     bool polarity = literal.getKind() != kind::NOT;
     TNode atom = polarity ? literal : literal[0];
@@ -1169,11 +1223,11 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
       }
       d_eqEngine->explainPredicate(atom, polarity, assumptions);
     } else {
-      Trace("rels-debug") << "unhandled: " << literal << "; (" << atom << ", "
+      Trace("rels-exp") << "unhandled: " << literal << "; (" << atom << ", "
                     << polarity << "); kind" << atom.getKind() << std::endl;
       Unhandled();
     }
-    Trace("rels-debug") << "[sets-rels] ****** done with TheorySetsRels::explain(" << literal << ")"<< std::endl;
+    Trace("rels-exp") << "[sets-rels] ****** done with TheorySetsRels::explain(" << literal << ")"<< std::endl;
     return mkAnd(assumptions);
   }
 
index 8fc107a826a22c64c0fabc9d79e120bed9c0b20d..0876cc5b3016bbde83e9742981eeb0dac2e519c7 100644 (file)
@@ -100,6 +100,7 @@ private:
   NodeSet d_lemma;
   NodeSet d_shared_terms;
 
+  std::hash_set< Node, NodeHashFunction > d_tc_nodes;
   std::map< Node, std::vector<Node> > d_tuple_reps;
   std::map< Node, TupleTrie > d_membership_trie;
   std::hash_set< Node, NodeHashFunction > d_symbolic_tuples;