fixed product rules
authorPaulMeng <baolmeng@gmail.com>
Tue, 1 Mar 2016 20:17:08 +0000 (14:17 -0600)
committerPaulMeng <baolmeng@gmail.com>
Tue, 1 Mar 2016 20:17:08 +0000 (14:17 -0600)
src/theory/sets/theory_sets_rels.cpp
src/theory/sets/theory_sets_rels.h

index ec2d28fa68acba779e70547ddd80f6112c014425..986c4cf005591a24cc5b2b1be0ae9735cfe5c173 100644 (file)
@@ -93,7 +93,9 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
     }
   }
 
-
+  /*
+   * Polulate the relational terms data structure
+   */
 
   void TheorySetsRels::collectRelationalInfo() {
     Trace("rels-debug") << "[sets-rels] Start collecting relational terms..." << std::endl;
@@ -164,105 +166,143 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
     Trace("rels-debug") << "[sets-rels] Done with collecting relational terms!" << std::endl;
   }
 
-  void TheorySetsRels::doPendingFacts() {
-    std::map<Node, Node>::iterator map_it = d_pending_facts.begin();
-    while( !(*d_conflict) && map_it != d_pending_facts.end()) {
+ /*  join-split rule:    (a, b) IS_IN (X PRODUCT Y)
+  *                  ----------------------------------
+  *                       a IS_IN X  && b IS_IN Y
+  *
+  *  product-compose rule: (a, b) IS_IN X    (c, d) IS_IN Y  NOT (r, s, t, u) IS_IN (X PRODUCT Y)
+  *                        ----------------------------------------------------------------------
+  *                                         (a, b, c, d) IS_IN (X PRODUCT Y)
+  */
 
-      Node fact = map_it->first;
-      Node exp = d_pending_facts[ fact ];
-      if(fact.getKind() == kind::AND) {
-        for(size_t j=0; j<fact.getNumChildren(); j++) {
-          bool polarity = fact[j].getKind() != kind::NOT;
-          Node atom = polarity ? fact[j] : fact[j][0];
-          assertMembership(atom, exp, polarity);
+  void TheorySetsRels::applyProductRule(Node exp, Node rel_rep, Node product_term) {
+    Trace("rels-debug") << "\n[sets-rels] Apply PRODUCT rule on term: " << product_term
+                        << " with explaination: " << exp << std::endl;
+    bool polarity = exp.getKind() != kind::NOT;
+    Node atom = polarity ? exp : exp[0];
+    Node r1_rep = getRepresentative(product_term[0]);
+    Node r2_rep = getRepresentative(product_term[1]);
+
+    if(polarity) {
+      Node t1;
+      Node t2;
+      unsigned int s1_len = 1;
+      std::vector<Node> r1_element;
+      std::vector<Node> r2_element;
+      Node::iterator child_it = atom[0].begin();
+      NodeManager *nm = NodeManager::currentNM();
+
+      if(r1_rep.getType().getSetElementType().isTuple()) {
+        Datatype dt = r1_rep.getType().getSetElementType().getDatatype();
+        s1_len = r1_rep.getType().getSetElementType().getTupleLength();
+        r1_element.push_back(Node::fromExpr(dt[0].getConstructor()));
+        for(unsigned int i = 0; i < s1_len; ++child_it) {
+          r1_element.push_back(*child_it);
+          ++i;
         }
+        t1 = nm->mkNode(kind::APPLY_CONSTRUCTOR, r1_element);
       } else {
-        bool polarity = fact.getKind() != kind::NOT;
-        Node atom = polarity ? fact : fact[0];
-        assertMembership(atom, exp, polarity);
+        t1 = *child_it;
+        ++child_it;
       }
-      map_it++;
-    }
-    d_pending_facts.clear();
-    d_membership_cache.clear();
-    d_membership_exp_cache.clear();
-    d_terms_cache.clear();
-  }
 
-  void TheorySetsRels::doPendingSplitFacts() {
-      std::map<Node, Node>::iterator map_it = d_pending_split_facts.begin();
-      while( !(*d_conflict) && map_it != d_pending_split_facts.end()) {
-
-        Node fact = map_it->first;
-        Node exp = d_pending_split_facts[ fact ];
-        if(fact.getKind() == kind::AND) {
-          for(size_t j=0; j<fact.getNumChildren(); j++) {
-            bool polarity = fact[j].getKind() != kind::NOT;
-            Node atom = polarity ? fact[j] : fact[j][0];
-            assertMembership(atom, exp, polarity);
-          }
-        } else {
-          bool polarity = fact.getKind() != kind::NOT;
-          Node atom = polarity ? fact : fact[0];
-          assertMembership(atom, exp, polarity);
+      if(r2_rep.getType().getSetElementType().isTuple()) {
+        Datatype dt = r2_rep.getType().getSetElementType().getDatatype();
+        r2_element.push_back(Node::fromExpr(dt[0].getConstructor()));
+        for(; child_it != atom[0].end(); ++child_it) {
+          r2_element.push_back(*child_it);
         }
-        map_it++;
+        t2 = nm->mkNode(kind::APPLY_CONSTRUCTOR, r2_element);
+      } else {
+        t2 = *child_it;
       }
-      d_pending_split_facts.clear();
-    }
-
-  void TheorySetsRels::applyProductRule(Node exp, Node rel_rep, Node product_term) {
-    Trace("rels-debug") << "\n[sets-rels] Apply PRODUCT rule on term: " << product_term
-                        << " with explaination: " << exp << std::endl;
-    bool polarity = exp.getKind() != kind::NOT;
-    Node atom = polarity ? exp : exp[0];
-    if(!polarity)
+      Node f1 = nm->mkNode(kind::MEMBER, t1, r1_rep);
+      Node f2 = nm->mkNode(kind::MEMBER, t2, r2_rep);
+      Node reason = exp;
+      if(atom[1] != product_term)
+        reason = AND(reason, EQUAL(atom[1], product_term));
+      if(r1_rep != product_term[0])
+        reason = AND(reason, EQUAL(r1_rep, product_term[0]));
+      if(r2_rep != product_term[1])
+        reason = Rewriter::rewrite(AND(reason, EQUAL(r2_rep, product_term[1])));
+
+      sendInfer(f1, reason, "product-split");
+      sendInfer(f2, reason, "product-split");
+    } else {
+      // ONLY need to explicitly compute joins if there are negative literals involving PRODUCT
       computeJoinOrProductRelations(product_term);
+    }
   }
 
+  /* join-split rule:           (a, b) IS_IN (X JOIN Y)
+   *                  --------------------------------------------
+   *                  exists z | (a, z) IS_IN X  && (z, b) IS_IN Y
+   *
+   *
+   * join-compose rule: (a, b) IS_IN X    (b, c) IS_IN Y  NOT (t, u) IS_IN (X JOIN Y)
+   *                    -------------------------------------------------------------
+   *                                      (a, c) IS_IN (X JOIN Y)
+   */
   void TheorySetsRels::applyJoinRule(Node exp, Node rel_rep, Node join_term) {
     Trace("rels-debug") <<  "\n[sets-rels] Apply JOIN rule on term: " << join_term
                         << " with explaination: " << exp << std::endl;
     bool polarity = exp.getKind() != kind::NOT;
     Node atom = polarity ? exp : exp[0];
-    Node r1 = join_term[0];
-    Node r2 = join_term[1];
+    Node r1_rep = getRepresentative(join_term[0]);
+    Node r2_rep = getRepresentative(join_term[1]);
 
-    // join-split rule: (a, b) IS_IN (X JOIN Y)
-    //               => (a, z) IS_IN X  && (z, b) IS_IN Y
     if(polarity) {
-      Debug("rels-join") << "[sets-rels] Join rules (a, b) IS_IN (X JOIN Y) => "
-                            "((a, z) IS_IN X  && (z, b) IS_IN Y)"<< std::endl;
-      Assert((r1.getType().getSetElementType()).isDatatype());
-      Assert((r2.getType().getSetElementType()).isDatatype());
-
-      unsigned int i = 0;
-      std::vector<Node> r1_tuple;
-      std::vector<Node> r2_tuple;
+      Node t1;
+      Node t2;
+      TypeNode shared_type;
+      unsigned int s1_len = 1;
+      std::vector<Node> r1_element;
+      std::vector<Node> r2_element;
       Node::iterator child_it = atom[0].begin();
-      unsigned int s1_len = r1.getType().getSetElementType().getTupleLength();
-      Node shared_x = NodeManager::currentNM()->mkSkolem("sde_", r2.getType().getSetElementType().getTupleTypes()[0]);
-      Datatype dt = r1.getType().getSetElementType().getDatatype();
-
-      r1_tuple.push_back(Node::fromExpr(dt[0].getConstructor()));
-      for(; i < s1_len-1; ++child_it) {
-        r1_tuple.push_back(*child_it);
-        ++i;
+      NodeManager *nm = NodeManager::currentNM();
+
+      if(r2_rep.getType().getSetElementType().isTuple()) {
+        shared_type = r2_rep.getType().getSetElementType().getTupleTypes()[0];
+      } else {
+        shared_type = r2_rep.getType().getSetElementType();
+      }
+
+      Node shared_x = nm->mkSkolem("sde_", shared_type);
+      if(r1_rep.getType().getSetElementType().isTuple()) {
+        Datatype dt = r1_rep.getType().getSetElementType().getDatatype();
+        s1_len = r1_rep.getType().getSetElementType().getTupleLength();
+        r1_element.push_back(Node::fromExpr(dt[0].getConstructor()));
+        for(unsigned int i = 0; i < s1_len-1; ++child_it) {
+          r1_element.push_back(*child_it);
+          ++i;
+        }
+        r1_element.push_back(shared_x);
+        t1 = nm->mkNode(kind::APPLY_CONSTRUCTOR, r1_element);
+      } else {
+        t1 = shared_x;
       }
-      r1_tuple.push_back(shared_x);
-      dt = r2.getType().getSetElementType().getDatatype();
-      r2_tuple.push_back(Node::fromExpr(dt[0].getConstructor()));
-      r2_tuple.push_back(shared_x);
-      for(; child_it != atom[0].end(); ++child_it) {
-        r2_tuple.push_back(*child_it);
+
+      if(r2_rep.getType().getSetElementType().isTuple()) {
+        Datatype dt = r2_rep.getType().getSetElementType().getDatatype();
+        r2_element.push_back(Node::fromExpr(dt[0].getConstructor()));
+        r2_element.push_back(shared_x);
+        for(; child_it != atom[0].end(); ++child_it) {
+          r2_element.push_back(*child_it);
+        }
+        t2 = nm->mkNode(kind::APPLY_CONSTRUCTOR, r1_element);
+      } else {
+        t2 = shared_x;
       }
-      Node t1 = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, r1_tuple);
-      Node t2 = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, r2_tuple);
-      Node f1 = NodeManager::currentNM()->mkNode(kind::MEMBER, t1, r1);
-      Node f2 = NodeManager::currentNM()->mkNode(kind::MEMBER, t2, r2);
+
+      Node f1 = nm->mkNode(kind::MEMBER, t1, r1_rep);
+      Node f2 = nm->mkNode(kind::MEMBER, t2, r2_rep);
       Node reason = exp;
       if(atom[1] != join_term)
         reason = AND(reason, EQUAL(atom[1], join_term));
+      if(r1_rep != join_term[0])
+        reason = AND(reason, EQUAL(r1_rep, join_term[0]));
+      if(r2_rep != join_term[1])
+        reason = Rewriter::rewrite(AND(reason, EQUAL(r2_rep, join_term[1])));
       sendInfer(f1, reason, "join-split");
       sendInfer(f2, reason, "join-split");
     } else {
@@ -271,6 +311,10 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
     }
   }
 
+  /* transpose rule: (a, b) IS_IN X   NOT (t, u) IS_IN (TRANSPOSE X)
+   *                ------------------------------------------------
+   *                         (b, a) IS_IN (TRANSPOSE X)
+   */
   void TheorySetsRels::applyTransposeRule(Node exp, Node rel_rep, Node tp_term) {
     Trace("rels-debug") << "\n[sets-rels] Apply transpose rule on term: " << tp_term
                         << " with explaination: " << exp << std::endl;
@@ -283,8 +327,6 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
       reason = AND(reason, EQUAL(rel_rep, tp_term));
     Node fact = MEMBER(reversedTuple, tp_term[0]);
 
-    // when the term is nested like (not tup is_in tp(x join/product y)),
-    // we need to compute what is inside x join/product y
     if(!polarity) {
       if(d_terms_cache[getRepresentative(fact[1])].find(kind::JOIN)
          != d_terms_cache[getRepresentative(fact[1])].end()) {
@@ -334,7 +376,7 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
     if(d_membership_cache.find(getRepresentative(n[0])) == d_membership_cache.end() ||
        d_membership_cache.find(getRepresentative(n[1])) == d_membership_cache.end())
           return;
-    composeRelations(n);
+    composeTuplesForRels(n);
   }
 
   void TheorySetsRels::computeTransposeRelations(Node n) {
@@ -379,18 +421,19 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
     d_membership_exp_cache[n_rep] = rev_exps;
   }
 
-  // Join all explicitly specified tuples in r1, r2
-  // e.g. If (a, b) in X and (b, c) in Y, (a, c) in (X JOIN Y)
-  void TheorySetsRels::composeRelations(Node n) {
+  /*
+   * Explicitly compose the join or product relations of r1 and r2
+   * e.g. If (a, b) in X and (b, c) in Y, (a, c) in (X JOIN Y)
+   *
+   */
+  void TheorySetsRels::composeTuplesForRels( Node n ) {
     Node r1 = n[0];
     Node r2 = n[1];
     Node r1_rep = getRepresentative(r1);
     Node r2_rep = getRepresentative(r2);
     NodeManager* nm = NodeManager::currentNM();
     Trace("rels-debug") << "[sets-rels] start composing tuples in relations "
-                       << r1 << " and " << r2
-                       << "\n r1_rep: " << r1_rep
-                       << "\n r2_rep: " << r2_rep << std::endl;
+                        << r1 << " and " << r2 << std::endl;
 
     if(d_membership_cache.find(r1_rep) == d_membership_cache.end() ||
        d_membership_cache.find(r2_rep) == d_membership_cache.end())
@@ -434,7 +477,7 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
           }
           if(kind::PRODUCT == n.getKind()) {
             composedTuple.push_back(r1_elements[i][k]);
-            composedTuple.push_back(r1_elements[j][0]);
+            composedTuple.push_back(r2_elements[j][0]);
           }
           for(; l < t2_len; ++l) {
             composedTuple.push_back(r2_elements[j][l]);
@@ -534,6 +577,53 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
     d_infer_exp.push_back( exp );
   }
 
+  void TheorySetsRels::doPendingFacts() {
+    std::map<Node, Node>::iterator map_it = d_pending_facts.begin();
+    while( !(*d_conflict) && map_it != d_pending_facts.end()) {
+
+      Node fact = map_it->first;
+      Node exp = d_pending_facts[ fact ];
+      if(fact.getKind() == kind::AND) {
+        for(size_t j=0; j<fact.getNumChildren(); j++) {
+          bool polarity = fact[j].getKind() != kind::NOT;
+          Node atom = polarity ? fact[j] : fact[j][0];
+          assertMembership(atom, exp, polarity);
+        }
+      } else {
+        bool polarity = fact.getKind() != kind::NOT;
+        Node atom = polarity ? fact : fact[0];
+        assertMembership(atom, exp, polarity);
+      }
+      map_it++;
+    }
+    d_pending_facts.clear();
+    d_membership_cache.clear();
+    d_membership_exp_cache.clear();
+    d_terms_cache.clear();
+  }
+
+  void TheorySetsRels::doPendingSplitFacts() {
+    std::map<Node, Node>::iterator map_it = d_pending_split_facts.begin();
+    while( !(*d_conflict) && map_it != d_pending_split_facts.end()) {
+
+      Node fact = map_it->first;
+      Node exp = d_pending_split_facts[ fact ];
+      if(fact.getKind() == kind::AND) {
+        for(size_t j=0; j<fact.getNumChildren(); j++) {
+          bool polarity = fact[j].getKind() != kind::NOT;
+          Node atom = polarity ? fact[j] : fact[j][0];
+          assertMembership(atom, exp, polarity);
+        }
+      } else {
+        bool polarity = fact.getKind() != kind::NOT;
+        Node atom = polarity ? fact : fact[0];
+        assertMembership(atom, exp, polarity);
+      }
+      map_it++;
+    }
+    d_pending_split_facts.clear();
+  }
+
   Node TheorySetsRels::reverseTuple( Node tuple ) {
     Assert( tuple.getType().isTuple() );
     std::vector<Node> elements;
@@ -600,7 +690,6 @@ typedef std::map<Node, std::vector<Node> >::iterator mem_it;
 //              break;
 //          }
 //      }
-      Trace("rels-debug") << "has a and b " << a << " " << b << " are equal? "<<  d_eqEngine->areEqual( a, b ) << std::endl;
       return d_eqEngine->areEqual( a, b );
     }else if( a.isConst() && b.isConst() ){
       return a == b;
index 16329fd43172f73d67266d4f5ff256a7f3d5ab62..827caf69fa858f083cf4fbd438f58a3536e099a8 100644 (file)
@@ -73,7 +73,7 @@ private:
   void assertMembership( Node fact, Node reason, bool polarity );
   void composeProductRelations( Node );
   void composeJoinRelations( Node );
-  void composeRelations( Node );
+  void composeTuplesForRels( Node );
   void applyTransposeRule( Node, Node, Node );
   void applyJoinRule( Node, Node, Node );
   void applyProductRule( Node, Node, Node );