- Implement constant rewriter for relational operators for model generation
authorPaulMeng <baolmeng@gmail.com>
Fri, 15 Apr 2016 03:01:32 +0000 (22:01 -0500)
committerPaulMeng <baolmeng@gmail.com>
Fri, 15 Apr 2016 03:01:32 +0000 (22:01 -0500)
- fixed a few bugs

src/theory/sets/theory_sets_private.cpp
src/theory/sets/theory_sets_rels.cpp
src/theory/sets/theory_sets_rels.h
src/theory/sets/theory_sets_rewriter.cpp

index 10fc9f1953150674684775cc5969208b15467042..db4f4bf26d3943c3862aebf887e9d0603c6e0bf8 100644 (file)
@@ -672,6 +672,7 @@ bool TheorySetsPrivate::checkModel(const SettermElementsMap& settermElementsMap,
                       << std::endl;
 
   Assert(S.getType().isSet());
+  std::set<Node> temp_nodes;
 
   const Elements emptySetOfElements;
   const Elements& saved =
@@ -715,6 +716,74 @@ bool TheorySetsPrivate::checkModel(const SettermElementsMap& settermElementsMap,
       std::set_difference(left.begin(), left.end(), right.begin(), right.end(),
                           std::inserter(cur, cur.begin()) );
       break;
+    case kind::PRODUCT: {
+      std::set<Node> new_tuple_set;
+      Elements::const_iterator left_it = left.begin();
+      int left_len = (*left_it).getType().getTupleLength();
+      TypeNode tn = S.getType().getSetElementType();
+      while(left_it != left.end()) {
+        Trace("rels-debug") << "Sets::postRewrite processing left_it = " <<  *left_it << std::endl;
+        std::vector<Node> left_tuple;
+        left_tuple.push_back(Node::fromExpr(tn.getDatatype()[0].getConstructor()));
+        for(int i = 0; i < left_len; i++) {
+          left_tuple.push_back(TheorySetsRels::nthElementOfTuple(*left_it,i));
+        }
+        Elements::const_iterator right_it = right.begin();
+        int right_len = (*right_it).getType().getTupleLength();
+        while(right_it != right.end()) {
+          std::vector<Node> right_tuple;
+          for(int j = 0; j < right_len; j++) {
+            right_tuple.push_back(TheorySetsRels::nthElementOfTuple(*right_it,j));
+          }
+          std::vector<Node> new_tuple;
+          new_tuple.insert(new_tuple.end(), left_tuple.begin(), left_tuple.end());
+          new_tuple.insert(new_tuple.end(), right_tuple.begin(), right_tuple.end());
+          Node composed_tuple = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, new_tuple);
+          temp_nodes.insert(composed_tuple);
+          new_tuple_set.insert(composed_tuple);
+          right_it++;
+        }
+        left_it++;
+      }
+      cur.insert(new_tuple_set.begin(), new_tuple_set.end());
+      Trace("rels-debug") << " ***** Done with check model for product operator" << std::endl;
+      break;
+    }
+    case kind::JOIN: {
+      std::set<Node> new_tuple_set;
+      Elements::const_iterator left_it = left.begin();
+      int left_len = (*left_it).getType().getTupleLength();
+      TypeNode tn = S.getType().getSetElementType();
+      while(left_it != left.end()) {
+        std::vector<Node> left_tuple;
+        left_tuple.push_back(Node::fromExpr(tn.getDatatype()[0].getConstructor()));
+        for(int i = 0; i < left_len - 1; i++) {
+          left_tuple.push_back(TheorySetsRels::nthElementOfTuple(*left_it,i));
+        }
+        Elements::const_iterator right_it = right.begin();
+        int right_len = (*right_it).getType().getTupleLength();
+        while(right_it != right.end()) {
+          if(TheorySetsRels::nthElementOfTuple(*left_it,left_len-1) == TheorySetsRels::nthElementOfTuple(*right_it,0)) {
+            std::vector<Node> right_tuple;
+            for(int j = 1; j < right_len; j++) {
+              right_tuple.push_back(TheorySetsRels::nthElementOfTuple(*right_it,j));
+            }
+            std::vector<Node> new_tuple;
+            new_tuple.insert(new_tuple.end(), left_tuple.begin(), left_tuple.end());
+            new_tuple.insert(new_tuple.end(), right_tuple.begin(), right_tuple.end());
+            Node composed_tuple = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, new_tuple);
+            new_tuple_set.insert(composed_tuple);
+          }
+          right_it++;
+        }
+        left_it++;
+      }
+      cur.insert(new_tuple_set.begin(), new_tuple_set.end());
+      Trace("rels-debug") << " ***** Done with check model for JOIN operator" << std::endl;
+      break;
+    }
+    case kind::TRANSCLOSURE:
+      break;
     default:
       Unhandled();
     }
