implemented a basic solving procedure for finite relations (only for
authorPaulMeng <baolmeng@gmail.com>
Sun, 28 Feb 2016 22:22:43 +0000 (16:22 -0600)
committerPaulMeng <baolmeng@gmail.com>
Sun, 28 Feb 2016 22:22:43 +0000 (16:22 -0600)
join, product, transpose operators)

src/theory/sets/theory_sets.h
src/theory/sets/theory_sets_private.cpp
src/theory/sets/theory_sets_private.h
src/theory/sets/theory_sets_rels.cpp
src/theory/sets/theory_sets_rels.h
src/theory/sets/theory_sets_rewriter.cpp
test/regress/regress0/sets/rels/rel_join_0.cvc
test/regress/regress0/sets/rels/rel_transpose_0.cvc

index bc39fcbbd3b5c45aad07e593447f50586eca1b78..9e08b597d0b548cb9f781bbf8a8dcc5c2e934e2d 100644 (file)
@@ -33,6 +33,7 @@ class TheorySets : public Theory {
 private:
   friend class TheorySetsPrivate;
   friend class TheorySetsScrutinize;
+  friend class TheorySetsRels;
   TheorySetsPrivate* d_internal;
 public:
 
index 5e328b4fdf9ec4f5fddeede2f8a72f07a76e55bf..4cb82b66d79b6c39b06a7bbd2de3a5a659a8d4f7 100644 (file)
@@ -94,9 +94,11 @@ void TheorySetsPrivate::check(Theory::Effort level) {
     if(d_conflict) { return; }
     Debug("sets") << "[sets]  is complete = " << isComplete() << std::endl;
 
-    d_rels->check(level);
   }
-
+  d_rels->check(level);
+//  if( level == Theory::EFFORT_FULL ) {
+//    d_rels->doPendingLemmas();
+//  }
   if( (level == Theory::EFFORT_FULL || options::setsEagerLemmas() ) && !isComplete()) {
     d_external.d_out->lemma(getLemma());
     return;
@@ -1111,7 +1113,7 @@ TheorySetsPrivate::TheorySetsPrivate(TheorySets& external,
   d_rels(NULL)
 {
   d_termInfoManager = new TermInfoManager(*this, c, &d_equalityEngine);
-  d_rels = new TheorySetsRels(c, u, &d_equalityEngine, &d_conflict);
+  d_rels = new TheorySetsRels(c, u, &d_equalityEngine, &d_conflict, external);
 
   d_equalityEngine.addFunctionKind(kind::UNION);
   d_equalityEngine.addFunctionKind(kind::INTERSECTION);
index 8cbc17ae3af11ede9ff1b304d4a71cfa068b4870..ad04ff2736a396995652fceb70aa71dab4eb856d 100644 (file)
@@ -71,7 +71,6 @@ public:
 
 private:
   TheorySets& d_external;
-  TheorySetsRels* d_rels;
 
   class Statistics {
   public:
@@ -200,6 +199,7 @@ private:
   // more debugging stuff
   friend class TheorySetsScrutinize;
   TheorySetsScrutinize* d_scrutinize;
+  TheorySetsRels* d_rels;
   void dumpAssertionsHumanified() const;  /** do some formatting to make them more readable */
 };/* class TheorySetsPrivate */
 
index fcab5b5ca461f0e5687247f87caca93016480388..de70e6a52b663517bb809e9af79630695a5e4a78 100644 (file)
@@ -17,6 +17,9 @@
 #include "theory/sets/theory_sets_rels.h"
 
 #include "expr/datatype.h"
+#include "theory/sets/expr_patterns.h"
+#include "theory/sets/theory_sets_private.h"
+#include "theory/sets/theory_sets.h"
 //#include "options/sets_options.h"
 //#include "smt/smt_statistics_registry.h"
 //#include "theory/sets/expr_patterns.h" // ONLY included here
 
 
 using namespace std;
-using namespace CVC4::kind;
+using namespace CVC4::expr::pattern;
 
 namespace CVC4 {
 namespace theory {
 namespace sets {
 
-  TheorySetsRels::TheorySetsRels(context::Context* c,
-                                 context::UserContext* u,
-                                 eq::EqualityEngine* eq,
-                                 context::CDO<bool>* conflict):
-    d_trueNode(NodeManager::currentNM()->mkConst<bool>(true)),
-    d_falseNode(NodeManager::currentNM()->mkConst<bool>(false)),
-    d_eqEngine(eq),
-    d_conflict(conflict),
-    d_relsSaver(c)
-  {
-    d_eqEngine->addFunctionKind(kind::PRODUCT);
-    d_eqEngine->addFunctionKind(kind::JOIN);
-    d_eqEngine->addFunctionKind(kind::TRANSPOSE);
-    d_eqEngine->addFunctionKind(kind::TRANSCLOSURE);
+typedef std::map<Node, std::map<kind::Kind_t, std::vector<Node> > >::iterator term_it;
+typedef std::map<Node, std::vector<Node> >::iterator mem_it;
+
+  void TheorySetsRels::check(Theory::Effort level) {
+    Trace("rels-debug") << "[sets-rels] Start the relational solver..." << std::endl;
+    collectRelationalInfo();
+    check();
+//    doPendingFacts();
+    doPendingLemmas();
+    Assert(d_lemma_cache.empty());
+    Assert(d_pending_facts.empty());
+    Trace("rels-debug") << "[sets-rels] Done with the relational solver..." << std::endl;
   }
 
-  TheorySetsRels::~TheorySetsRels() {}
+  void TheorySetsRels::check() {
+    mem_it m_it = d_membership_cache.begin();
+    while(m_it != d_membership_cache.end()) {
+      std::vector<Node> tuples = m_it->second;
+      Node rel_rep = m_it->first;
+      // No relational terms found with rel_rep as its representative
+      if(d_terms_cache.find(rel_rep) == d_terms_cache.end()) {
+        m_it++;
+        continue;
+      }
+      for(unsigned int i = 0; i < tuples.size(); i++) {
+        Node tup_rep = tuples[i];
+        Node exp = d_membership_exp_cache[rel_rep][i];
+        std::map<kind::Kind_t, std::vector<Node> > kind_terms = d_terms_cache[rel_rep];
+
+        if(kind_terms.find(kind::TRANSPOSE) != kind_terms.end()) {
+          std::vector<Node> tp_terms = kind_terms.at(kind::TRANSPOSE);
+          // exp is a membership term and tp_terms contains all
+          // transposed terms that are equal to the right hand side of exp
+          for(unsigned int j = 0; j < tp_terms.size(); j++) {
+            applyTransposeRule(exp, rel_rep, tp_terms[j]);
+          }
+        }
+        if(kind_terms.find(kind::JOIN) != kind_terms.end()) {
+          std::vector<Node> conj;
+          std::vector<Node> join_terms = kind_terms.at(kind::JOIN);
+          // exp is a membership term and join_terms contains all
+          // joined terms that are in the same equivalence class with the right hand side of exp
+          for(unsigned int j = 0; j < join_terms.size(); j++) {
+            applyJoinRule(exp, rel_rep, join_terms[j]);
+          }
+        }
+        if(kind_terms.find(kind::PRODUCT) != kind_terms.end()) {
+          std::vector<Node> product_terms = kind_terms.at(kind::PRODUCT);
+          for(unsigned int j = 0; j < product_terms.size(); j++) {
+            applyProductRule(exp, rel_rep, product_terms[j]);
+          }
+        }
+      }
+      m_it++;
+    }
+  }
 
-  void TheorySetsRels::check(Theory::Effort level) {
 
-    Debug("rels-eqc") <<  "\nStart iterating equivalence classes......\n" << std::endl;
 
-    if (!d_eqEngine->consistent())
-      return;
+  void TheorySetsRels::collectRelationalInfo() {
+    Trace("rels-debug") << "[sets-rels] Start collecting relational terms..." << std::endl;
     eq::EqClassesIterator eqcs_i = eq::EqClassesIterator( d_eqEngine );
-
     while( !eqcs_i.isFinished() ){
-      TNode r = (*eqcs_i);
+      Node r = (*eqcs_i);
       eq::EqClassIterator eqc_i = eq::EqClassIterator( r, d_eqEngine );
-
+      Trace("rels-ee") << "[sets-rels] term representative: " << r << std::endl;
       while( !eqc_i.isFinished() ){
-        TNode n = (*eqc_i);
-
-        // only consider membership constraints that involving relatioinal operators
-        if((d_eqEngine->getRepresentative(r) == d_eqEngine->getRepresentative(d_trueNode)
-              || d_eqEngine->getRepresentative(r) == d_eqEngine->getRepresentative(d_falseNode))
-            && !d_relsSaver.contains(n)) {
-
-          // case: [NOT] (b, a) IS_IN (TRANSPOSE X)
-          //    => [NOT] (a, b) IS_IN X
-          if(n.getKind() == kind::MEMBER) {
-            d_relsSaver.insert(n);
-            if(kind::TRANSPOSE == n[1].getKind()) {
-              Node reversedTuple = reverseTuple(n[0]);
-              Node fact = NodeManager::currentNM()->mkNode(kind::MEMBER, reversedTuple, n[1][0]);
+        Node n = (*eqc_i);
+        Trace("rels-ee") << "  term : " << n << std::endl;
+        if(getRepresentative(r) == getRepresentative(d_trueNode) ||
+           getRepresentative(r) == getRepresentative(d_falseNode)) {
+          // collect membership info
+          if(n.getKind() == kind::MEMBER && n[0].getType().isTuple()) {
+            Node tup_rep = getRepresentative(n[0]);
+            Node rel_rep = getRepresentative(n[1]);
+            // No rel_rep is found
+            if(d_membership_cache.find(rel_rep) == d_membership_cache.end()) {
+              std::vector<Node> tups;
+              tups.push_back(tup_rep);
+              d_membership_cache[rel_rep] = tups;
               Node exp = n;
-              if(d_eqEngine->getRepresentative(r) == d_eqEngine->getRepresentative(d_falseNode)) {
-                fact = fact.negate();
+              if(getRepresentative(r) == getRepresentative(d_falseNode))
                 exp = n.negate();
+              tups.clear();
+              tups.push_back(exp);
+              d_membership_exp_cache[rel_rep] = tups;
+            } else if(std::find(d_membership_cache.at(rel_rep).begin(),
+                                d_membership_cache.at(rel_rep).end(), tup_rep)
+                      == d_membership_cache.at(rel_rep).end()) {
+              d_membership_cache[rel_rep].push_back(tup_rep);
+              Node exp = n;
+              if(getRepresentative(r) == getRepresentative(d_falseNode))
+                exp = n.negate();
+              d_membership_exp_cache.at(rel_rep).push_back(exp);
+            }
+          }
+        // collect term info
+        } else if(r.getType().isSet() && r.getType().getSetElementType().isTuple()) {
+          if(n.getKind() == kind::TRANSPOSE ||
+             n.getKind() == kind::JOIN ||
+             n.getKind() == kind::PRODUCT ||
+             n.getKind() == kind::TRANSCLOSURE) {
+            std::map<kind::Kind_t, std::vector<Node> > rel_terms;
+            std::vector<Node> terms;
+            // No r record is found
+            if(d_terms_cache.find(r) == d_terms_cache.end()) {
+              terms.push_back(n);
+              rel_terms[n.getKind()] = terms;
+              d_terms_cache[r] = rel_terms;
+            } else {
+              rel_terms = d_terms_cache[r];
+              // No n's kind record is found
+              if(rel_terms.find(n.getKind()) == rel_terms.end()) {
+                terms.push_back(n);
+                rel_terms[n.getKind()] = terms;
+              } else {
+                rel_terms.at(n.getKind()).push_back(n);
               }
-              d_pending_facts[fact] = exp;
-            } else if(kind::JOIN == n[1].getKind()) {
-              TNode r1 = n[1][0];
-              TNode r2 = n[1][1];
-              // Need to do this efficiently... Join relations after collecting all of them
-              // So that we would just need to iterate over EE once
-              joinRelations(r1, r2, n[1].getType().getSetElementType());
-
-              // case: (a, b) IS_IN (X JOIN Y)
-              //      => (a, z) IS_IN X  && (z, b) IS_IN Y
-              if(d_eqEngine->getRepresentative(r) == d_eqEngine->getRepresentative(d_trueNode)) {
-                Debug("rels-join") << "Join rules (a, b) IS_IN (X JOIN Y) => ((a, z) IS_IN X  && (z, b) IS_IN Y)"<< std::endl;
-                Assert((r1.getType().getSetElementType()).isDatatype());
-                Assert((r2.getType().getSetElementType()).isDatatype());
-
-                unsigned int i = 0;
-                std::vector<Node> r1_tuple;
-                std::vector<Node> r2_tuple;
-                Node::iterator child_it = n[0].begin();
-                unsigned int s1_len = r1.getType().getSetElementType().getTupleLength();
-                Node shared_x = NodeManager::currentNM()->mkSkolem("sde_", r2.getType().getSetElementType().getTupleTypes()[0]);
-                Datatype dt = r1.getType().getSetElementType().getDatatype();
-
-                r1_tuple.push_back(Node::fromExpr(dt[0].getConstructor()));
-                for(; i < s1_len-1; ++child_it) {
-                  r1_tuple.push_back(*child_it);
-                  ++i;
-                }
-                r1_tuple.push_back(shared_x);
-                dt = r2.getType().getSetElementType().getDatatype();
-                r2_tuple.push_back(Node::fromExpr(dt[0].getConstructor()));
-                r2_tuple.push_back(shared_x);
-                for(; child_it != n[0].end(); ++child_it) {
-                  r2_tuple.push_back(*child_it);
-                }
-                Node t1 = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, r1_tuple);
-                Node t2 = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, r2_tuple);
-                Node f1 = NodeManager::currentNM()->mkNode(kind::MEMBER, t1, r1);
-                Node f2 = NodeManager::currentNM()->mkNode(kind::MEMBER, t2, r2);
-                d_pending_facts[f1] = n;
-                d_pending_facts[f2] = n;
-              }
-            }else if(kind::PRODUCT == n[1].getKind()) {
-
             }
           }
         }
@@ -134,135 +161,412 @@ namespace sets {
       }
       ++eqcs_i;
     }
-    doPendingFacts();
+    Trace("rels-debug") << "[sets-rels] Done with collecting relational terms!" << std::endl;
   }
 
-  // Join all explicitly specified tuples in r1, r2
-  // e.g. If (a, b) in X and (b, c) in Y, (a, c) in (X JOIN Y)
-  void TheorySetsRels::joinRelations(TNode r1, TNode r2, TypeNode tn) {
-    if (!d_eqEngine->consistent())
-          return;
-    Debug("rels-join") << "start joining tuples in "
-                       << r1 << " and " << r2 << std::endl;
-
-    std::vector<Node> r1_elements;
-    std::vector<Node> r2_elements;
-    eq::EqClassesIterator eqcs_i = eq::EqClassesIterator( d_eqEngine );
+  void TheorySetsRels::doPendingFacts() {
+    std::map<Node, Node>::iterator map_it = d_pending_facts.begin();
+    while( !(*d_conflict) && map_it != d_pending_facts.end()) {
 
-    // collect all tuples that are in r1, r2
-    while( !eqcs_i.isFinished() ){
-      TNode r = (*eqcs_i);
-      eq::EqClassIterator eqc_i = eq::EqClassIterator( r, d_eqEngine );
-      while( !eqc_i.isFinished() ){
-        TNode n = (*eqc_i);
-        if(d_eqEngine->getRepresentative(r) == d_eqEngine->getRepresentative(d_trueNode)
-            && n.getKind() == kind::MEMBER && n[0].getType().isTuple()) {
-          if(n[1] == r1) {
-            Debug("rels-join") << "r1 tuple: " << n[0] << std::endl;
-            r1_elements.push_back(n[0]);
-          } else if (n[1] == r2) {
-            Debug("rels-join") << "r2 tuple: " << n[0] << std::endl;
-            r2_elements.push_back(n[0]);
-          }
+      Node fact = map_it->first;
+      Node exp = d_pending_facts[ fact ];
+      if(fact.getKind() == kind::AND) {
+        for(size_t j=0; j<fact.getNumChildren(); j++) {
+          bool polarity = fact[j].getKind() != kind::NOT;
+          Node atom = polarity ? fact[j] : fact[j][0];
+          assertMembership(atom, exp, polarity);
         }
-        ++eqc_i;
+      } else {
+        bool polarity = fact.getKind() != kind::NOT;
+        Node atom = polarity ? fact : fact[0];
+        assertMembership(atom, exp, polarity);
       }
-      ++eqcs_i;
+      map_it++;
     }
-    if(r1_elements.size() == 0 || r2_elements.size() == 0)
-      return;
+    d_pending_facts.clear();
+    d_membership_cache.clear();
+    d_membership_exp_cache.clear();
+    d_terms_cache.clear();
+  }
+
+  void TheorySetsRels::applyProductRule(Node exp, Node rel_rep, Node product_term) {
+    Trace("rels-debug") << "\n[sets-rels] Apply PRODUCT rule on term: " << product_term
+                        << " with explaination: " << exp << std::endl;
+    bool polarity = exp.getKind() != kind::NOT;
+    Node atom = polarity ? exp : exp[0];
+    if(!polarity)
+      computeJoinOrProductRelations(product_term);
+  }
 
-    // Join r1 and r2
-    joinTuples(r1, r2, r1_elements, r2_elements, tn);
+  void TheorySetsRels::applyJoinRule(Node exp, Node rel_rep, Node join_term) {
+    Trace("rels-debug") <<  "\n[sets-rels] Apply JOIN rule on term: " << join_term
+                        << " with explaination: " << exp << std::endl;
+    bool polarity = exp.getKind() != kind::NOT;
+    Node atom = polarity ? exp : exp[0];
+    Node r1 = join_term[0];
+    Node r2 = join_term[1];
+
+    // case: (a, b) IS_IN (X JOIN Y)
+    //      => (a, z) IS_IN X  && (z, b) IS_IN Y
+    if(polarity) {
+      Debug("rels-join") << "[sets-rels] Join rules (a, b) IS_IN (X JOIN Y) => "
+                            "((a, z) IS_IN X  && (z, b) IS_IN Y)"<< std::endl;
+      Assert((r1.getType().getSetElementType()).isDatatype());
+      Assert((r2.getType().getSetElementType()).isDatatype());
+
+      unsigned int i = 0;
+      std::vector<Node> r1_tuple;
+      std::vector<Node> r2_tuple;
+      Node::iterator child_it = atom[0].begin();
+      unsigned int s1_len = r1.getType().getSetElementType().getTupleLength();
+      Node shared_x = NodeManager::currentNM()->mkSkolem("sde_", r2.getType().getSetElementType().getTupleTypes()[0]);
+      Datatype dt = r1.getType().getSetElementType().getDatatype();
+
+      r1_tuple.push_back(Node::fromExpr(dt[0].getConstructor()));
+      for(; i < s1_len-1; ++child_it) {
+        r1_tuple.push_back(*child_it);
+        ++i;
+      }
+      r1_tuple.push_back(shared_x);
+      dt = r2.getType().getSetElementType().getDatatype();
+      r2_tuple.push_back(Node::fromExpr(dt[0].getConstructor()));
+      r2_tuple.push_back(shared_x);
+      for(; child_it != atom[0].end(); ++child_it) {
+        r2_tuple.push_back(*child_it);
+      }
+      Node t1 = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, r1_tuple);
+      Node t2 = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, r2_tuple);
+      Node f1 = NodeManager::currentNM()->mkNode(kind::MEMBER, t1, r1);
+      Node f2 = NodeManager::currentNM()->mkNode(kind::MEMBER, t2, r2);
+      Node reason = exp;
+      if(atom[1] != join_term)
+        reason = AND(reason, EQUAL(atom[1], join_term));
+      sendInfer(f1, reason, "join-split");
+      sendInfer(f2, reason, "join-split");
+    } else {
+      // ONLY need to explicitly compute joins if there are negative literals involving JOIN
+      computeJoinOrProductRelations(join_term);
+    }
+  }
 
+  void TheorySetsRels::applyTransposeRule(Node exp, Node rel_rep, Node tp_term) {
+    Trace("rels-debug") << "\n[sets-rels] Apply transpose rule on term: " << tp_term
+                        << " with explaination: " << exp << std::endl;
+    bool polarity = exp.getKind() != kind::NOT;
+    Node atom = polarity ? exp : exp[0];
+    Node reversedTuple = reverseTuple(atom[0]);
+    Node reason = exp;
+
+    if(atom[1] != tp_term)
+      reason = AND(reason, EQUAL(rel_rep, tp_term));
+    Node fact = MEMBER(reversedTuple, tp_term[0]);
+
+    // when the term is nested like (not tup is_in tp(x join/product y)),
+    // we need to compute what is inside x join/product y
+    if(!polarity) {
+      if(d_terms_cache[getRepresentative(fact[1])].find(kind::JOIN)
+         != d_terms_cache[getRepresentative(fact[1])].end()) {
+        computeJoinOrProductRelations(fact[1]);
+      }
+      if(d_terms_cache[getRepresentative(fact[1])].find(kind::PRODUCT)
+         != d_terms_cache[getRepresentative(fact[1])].end()) {
+        computeJoinOrProductRelations(fact[1]);
+      }
+      fact = fact.negate();
+    }
+    sendInfer(fact, exp, "transpose-rule");
   }
 
-  void TheorySetsRels::joinTuples(TNode r1, TNode r2, std::vector<Node>& r1_elements, std::vector<Node>& r2_elements, TypeNode tn) {
+  void TheorySetsRels::computeJoinOrProductRelations(Node n) {
+    switch(n[0].getKind()) {
+    case kind::JOIN:
+      computeJoinOrProductRelations(n[0]);
+      break;
+    case kind::TRANSPOSE:
+      computeTransposeRelations(n[0]);
+      break;
+    case kind::PRODUCT:
+      computeJoinOrProductRelations(n[0]);
+      break;
+    default:
+      break;
+    }
+
+    switch(n[1].getKind()) {
+    case kind::JOIN:
+      computeJoinOrProductRelations(n[1]);
+      break;
+    case kind::TRANSPOSE:
+      computeTransposeRelations(n[1]);
+      break;
+    case kind::PRODUCT:
+      computeJoinOrProductRelations(n[1]);
+      break;
+    default:
+      break;
+    }
+
+    if(d_membership_cache.find(getRepresentative(n[0])) == d_membership_cache.end() ||
+       d_membership_cache.find(getRepresentative(n[1])) == d_membership_cache.end())
+          return;
+    composeRelations(n);
+  }
+
+  void TheorySetsRels::computeTransposeRelations(Node n) {
+    switch(n[0].getKind()) {
+    case kind::JOIN:
+      computeJoinOrProductRelations(n[0]);
+      break;
+    case kind::TRANSPOSE:
+      computeTransposeRelations(n[0]);
+      break;
+    case kind::PRODUCT:
+      computeJoinOrProductRelations(n[0]);
+      break;
+    default:
+      break;
+    }
+
+    if(d_membership_cache.find(getRepresentative(n[0])) == d_membership_cache.end())
+      return;
+    std::vector<Node> rev_tuples;
+    std::vector<Node> rev_exps;
+    Node n_rep = getRepresentative(n);
+    Node n0_rep = getRepresentative(n[0]);
+
+    if(d_membership_cache.find(n_rep) != d_membership_cache.end()) {
+      rev_tuples = d_membership_cache[n_rep];
+      rev_exps = d_membership_exp_cache[n_rep];
+    }
+    std::vector<Node> tuples = d_membership_cache[n0_rep];
+    std::vector<Node> exps = d_membership_exp_cache[n0_rep];
+    for(unsigned int i = 0; i < tuples.size(); i++) {
+      // Todo: Need to consider duplicates
+      Node reason = exps[i];
+      Node rev_tup = reverseTuple(tuples[i]);
+      if(exps[i][1] != n0_rep)
+        reason = AND(reason, EQUAL(exps[i][1], n0_rep));
+      rev_tuples.push_back(rev_tup);
+      rev_exps.push_back(Rewriter::rewrite(reason));
+      sendInfer(MEMBER(rev_tup, n_rep), Rewriter::rewrite(reason), "transpose-rule");
+//      if(std::find(rev_tuples.begin(), rev_tuples.end(), reverseTuple(tuples[i])) == rev_tuples.end()) {
+//
+//      }
+    }
+    d_membership_cache[n_rep] = rev_tuples;
+    d_membership_exp_cache[n_rep] = rev_exps;
+  }
+
+  // Join all explicitly specified tuples in r1, r2
+  // e.g. If (a, b) in X and (b, c) in Y, (a, c) in (X JOIN Y)
+  void TheorySetsRels::composeRelations(Node n) {
+    Node r1 = n[0];
+    Node r2 = n[1];
+    Node r1_rep = getRepresentative(r1);
+    Node r2_rep = getRepresentative(r2);
+    Trace("rels-debug") << "[sets-rels] start joining tuples in "
+                       << r1 << " and " << r2
+                       << "\n r1_rep: " << r1_rep
+                       << "\n r2_rep: " << r2_rep << std::endl;
+
+    if(d_membership_cache.find(r1_rep) == d_membership_cache.end() ||
+       d_membership_cache.find(r2_rep) == d_membership_cache.end())
+    return;
+
+    TypeNode tn = n.getType().getSetElementType();
     Datatype dt = tn.getDatatype();
+    std::vector<Node> new_tups;
+    std::vector<Node> new_exps;
+    std::vector<Node> r1_elements = d_membership_cache[r1_rep];
+    std::vector<Node> r2_elements = d_membership_cache[r2_rep];
+    std::vector<Node> r1_exps = d_membership_exp_cache[r1_rep];
+    std::vector<Node> r2_exps = d_membership_exp_cache[r2_rep];
+    Node new_rel = n.getKind() == kind::JOIN ? NodeManager::currentNM()->mkNode(kind::JOIN, r1_rep, r2_rep)
+                                             : NodeManager::currentNM()->mkNode(kind::PRODUCT, r1_rep, r2_rep);
     unsigned int t1_len = r1_elements.front().getType().getTupleLength();
     unsigned int t2_len = r2_elements.front().getType().getTupleLength();
 
     for(unsigned int i = 0; i < r1_elements.size(); i++) {
       for(unsigned int j = 0; j < r2_elements.size(); j++) {
-        if(r1_elements[i][t1_len-1] == r2_elements[j][0]) {
-          std::vector<Node> joinedTuple;
-          joinedTuple.push_back(Node::fromExpr(dt[0].getConstructor()));
-          for(unsigned int k = 0; k < t1_len - 1; ++k) {
+        std::vector<Node> joinedTuple;
+        joinedTuple.push_back(Node::fromExpr(dt[0].getConstructor()));
+        Debug("rels-debug") << "areEqual(r1_elements[i][t1_len-1], r2_elements[j][0]):\n"
+                            << "   r1_elements[i][t1_len-1] = " << r1_elements[i][t1_len-1]
+                            << "   r2_elements[j][0]) = " << r2_elements[j][0]
+                            << "   are equal? " << areEqual(r1_elements[i][t1_len-1], r2_elements[j][0]) << std::endl;
+        if((areEqual(r1_elements[i][t1_len-1], r2_elements[j][0]) && n.getKind() == kind::JOIN) ||
+            n.getKind() == kind::PRODUCT) {
+          unsigned int k = 0;
+          unsigned int l = 1;
+          for(; k < t1_len - 1; ++k) {
             joinedTuple.push_back(r1_elements[i][k]);
           }
-          for(unsigned int l = 1; l < t2_len; ++l) {
+          if(kind::PRODUCT == n.getKind()) {
+            joinedTuple.push_back(r1_elements[i][k]);
+            joinedTuple.push_back(r1_elements[j][0]);
+          }
+          for(; l < t2_len; ++l) {
             joinedTuple.push_back(r2_elements[j][l]);
           }
           Node fact = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, joinedTuple);
-          fact = NodeManager::currentNM()->mkNode(kind::MEMBER, fact, NodeManager::currentNM()->mkNode(kind::JOIN, r1, r2));
-          Node reason = NodeManager::currentNM()->mkNode(kind::AND,
-                                                         NodeManager::currentNM()->mkNode(kind::MEMBER, r1_elements[i], r1),
-                                                         NodeManager::currentNM()->mkNode(kind::MEMBER, r2_elements[j], r2));
-          Debug("rels-join") << "join tuples: " << r1_elements[i]
+          new_tups.push_back(fact);
+          fact = MEMBER(fact, new_rel);
+          std::vector<Node> reasons;
+          reasons.push_back(r1_exps[i]);
+          reasons.push_back(r2_exps[j]);
+
+          //Todo: think about how to deal with shared terms(?)
+          if(n.getKind() == kind::JOIN)
+            reasons.push_back(EQUAL(r1_elements[i][t1_len-1], r2_elements[j][0]));
+
+          if(r1 != r1_rep) {
+            reasons.push_back(EQUAL(r1, r1_rep));
+          }
+          if(r2 != r2_rep) {
+            reasons.push_back(EQUAL(r2, r2_rep));
+          }
+          Node reason = theory::Rewriter::rewrite(NodeManager::currentNM()->mkNode(kind::AND, reasons));
+          new_exps.push_back(reason);
+          Trace("rels-debug") << "[sets-rels] compose tuples: " << r1_elements[i]
                              << " and " << r2_elements[j]
-                             << "\nnew fact: " << fact
-                             << "\nreason: " << reason<< std::endl;
-          d_pending_facts[fact] = reason;
+                             << "\n new fact: " << fact
+                             << "\n reason: " << reason<< std::endl;
+          if(kind::JOIN == n.getKind())
+            sendInfer(fact, reason, "join-compose");
+          else if(kind::PRODUCT == n.getKind())
+            sendInfer(fact, reason, "product-compose");
         }
       }
     }
-  }
-
-
-  void TheorySetsRels::sendLemma(TNode fact, TNode reason, bool polarity) {
 
+    Node new_rel_rep = getRepresentative( new_rel );
+    if(d_membership_cache.find( new_rel_rep ) != d_membership_cache.end()) {
+      std::vector<Node> tups = d_membership_cache[new_rel_rep];
+      std::vector<Node> exps = d_membership_exp_cache[new_rel_rep];
+      // Todo: Need to take care of duplicate tuples
+      tups.insert( tups.end(), new_tups.begin(), new_tups.end() );
+      exps.insert( exps.end(), new_exps.begin(), new_exps.end() );
+    } else {
+      d_membership_cache[new_rel_rep] = new_tups;
+      d_membership_exp_cache[new_rel_rep] = new_exps;
+    }
+    Trace("rels-debug") << "[sets-rels] Done with joining tuples !" << std::endl;
   }
 
-  void TheorySetsRels::doPendingFacts() {
-    std::map<Node, Node>::iterator map_it = d_pending_facts.begin();
-    while( !(*d_conflict) && map_it != d_pending_facts.end()) {
-
-      Node fact = map_it->first;
-      Node exp = d_pending_facts[ fact ];
-      Debug("rels") << "sending out pending fact: " << fact
-                    << "  reason: " << exp
-                    << std::endl;
-      if(fact.getKind() == kind::AND) {
-        for(size_t j=0; j<fact.getNumChildren(); j++) {
-          bool polarity = fact[j].getKind() != kind::NOT;
-          TNode atom = polarity ? fact[j] : fact[j][0];
-          assertMembership(atom, exp, polarity);
-        }
-      } else {
-        bool polarity = fact.getKind() != kind::NOT;
-        TNode atom = polarity ? fact : fact[0];
-        assertMembership(atom, exp, polarity);
+  void TheorySetsRels::doPendingLemmas() {
+    if( !(*d_conflict) && !d_lemma_cache.empty() ){
+      for( unsigned i=0; i < d_lemma_cache.size(); i++ ){
+        Trace("rels-debug") << "[sets-rels] Process pending lemma : " << d_lemma_cache[i] << std::endl;
+        d_sets.d_out->lemma( d_lemma_cache[i] );
+      }
+      for( std::map<Node, Node>::iterator child_it = d_pending_facts.begin();
+           child_it != d_pending_facts.end(); child_it++ ) {
+        Trace("rels-debug") << "[sets-rels] Process pending fact as lemma : " << child_it->first << std::endl;
+        d_sets.d_out->lemma(child_it->first);
       }
-      map_it++;
     }
     d_pending_facts.clear();
+    d_lemma_cache.clear();
   }
 
+  void TheorySetsRels::sendSplit(Node a, Node b, const char * c) {
+    Node eq = a.eqNode( b );
+    Node neq = NOT( eq );
+    Node lemma_or = OR( eq, neq );
+    Trace("rels-lemma") << "[sets-rels] Lemma " << c << " SPLIT : " << lemma_or << std::endl;
+    d_lemma_cache.push_back( lemma_or );
+  }
 
+  void TheorySetsRels::sendLemma(Node fact, Node reason, bool polarity) {
 
-  Node TheorySetsRels::reverseTuple(TNode tuple) {
-    Assert(tuple.getType().isTuple());
+  }
 
+  void TheorySetsRels::sendInfer( Node fact, Node exp, const char * c ) {
+    Trace("rels-lemma") << "[sets-rels] Infer " << fact << " from " << exp << " by " << c << std::endl;
+    d_pending_facts[fact] = exp;
+    d_infer.push_back( fact );
+    d_infer_exp.push_back( exp );
+  }
+
+  Node TheorySetsRels::reverseTuple( Node tuple ) {
+    Assert( tuple.getType().isTuple() );
     std::vector<Node> elements;
     std::vector<TypeNode> tuple_types = tuple.getType().getTupleTypes();
-    std::reverse(tuple_types.begin(), tuple_types.end());
-    TypeNode tn = NodeManager::currentNM()->mkTupleType(tuple_types);
+    std::reverse( tuple_types.begin(), tuple_types.end() );
+    TypeNode tn = NodeManager::currentNM()->mkTupleType( tuple_types );
     Datatype dt = tn.getDatatype();
 
-    elements.push_back(Node::fromExpr(dt[0].getConstructor()));
+    elements.push_back( Node::fromExpr(dt[0].getConstructor() ) );
     for(Node::iterator child_it = tuple.end()-1;
-              child_it != tuple.begin()-1; --child_it) {
-      elements.push_back(*child_it);
+        child_it != tuple.begin()-1; --child_it) {
+      elements.push_back( *child_it );
+    }
+    return NodeManager::currentNM()->mkNode( kind::APPLY_CONSTRUCTOR, elements );
+  }
+
+  void TheorySetsRels::assertMembership( Node fact, Node reason, bool polarity ) {
+    d_eqEngine->assertPredicate( fact, polarity, reason );
+  }
+
+  Node TheorySetsRels::getRepresentative( Node t ) {
+    if( d_eqEngine->hasTerm( t ) ){
+      return d_eqEngine->getRepresentative( t );
+    }else{
+      return t;
     }
-    return NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, elements);
   }
 
-  void TheorySetsRels::assertMembership(TNode fact, TNode reason, bool polarity) {
-    Debug("rels") << "fact: " << fact
-                  << "\npolarity : " << polarity
-                  << "\nreason: " << reason << std::endl;
-    d_eqEngine->assertPredicate(fact, polarity, reason);
+  bool TheorySetsRels::hasTerm( Node a ){
+    return d_eqEngine->hasTerm( a );
+  }
+
+  bool TheorySetsRels::areEqual( Node a, Node b ){
+    if( hasTerm( a ) && hasTerm( b ) ){
+//      Trace("rels-debug") << "has a and b " << a << " " << b << " are equal? "<<  d_eqEngine->areEqual( a, b ) << std::endl;
+      return d_eqEngine->areEqual( a, b );
+    }else if( a.isConst() && b.isConst() ){
+      return a == b;
+    }else {
+//      Trace("rels-debug") << "to split a and b " << a << " " << b << std::endl;
+      addSharedTerm(a);
+      addSharedTerm(b);
+      sendSplit(a, b, "tuple-element-equality");
+      return false;
+    }
+  }
+
+  void TheorySetsRels::addSharedTerm(TNode n) {
+    Trace("rels-debug") << "[sets-rels] Add a shared term:  " << n << std::endl;
+    d_sets.addSharedTerm(n);
+    d_eqEngine->addTriggerTerm(n, THEORY_SETS);
+  }
+
+  bool TheorySetsRels::exists( std::vector<Node>& v, Node n ){
+    return std::find(v.begin(), v.end(), n) != v.end();
   }
+
+  TheorySetsRels::TheorySetsRels(context::Context* c,
+                                 context::UserContext* u,
+                                 eq::EqualityEngine* eq,
+                                 context::CDO<bool>* conflict,
+                                 TheorySets& d_set):
+    d_sets(d_set),
+    d_trueNode(NodeManager::currentNM()->mkConst<bool>(true)),
+    d_falseNode(NodeManager::currentNM()->mkConst<bool>(false)),
+    d_infer(c),
+    d_infer_exp(c),
+    d_eqEngine(eq),
+    d_conflict(conflict)
+  {
+    d_eqEngine->addFunctionKind(kind::PRODUCT);
+    d_eqEngine->addFunctionKind(kind::JOIN);
+    d_eqEngine->addFunctionKind(kind::TRANSPOSE);
+    d_eqEngine->addFunctionKind(kind::TRANSCLOSURE);
+  }
+
+  TheorySetsRels::~TheorySetsRels() {}
+
+
 }
 }
 }
index 537fc2d430d6fd6addc237e761bfb5706731c5a9..4eb30ab12886f220519791aa703caeec187dff8e 100644 (file)
 #include "theory/theory.h"
 #include "theory/uf/equality_engine.h"
 #include "context/cdhashset.h"
+#include "context/cdchunk_list.h"
 
 namespace CVC4 {
 namespace theory {
 namespace sets {
 
+class TheorySets;
+
 class TheorySetsRels {
 
+  typedef context::CDChunkList<Node> NodeList;
+
 public:
   TheorySetsRels(context::Context* c,
                  context::UserContext* u,
                  eq::EqualityEngine*,
-                 context::CDO<bool>* );
+                 context::CDO<bool>*,
+                 TheorySets&);
 
   ~TheorySetsRels();
-
   void check(Theory::Effort);
 
+  void doPendingLemmas();
+
 private:
 
+  TheorySets& d_sets;
+
   /** True and false constant nodes */
   Node d_trueNode;
   Node d_falseNode;
 
   // Facts and lemmas to be sent to EE
-  std::map< Node, Node> d_pending_facts;
+  std::map< Node, Node > d_pending_facts;
   std::vector< Node > d_lemma_cache;
 
-  // Relation pairs to be joined
-//  std::map<TNode, TNode> d_rel_pairs;
-//  std::hash_set<TNode> d_rels;
+  /** inferences: maintained to ensure ref count for internally introduced nodes */
+  NodeList d_infer;
+  NodeList d_infer_exp;
+
+  std::map< Node, std::vector<Node> > d_membership_cache;
+  std::map< Node, std::vector<Node> > d_membership_exp_cache;
+  std::map< Node, std::map<kind::Kind_t, std::vector<Node> > > d_terms_cache;
 
   eq::EqualityEngine *d_eqEngine;
   context::CDO<bool> *d_conflict;
 
-  // save all the relational terms seen so far
-  context::CDHashSet <Node, NodeHashFunction> d_relsSaver;
-
-  void assertMembership(TNode fact, TNode reason, bool polarity);
-
-  void joinRelations(TNode, TNode, TypeNode);
-  void joinTuples(TNode, TNode, std::vector<Node>&, std::vector<Node>&, TypeNode tn);
-
-  Node reverseTuple(TNode);
-
-  void sendLemma(TNode fact, TNode reason, bool polarity);
-  void doPendingLemmas();
+  void check();
+  void collectRelationalInfo();
+  void assertMembership( Node fact, Node reason, bool polarity );
+  void composeProductRelations( Node );
+  void composeJoinRelations( Node );
+  void composeRelations( Node );
+  void applyTransposeRule( Node, Node, Node );
+  void applyJoinRule( Node, Node, Node );
+  void applyProductRule( Node, Node, Node );
+  void computeJoinOrProductRelations( Node );
+  void computeTransposeRelations( Node );
+  Node reverseTuple( Node );
+
+  void sendInfer( Node fact, Node exp, const char * c );
+  void sendLemma( Node fact, Node reason, bool polarity );
+  void sendSplit( Node a, Node b, const char * c );
   void doPendingFacts();
+  void addSharedTerm( TNode n );
+
+  // Helper functions
+  Node getRepresentative( Node t );
+  bool hasTerm( Node a );
+  bool areEqual( Node a, Node b );
+  bool exists( std::vector<Node>&, Node );
 
 };
 
index 635f9856a748f67cea5509d18d1c6aecb991a2da..dac554d4feb179af9a855b8d0c5519dc633f635d 100644 (file)
@@ -62,6 +62,27 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
       bool isMember = checkConstantMembership(node[0], S);
       return RewriteResponse(REWRITE_DONE, nm->mkConst(isMember));
     }
+    if(node[1].getKind() == kind::TRANSPOSE) {
+      // only work for node[0] is an actual tuple like (a, b), won't work for tuple variables
+      if(node[0].isVar())
+        return RewriteResponse(REWRITE_DONE, node);
+      std::vector<Node> elements;
+      std::vector<TypeNode> tuple_types = node[0].getType().getTupleTypes();
+      std::reverse(tuple_types.begin(), tuple_types.end());
+      TypeNode tn = NodeManager::currentNM()->mkTupleType(tuple_types);
+      Datatype dt = tn.getDatatype();
+      elements.push_back(Node::fromExpr(dt[0].getConstructor()));
+      for(Node::iterator child_it = node[0].end()-1;
+                child_it != node[0].begin()-1; --child_it) {
+        elements.push_back(*child_it);
+      }
+      Node new_node = NodeManager::currentNM()->mkNode(kind::MEMBER,
+                                                       NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, elements),
+                                                       node[1][0]);
+      if(node.getKind() == kind::NOT)
+        new_node = NodeManager::currentNM()->mkNode(kind::NOT, new_node);
+      return RewriteResponse(REWRITE_AGAIN, new_node);
+    }
     break;
   }//kind::MEMBER
 
@@ -176,6 +197,17 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
     break;
   }//kind::UNION
 
+  case kind::TRANSPOSE: {
+    if(node[0].getKind() != kind::TRANSPOSE) {
+      Trace("sets-postrewrite") << "Sets::postRewrite returning " << node << std::endl;
+      return RewriteResponse(REWRITE_DONE, node);
+    }
+    if(node[0].getKind() == kind::TRANSPOSE) {
+      return RewriteResponse(REWRITE_AGAIN, node[0][0]);
+    }
+    break;
+  }
+
   default:
     break;
   }//switch(node.getKind())
index a251218c63105461584392479838b7c296ab5413..406b8d312ab1dd4c2ec17bd9c0f5ac070131f83d 100644 (file)
@@ -1,3 +1,4 @@
+% EXPECT: unsat
 OPTION "logic" "ALL_SUPPORTED";
 IntPair: TYPE = [INT, INT];
 x : SET OF IntPair;
@@ -18,8 +19,6 @@ ASSERT (7, 5) IS_IN y;
 
 ASSERT z IS_IN x;
 ASSERT zt IS_IN y;
-%ASSERT a IS_IN (x JOIN y);
-%ASSERT NOT (v IS_IN (x JOIN y));
 ASSERT NOT (a IS_IN (x JOIN y));
 
 CHECKSAT;
index d06528fd279455e00029ee4e97b0d115f167a4d5..95c27edf0d14b5dde87543e915f8980006a4b338 100644 (file)
@@ -1,3 +1,4 @@
+% EXPECT: unsat
 OPTION "logic" "ALL_SUPPORTED";
 IntPair: TYPE = [INT, INT];
 x : SET OF IntPair;