index 0e20b9bfa7165fa98928083b6edb03196fce9619..eae9a4e8fa18f5ec982452b4de1e00b99dd1f811 100644 (file)
@@ -236,10 +236,10 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
     Node tc_r_rep = getRepresentative(tc_term[0]);
 
     // build the TC graph for tc_rep if it was not created before
-    if( d_tc_nodes.find(tc_rep) == d_tc_nodes.end() ) {
+    if( d_rel_nodes.find(tc_rep) == d_rel_nodes.end() ) {
       Trace("rels-debug") << "[sets-rels]  Start building the TC graph!" << std::endl;
       buildTCGraph(tc_r_rep, tc_rep, tc_term);
-      d_tc_nodes.insert(tc_rep);
+      d_rel_nodes.insert(tc_rep);
     }
     // insert atom[0] in the tc_graph if it is not in the graph already
     TC_IT tc_graph_it = d_membership_tc_cache.find(tc_rep);
@@ -316,6 +316,10 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
 
   void TheorySetsRels::applyProductRule(Node exp, Node product_term) {
     Trace("rels-debug") << "\n[sets-rels] *********** Applying PRODUCT rule  " << std::endl;
+    if(d_rel_nodes.find(product_term) == d_rel_nodes.end()) {
+      computeRels(product_term);
+      d_rel_nodes.insert(product_term);
+    }
     bool polarity = exp.getKind() != kind::NOT;
     Node atom = polarity ? exp : exp[0];
     Node r1_rep = getRepresentative(product_term[0]);
@@ -366,25 +370,22 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
       reason_2 = Rewriter::rewrite(AND(reason_2, explain(EQUAL(t2, t2_rep))));
     }
     if(polarity) {
-      if(!hasMember(r1_rep, t1_rep)) {
-        addToMap(d_membership_db, r1_rep, t1_rep);
+      if(safeAddToMap(d_membership_db, r1_rep, t1_rep)) {
         addToMap(d_membership_exp_db, r1_rep, reason_1);
         sendInfer(fact_1, reason_1, "product-split");
       }
-      if(!hasMember(r2_rep, t2)) {
-        addToMap(d_membership_db, r2_rep, t2);
+      if(safeAddToMap(d_membership_db, r2_rep, t2_rep)) {
         addToMap(d_membership_exp_db, r2_rep, reason_2);
         sendInfer(fact_2, reason_2, "product-split");
       }
 
     } else {
-//      sendInfer(fact_1.negate(), reason_1, "product-split");
-//      sendInfer(fact_2.negate(), reason_2, "product-split");
+      sendInfer(fact_1.negate(), reason_1, "product-split");
+      sendInfer(fact_2.negate(), reason_2, "product-split");
 
       // ONLY need to explicitly compute joins if there are negative literals involving PRODUCT
       Trace("rels-debug") << "\n[sets-rels] Apply PRODUCT-COMPOSE rule on term: " << product_term
                           << " with explanation: " << exp << std::endl;
-      computeRels(product_term);
     }
   }
 
@@ -399,6 +400,10 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
    */
   void TheorySetsRels::applyJoinRule(Node exp, Node join_term) {
     Trace("rels-debug") << "\n[sets-rels] *********** Applying JOIN rule  " << std::endl;
+    if(d_rel_nodes.find(join_term) == d_rel_nodes.end()) {
+      computeRels(join_term);
+      d_rel_nodes.insert(join_term);
+    }
     bool polarity = exp.getKind() != kind::NOT;
     Node atom = polarity ? exp : exp[0];
     Node r1_rep = getRepresentative(join_term[0]);
@@ -472,7 +477,6 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
       // ONLY need to explicitly compute joins if there are negative literals involving JOIN
       Trace("rels-debug") <<  "\n[sets-rels] Apply JOIN-COMPOSE rule on term: " << join_term
                           << " with explanation: " << exp << std::endl;
-      computeRels(join_term);
     }
   }
 
@@ -521,14 +525,18 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
          != d_terms_cache[tp_t0_rep].end()) {
         std::vector<Node> join_terms = d_terms_cache[tp_t0_rep][kind::JOIN];
         for(unsigned int i = 0; i < join_terms.size(); i++) {
-          computeRels(join_terms[i]);
+          if(d_rel_nodes.find(join_terms[i]) == d_rel_nodes.end()) {
+            computeRels(join_terms[i]);
+          }
         }
       }
       if(d_terms_cache[tp_t0_rep].find(kind::PRODUCT)
          != d_terms_cache[tp_t0_rep].end()) {
         std::vector<Node> product_terms = d_terms_cache[tp_t0_rep][kind::PRODUCT];
         for(unsigned int i = 0; i < product_terms.size(); i++) {
-          computeRels(product_terms[i]);
+          if(d_rel_nodes.find(product_terms[i]) == d_rel_nodes.end()) {
+            computeRels(product_terms[i]);
+          }
         }
       }
       fact = fact.negate();
@@ -792,7 +800,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
         d_sets_theory.d_out->lemma(NodeManager::currentNM()->mkNode(kind::IMPLIES, child_it->second, child_it->first));
       }
     }
-    d_tc_nodes.clear();
+    d_rel_nodes.clear();
     d_pending_facts.clear();
     d_membership_constraints_cache.clear();
     d_membership_tc_cache.clear();
@@ -930,12 +938,20 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
   }
 
   bool TheorySetsRels::safeAddToMap(std::map< Node, std::vector<Node> >& map, Node rel_rep, Node member) {
-    if(map.find(rel_rep) == map.end()) {
+    std::map< Node, std::vector< Node > >::iterator mem_it = map.find(rel_rep);
+    if(mem_it == map.end()) {
       std::vector<Node> members;
       members.push_back(member);
       map[rel_rep] = members;
       return true;
-    } else if(std::find(map[rel_rep].begin(), map[rel_rep].end(), member) == map[rel_rep].end()) {
+    } else {
+      std::vector<Node>::iterator mems = mem_it->second.begin();
+      while(mems != mem_it->second.end()) {
+        if(areEqual(*mems, member)) {
+          return false;
+        }
+        mems++;
+      }
       map[rel_rep].push_back(member);
       return true;
     }
@@ -979,7 +995,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
       if(tc_r_mems != d_membership_db.end()) {
         for(unsigned int i = 0; i < tc_r_mems->second.size(); i++) {
           if(areEqual(tc_r_mems->second[i], tuple)) {
-            Node exp = d_trueNode;
+            Node exp = MEMBER(tc_r_mems->second[i], tc_r_mems->first);
             if(tc_r_rep != tc_term[0]) {
               exp = explain(EQUAL(tc_r_rep, tc_term[0]));
             }
@@ -1006,7 +1022,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
       if(tc_t_mems != d_membership_db.end()) {
         for(unsigned int j = 0; j < tc_t_mems->second.size(); j++) {
           if(areEqual(tc_t_mems->second[j], tuple)) {
-            Node exp = d_trueNode;
+            Node exp = MEMBER(tc_t_mems->second[j], tc_t_mems->first);
             if(tc_rep != tc_terms[i]) {
               exp = AND(exp, explain(EQUAL(tc_rep, tc_terms[i])));
             }
@@ -1038,7 +1054,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
     return Node::null();
   }
 
-  inline Node TheorySetsRels::nthElementOfTuple( Node tuple, int n_th ) {
+  Node TheorySetsRels::nthElementOfTuple( Node tuple, int n_th ) {
     if(tuple.isConst() || (!tuple.isVar() && !tuple.isConst()))
       return tuple[n_th];
     Datatype dt = tuple.getType().getDatatype();
@@ -1067,17 +1083,19 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P
   }
 
   bool TheorySetsRels::holds(Node node) {
+    Trace("rels-check") << " [sets-rels] Check if node = " << node << " already holds " << std::endl;
     bool polarity = node.getKind() != kind::NOT;
     Node atom = polarity ? node : node[0];
     Node polarity_atom = polarity ? d_trueNode : d_falseNode;
-    if(d_eqEngine->hasTerm(node)) {
-      return areEqual(node, polarity_atom);
+    if(d_eqEngine->hasTerm(atom)) {
+      Trace("rels-check") << " [sets-rels] node = " << node << " is in the EE " << std::endl;
+      return areEqual(atom, polarity_atom);
     } else {
       Node atom_mod = NodeManager::currentNM()->mkNode(atom.getKind(),
                                                        getRepresentative(atom[0]),
                                                        getRepresentative(atom[1]));
       if(d_eqEngine->hasTerm(atom_mod)) {
-        return areEqual(node, polarity_atom);
+        return areEqual(atom_mod, polarity_atom);
       }
     }
     return false;
index 0876cc5b3016bbde83e9742981eeb0dac2e519c7..b5e36603b0a7fb2288e2f20670a1fb390d388dfd 100644 (file)
@@ -45,6 +45,7 @@ class TheorySetsRels {
 
   typedef context::CDChunkList<Node> NodeList;
   typedef context::CDHashSet<Node, NodeHashFunction> NodeSet;
+  typedef context::CDHashMap<Node, bool, NodeHashFunction> NodeBoolMap;
 
 public:
   TheorySetsRels(context::Context* c,
@@ -100,7 +101,7 @@ private:
   NodeSet d_lemma;
   NodeSet d_shared_terms;
 
-  std::hash_set< Node, NodeHashFunction > d_tc_nodes;
+  std::hash_set< Node, NodeHashFunction > d_rel_nodes;
   std::map< Node, std::vector<Node> > d_tuple_reps;
   std::map< Node, TupleTrie > d_membership_trie;
   std::hash_set< Node, NodeHashFunction > d_symbolic_tuples;
@@ -141,7 +142,6 @@ private:
   void buildTCGraph( Node, Node, Node );
   void computeRels( Node );
   void computeTransposeRelations( Node );
-  Node reverseTuple( Node );
   void finalizeTCInfer();
   void inferTC( Node, std::map< Node, std::hash_set< Node, NodeHashFunction > >& );
   void inferTC( Node, Node, std::map< Node, std::hash_set< Node, NodeHashFunction > >&,
@@ -159,7 +159,6 @@ private:
   bool checkCycles( Node );
 
   // Helper functions
-  inline Node nthElementOfTuple( Node, int);
   inline Node getReason(Node tc_rep, Node tc_term, Node tc_r_rep, Node tc_r);
   inline Node constructPair(Node tc_rep, Node a, Node b);
   Node findMemExp(Node r, Node tuple);
@@ -178,6 +177,10 @@ private:
   bool isRel( Node n ) {return n.getType().isSet() && n.getType().getSetElementType().isTuple();}
   Node mkAnd( std::vector< TNode >& assumptions );
 
+public:
+  static Node reverseTuple( Node );
+  static Node nthElementOfTuple( Node, int);
+
 };
 
 
index 8d76748bb6077497989f07956f965ea92e14d506..61ec07e9da7edfc8cd2e9eb74174ecf5f42cefc8 100644 (file)
@@ -16,6 +16,7 @@
 
 #include "theory/sets/theory_sets_rewriter.h"
 #include "theory/sets/normal_form.h"
+#include "theory/sets/theory_sets_rels.h"
 
 namespace CVC4 {
 namespace theory {
@@ -62,37 +63,6 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
       bool isMember = checkConstantMembership(node[0], S);
       return RewriteResponse(REWRITE_DONE, nm->mkConst(isMember));
     }
-//    if(node[1].getKind() == kind::TRANSPOSE) {
-//      // only work for node[0] is an actual tuple like (a, b), won't work for tuple variables
-//      if(node[0].getType().isSet() && !node[0].getType().getSetElementType().isTuple()) {
-//        Node atom = node;
-//        bool polarity = node.getKind() != kind::NOT;
-//        if( !polarity )
-//          atom = atom[0];
-//        Node new_node = NodeManager::currentNM()->mkNode(kind::MEMBER, atom[0], atom[1][0]);
-//        if(!polarity)
-//          new_node = new_node.negate();
-//        return RewriteResponse(REWRITE_AGAIN, new_node);
-//      }
-//      if(node[0].isVar())
-//        return RewriteResponse(REWRITE_DONE, node);
-//      std::vector<Node> elements;
-//      std::vector<TypeNode> tuple_types = node[0].getType().getTupleTypes();
-//      std::reverse(tuple_types.begin(), tuple_types.end());
-//      TypeNode tn = NodeManager::currentNM()->mkTupleType(tuple_types);
-//      Datatype dt = tn.getDatatype();
-//      elements.push_back(Node::fromExpr(dt[0].getConstructor()));
-//      for(Node::iterator child_it = node[0].end()-1;
-//                child_it != node[0].begin()-1; --child_it) {
-//        elements.push_back(*child_it);
-//      }
-//      Node new_node = NodeManager::currentNM()->mkNode(kind::MEMBER,
-//                                                       NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, elements),
-//                                                       node[1][0]);
-//      if(node.getKind() == kind::NOT)
-//        new_node = NodeManager::currentNM()->mkNode(kind::NOT, new_node);
-//      return RewriteResponse(REWRITE_AGAIN, new_node);
-//    }
     break;
   }//kind::MEMBER
 
@@ -208,24 +178,135 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) {
   }//kind::UNION
 
   case kind::TRANSPOSE: {
+    if(node[0].getKind() == kind::TRANSPOSE) {
+      return RewriteResponse(REWRITE_AGAIN, node[0][0]);
+    }
+
+    if(node[0].getKind() == kind::EMPTYSET) {
+      return RewriteResponse(REWRITE_DONE, nm->mkConst(EmptySet(nm->toType(node.getType()))));
+    } else if(node[0].isConst()) {
+      std::set<Node> new_tuple_set;
+      std::set<Node> tuple_set = NormalForm::getElementsFromNormalConstant(node[0]);
+      std::set<Node>::iterator tuple_it = tuple_set.begin();
+
+      while(tuple_it != tuple_set.end()) {
+        new_tuple_set.insert(TheorySetsRels::reverseTuple(*tuple_it));
+        tuple_it++;
+      }
+      Node new_node = NormalForm::elementsToSet(new_tuple_set, node.getType());
+      Assert(new_node.isConst());
+      Trace("sets-postrewrite") << "Sets::postRewrite returning " << new_node << std::endl;
+      return RewriteResponse(REWRITE_DONE, new_node);
+
+    }
     if(node[0].getKind() != kind::TRANSPOSE) {
       Trace("sets-postrewrite") << "Sets::postRewrite returning " << node << std::endl;
       return RewriteResponse(REWRITE_DONE, node);
     }
-    if(node[0].getKind() == kind::TRANSPOSE) {
-      return RewriteResponse(REWRITE_AGAIN, node[0][0]);
+    break;
+  }
+
+  case kind::PRODUCT: {
+    Trace("sets-rels-postrewrite") << "Sets::postRewrite processing " <<  node << std::endl;
+    if( node[0].getKind() == kind::EMPTYSET ||
+        node[1].getKind() == kind::EMPTYSET) {
+      return RewriteResponse(REWRITE_DONE, nm->mkConst(EmptySet(nm->toType(node.getType()))));
+    } else if( node[0].isConst() && node[1].isConst() ) {
+      Trace("sets-rels-postrewrite") << "Sets::postRewrite processing **** " <<  node << std::endl;
+      std::set<Node> new_tuple_set;
+      std::set<Node> left = NormalForm::getElementsFromNormalConstant(node[0]);
+      std::set<Node> right = NormalForm::getElementsFromNormalConstant(node[1]);
+      std::set<Node>::iterator left_it = left.begin();
+      int left_len = (*left_it).getType().getTupleLength();
+      TypeNode tn = node.getType().getSetElementType();
+      while(left_it != left.end()) {
+        Trace("rels-debug") << "Sets::postRewrite processing left_it = " <<  *left_it << std::endl;
+        std::vector<Node> left_tuple;
+        left_tuple.push_back(Node::fromExpr(tn.getDatatype()[0].getConstructor()));
+        for(int i = 0; i < left_len; i++) {
+          left_tuple.push_back(TheorySetsRels::nthElementOfTuple(*left_it,i));
+        }
+        std::set<Node>::iterator right_it = right.begin();
+        int right_len = (*right_it).getType().getTupleLength();
+        while(right_it != right.end()) {
+          Trace("rels-debug") << "Sets::postRewrite processing left_it = " <<  *right_it << std::endl;
+          std::vector<Node> right_tuple;
+          for(int j = 0; j < right_len; j++) {
+            right_tuple.push_back(TheorySetsRels::nthElementOfTuple(*right_it,j));
+          }
+          std::vector<Node> new_tuple;
+          new_tuple.insert(new_tuple.end(), left_tuple.begin(), left_tuple.end());
+          new_tuple.insert(new_tuple.end(), right_tuple.begin(), right_tuple.end());
+          Node composed_tuple = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, new_tuple);
+          new_tuple_set.insert(composed_tuple);
+          right_it++;
+        }
+        left_it++;
+      }
+      Node new_node = NormalForm::elementsToSet(new_tuple_set, node.getType());
+      Assert(new_node.isConst());
+      Trace("sets-postrewrite") << "Sets::postRewrite returning " << new_node << std::endl;
+      return RewriteResponse(REWRITE_DONE, new_node);
+    }
+    break;
+  }
+
+  case kind::JOIN: {
+    if( node[0].getKind() == kind::EMPTYSET ||
+        node[1].getKind() == kind::EMPTYSET) {
+      return RewriteResponse(REWRITE_DONE, nm->mkConst(EmptySet(nm->toType(node.getType()))));
+    } else if( node[0].isConst() && node[1].isConst() ) {
+      Trace("sets-rels-postrewrite") << "Sets::postRewrite processing " <<  node << std::endl;
+      std::set<Node> new_tuple_set;
+      std::set<Node> left = NormalForm::getElementsFromNormalConstant(node[0]);
+      std::set<Node> right = NormalForm::getElementsFromNormalConstant(node[1]);
+      std::set<Node>::iterator left_it = left.begin();
+      int left_len = (*left_it).getType().getTupleLength();
+      TypeNode tn = node.getType().getSetElementType();
+      while(left_it != left.end()) {
+        std::vector<Node> left_tuple;
+        left_tuple.push_back(Node::fromExpr(tn.getDatatype()[0].getConstructor()));
+        for(int i = 0; i < left_len - 1; i++) {
+          left_tuple.push_back(TheorySetsRels::nthElementOfTuple(*left_it,i));
+        }
+        std::set<Node>::iterator right_it = right.begin();
+        int right_len = (*right_it).getType().getTupleLength();
+        while(right_it != right.end()) {
+          if(TheorySetsRels::nthElementOfTuple(*left_it,left_len-1) == TheorySetsRels::nthElementOfTuple(*right_it,0)) {
+            std::vector<Node> right_tuple;
+            for(int j = 1; j < right_len; j++) {
+              right_tuple.push_back(TheorySetsRels::nthElementOfTuple(*right_it,j));
+            }
+            std::vector<Node> new_tuple;
+            new_tuple.insert(new_tuple.end(), left_tuple.begin(), left_tuple.end());
+            new_tuple.insert(new_tuple.end(), right_tuple.begin(), right_tuple.end());
+            Node composed_tuple = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, new_tuple);
+            new_tuple_set.insert(composed_tuple);
+          }
+          right_it++;
+        }
+        left_it++;
+      }
+      Node new_node = NormalForm::elementsToSet(new_tuple_set, node.getType());
+      Assert(new_node.isConst());
+      Trace("sets-postrewrite") << "Sets::postRewrite returning " << new_node << std::endl;
+      return RewriteResponse(REWRITE_DONE, new_node);
     }
+
     break;
   }
 
   case kind::TRANSCLOSURE: {
-    if(node[0].getKind() != kind::TRANSCLOSURE) {
+    if(node[0].getKind() == kind::EMPTYSET) {
+      return RewriteResponse(REWRITE_DONE, nm->mkConst(EmptySet(nm->toType(node.getType()))));
+    } else if (node[0].isConst()) {
+
+    } else if(node[0].getKind() == kind::TRANSCLOSURE) {
+      return RewriteResponse(REWRITE_AGAIN, node[0]);
+    } else if(node[0].getKind() != kind::TRANSCLOSURE) {
       Trace("sets-postrewrite") << "Sets::postRewrite returning " << node << std::endl;
       return RewriteResponse(REWRITE_DONE, node);
     }
-    if(node[0].getKind() == kind::TRANSCLOSURE) {
-      return RewriteResponse(REWRITE_AGAIN, node[0]);
-    }
     break;
   }