From 4a17519a49f49633fa0145a55b1b45346f2b86fc Mon Sep 17 00:00:00 2001 From: PaulMeng Date: Thu, 14 Apr 2016 22:01:32 -0500 Subject: [PATCH] - Implement constant rewriter for relational operators for model generation - fixed a few bugs --- src/theory/sets/theory_sets_private.cpp | 69 ++++++++++ src/theory/sets/theory_sets_rels.cpp | 60 ++++++--- src/theory/sets/theory_sets_rels.h | 9 +- src/theory/sets/theory_sets_rewriter.cpp | 155 +++++++++++++++++------ 4 files changed, 232 insertions(+), 61 deletions(-) diff --git a/src/theory/sets/theory_sets_private.cpp b/src/theory/sets/theory_sets_private.cpp index 10fc9f195..db4f4bf26 100644 --- a/src/theory/sets/theory_sets_private.cpp +++ b/src/theory/sets/theory_sets_private.cpp @@ -672,6 +672,7 @@ bool TheorySetsPrivate::checkModel(const SettermElementsMap& settermElementsMap, << std::endl; Assert(S.getType().isSet()); + std::set temp_nodes; const Elements emptySetOfElements; const Elements& saved = @@ -715,6 +716,74 @@ bool TheorySetsPrivate::checkModel(const SettermElementsMap& settermElementsMap, std::set_difference(left.begin(), left.end(), right.begin(), right.end(), std::inserter(cur, cur.begin()) ); break; + case kind::PRODUCT: { + std::set new_tuple_set; + Elements::const_iterator left_it = left.begin(); + int left_len = (*left_it).getType().getTupleLength(); + TypeNode tn = S.getType().getSetElementType(); + while(left_it != left.end()) { + Trace("rels-debug") << "Sets::postRewrite processing left_it = " << *left_it << std::endl; + std::vector left_tuple; + left_tuple.push_back(Node::fromExpr(tn.getDatatype()[0].getConstructor())); + for(int i = 0; i < left_len; i++) { + left_tuple.push_back(TheorySetsRels::nthElementOfTuple(*left_it,i)); + } + Elements::const_iterator right_it = right.begin(); + int right_len = (*right_it).getType().getTupleLength(); + while(right_it != right.end()) { + std::vector right_tuple; + for(int j = 0; j < right_len; j++) { + right_tuple.push_back(TheorySetsRels::nthElementOfTuple(*right_it,j)); + } + std::vector new_tuple; + new_tuple.insert(new_tuple.end(), left_tuple.begin(), left_tuple.end()); + new_tuple.insert(new_tuple.end(), right_tuple.begin(), right_tuple.end()); + Node composed_tuple = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, new_tuple); + temp_nodes.insert(composed_tuple); + new_tuple_set.insert(composed_tuple); + right_it++; + } + left_it++; + } + cur.insert(new_tuple_set.begin(), new_tuple_set.end()); + Trace("rels-debug") << " ***** Done with check model for product operator" << std::endl; + break; + } + case kind::JOIN: { + std::set new_tuple_set; + Elements::const_iterator left_it = left.begin(); + int left_len = (*left_it).getType().getTupleLength(); + TypeNode tn = S.getType().getSetElementType(); + while(left_it != left.end()) { + std::vector left_tuple; + left_tuple.push_back(Node::fromExpr(tn.getDatatype()[0].getConstructor())); + for(int i = 0; i < left_len - 1; i++) { + left_tuple.push_back(TheorySetsRels::nthElementOfTuple(*left_it,i)); + } + Elements::const_iterator right_it = right.begin(); + int right_len = (*right_it).getType().getTupleLength(); + while(right_it != right.end()) { + if(TheorySetsRels::nthElementOfTuple(*left_it,left_len-1) == TheorySetsRels::nthElementOfTuple(*right_it,0)) { + std::vector right_tuple; + for(int j = 1; j < right_len; j++) { + right_tuple.push_back(TheorySetsRels::nthElementOfTuple(*right_it,j)); + } + std::vector new_tuple; + new_tuple.insert(new_tuple.end(), left_tuple.begin(), left_tuple.end()); + new_tuple.insert(new_tuple.end(), right_tuple.begin(), right_tuple.end()); + Node composed_tuple = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, new_tuple); + new_tuple_set.insert(composed_tuple); + } + right_it++; + } + left_it++; + } + cur.insert(new_tuple_set.begin(), new_tuple_set.end()); + Trace("rels-debug") << " ***** Done with check model for JOIN operator" << std::endl; + break; + } + case kind::TRANSCLOSURE: + break; default: Unhandled(); } diff --git a/src/theory/sets/theory_sets_rels.cpp b/src/theory/sets/theory_sets_rels.cpp index 0e20b9bfa..eae9a4e8f 100644 --- a/src/theory/sets/theory_sets_rels.cpp +++ b/src/theory/sets/theory_sets_rels.cpp @@ -236,10 +236,10 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P Node tc_r_rep = getRepresentative(tc_term[0]); // build the TC graph for tc_rep if it was not created before - if( d_tc_nodes.find(tc_rep) == d_tc_nodes.end() ) { + if( d_rel_nodes.find(tc_rep) == d_rel_nodes.end() ) { Trace("rels-debug") << "[sets-rels] Start building the TC graph!" << std::endl; buildTCGraph(tc_r_rep, tc_rep, tc_term); - d_tc_nodes.insert(tc_rep); + d_rel_nodes.insert(tc_rep); } // insert atom[0] in the tc_graph if it is not in the graph already TC_IT tc_graph_it = d_membership_tc_cache.find(tc_rep); @@ -316,6 +316,10 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P void TheorySetsRels::applyProductRule(Node exp, Node product_term) { Trace("rels-debug") << "\n[sets-rels] *********** Applying PRODUCT rule " << std::endl; + if(d_rel_nodes.find(product_term) == d_rel_nodes.end()) { + computeRels(product_term); + d_rel_nodes.insert(product_term); + } bool polarity = exp.getKind() != kind::NOT; Node atom = polarity ? exp : exp[0]; Node r1_rep = getRepresentative(product_term[0]); @@ -366,25 +370,22 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P reason_2 = Rewriter::rewrite(AND(reason_2, explain(EQUAL(t2, t2_rep)))); } if(polarity) { - if(!hasMember(r1_rep, t1_rep)) { - addToMap(d_membership_db, r1_rep, t1_rep); + if(safeAddToMap(d_membership_db, r1_rep, t1_rep)) { addToMap(d_membership_exp_db, r1_rep, reason_1); sendInfer(fact_1, reason_1, "product-split"); } - if(!hasMember(r2_rep, t2)) { - addToMap(d_membership_db, r2_rep, t2); + if(safeAddToMap(d_membership_db, r2_rep, t2_rep)) { addToMap(d_membership_exp_db, r2_rep, reason_2); sendInfer(fact_2, reason_2, "product-split"); } } else { -// sendInfer(fact_1.negate(), reason_1, "product-split"); -// sendInfer(fact_2.negate(), reason_2, "product-split"); + sendInfer(fact_1.negate(), reason_1, "product-split"); + sendInfer(fact_2.negate(), reason_2, "product-split"); // ONLY need to explicitly compute joins if there are negative literals involving PRODUCT Trace("rels-debug") << "\n[sets-rels] Apply PRODUCT-COMPOSE rule on term: " << product_term << " with explanation: " << exp << std::endl; - computeRels(product_term); } } @@ -399,6 +400,10 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P */ void TheorySetsRels::applyJoinRule(Node exp, Node join_term) { Trace("rels-debug") << "\n[sets-rels] *********** Applying JOIN rule " << std::endl; + if(d_rel_nodes.find(join_term) == d_rel_nodes.end()) { + computeRels(join_term); + d_rel_nodes.insert(join_term); + } bool polarity = exp.getKind() != kind::NOT; Node atom = polarity ? exp : exp[0]; Node r1_rep = getRepresentative(join_term[0]); @@ -472,7 +477,6 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P // ONLY need to explicitly compute joins if there are negative literals involving JOIN Trace("rels-debug") << "\n[sets-rels] Apply JOIN-COMPOSE rule on term: " << join_term << " with explanation: " << exp << std::endl; - computeRels(join_term); } } @@ -521,14 +525,18 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P != d_terms_cache[tp_t0_rep].end()) { std::vector join_terms = d_terms_cache[tp_t0_rep][kind::JOIN]; for(unsigned int i = 0; i < join_terms.size(); i++) { - computeRels(join_terms[i]); + if(d_rel_nodes.find(join_terms[i]) == d_rel_nodes.end()) { + computeRels(join_terms[i]); + } } } if(d_terms_cache[tp_t0_rep].find(kind::PRODUCT) != d_terms_cache[tp_t0_rep].end()) { std::vector product_terms = d_terms_cache[tp_t0_rep][kind::PRODUCT]; for(unsigned int i = 0; i < product_terms.size(); i++) { - computeRels(product_terms[i]); + if(d_rel_nodes.find(product_terms[i]) == d_rel_nodes.end()) { + computeRels(product_terms[i]); + } } } fact = fact.negate(); @@ -792,7 +800,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P d_sets_theory.d_out->lemma(NodeManager::currentNM()->mkNode(kind::IMPLIES, child_it->second, child_it->first)); } } - d_tc_nodes.clear(); + d_rel_nodes.clear(); d_pending_facts.clear(); d_membership_constraints_cache.clear(); d_membership_tc_cache.clear(); @@ -930,12 +938,20 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P } bool TheorySetsRels::safeAddToMap(std::map< Node, std::vector >& map, Node rel_rep, Node member) { - if(map.find(rel_rep) == map.end()) { + std::map< Node, std::vector< Node > >::iterator mem_it = map.find(rel_rep); + if(mem_it == map.end()) { std::vector members; members.push_back(member); map[rel_rep] = members; return true; - } else if(std::find(map[rel_rep].begin(), map[rel_rep].end(), member) == map[rel_rep].end()) { + } else { + std::vector::iterator mems = mem_it->second.begin(); + while(mems != mem_it->second.end()) { + if(areEqual(*mems, member)) { + return false; + } + mems++; + } map[rel_rep].push_back(member); return true; } @@ -979,7 +995,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P if(tc_r_mems != d_membership_db.end()) { for(unsigned int i = 0; i < tc_r_mems->second.size(); i++) { if(areEqual(tc_r_mems->second[i], tuple)) { - Node exp = d_trueNode; + Node exp = MEMBER(tc_r_mems->second[i], tc_r_mems->first); if(tc_r_rep != tc_term[0]) { exp = explain(EQUAL(tc_r_rep, tc_term[0])); } @@ -1006,7 +1022,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P if(tc_t_mems != d_membership_db.end()) { for(unsigned int j = 0; j < tc_t_mems->second.size(); j++) { if(areEqual(tc_t_mems->second[j], tuple)) { - Node exp = d_trueNode; + Node exp = MEMBER(tc_t_mems->second[j], tc_t_mems->first); if(tc_rep != tc_terms[i]) { exp = AND(exp, explain(EQUAL(tc_rep, tc_terms[i]))); } @@ -1038,7 +1054,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P return Node::null(); } - inline Node TheorySetsRels::nthElementOfTuple( Node tuple, int n_th ) { + Node TheorySetsRels::nthElementOfTuple( Node tuple, int n_th ) { if(tuple.isConst() || (!tuple.isVar() && !tuple.isConst())) return tuple[n_th]; Datatype dt = tuple.getType().getDatatype(); @@ -1067,17 +1083,19 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P } bool TheorySetsRels::holds(Node node) { + Trace("rels-check") << " [sets-rels] Check if node = " << node << " already holds " << std::endl; bool polarity = node.getKind() != kind::NOT; Node atom = polarity ? node : node[0]; Node polarity_atom = polarity ? d_trueNode : d_falseNode; - if(d_eqEngine->hasTerm(node)) { - return areEqual(node, polarity_atom); + if(d_eqEngine->hasTerm(atom)) { + Trace("rels-check") << " [sets-rels] node = " << node << " is in the EE " << std::endl; + return areEqual(atom, polarity_atom); } else { Node atom_mod = NodeManager::currentNM()->mkNode(atom.getKind(), getRepresentative(atom[0]), getRepresentative(atom[1])); if(d_eqEngine->hasTerm(atom_mod)) { - return areEqual(node, polarity_atom); + return areEqual(atom_mod, polarity_atom); } } return false; diff --git a/src/theory/sets/theory_sets_rels.h b/src/theory/sets/theory_sets_rels.h index 0876cc5b3..b5e36603b 100644 --- a/src/theory/sets/theory_sets_rels.h +++ b/src/theory/sets/theory_sets_rels.h @@ -45,6 +45,7 @@ class TheorySetsRels { typedef context::CDChunkList NodeList; typedef context::CDHashSet NodeSet; + typedef context::CDHashMap NodeBoolMap; public: TheorySetsRels(context::Context* c, @@ -100,7 +101,7 @@ private: NodeSet d_lemma; NodeSet d_shared_terms; - std::hash_set< Node, NodeHashFunction > d_tc_nodes; + std::hash_set< Node, NodeHashFunction > d_rel_nodes; std::map< Node, std::vector > d_tuple_reps; std::map< Node, TupleTrie > d_membership_trie; std::hash_set< Node, NodeHashFunction > d_symbolic_tuples; @@ -141,7 +142,6 @@ private: void buildTCGraph( Node, Node, Node ); void computeRels( Node ); void computeTransposeRelations( Node ); - Node reverseTuple( Node ); void finalizeTCInfer(); void inferTC( Node, std::map< Node, std::hash_set< Node, NodeHashFunction > >& ); void inferTC( Node, Node, std::map< Node, std::hash_set< Node, NodeHashFunction > >&, @@ -159,7 +159,6 @@ private: bool checkCycles( Node ); // Helper functions - inline Node nthElementOfTuple( Node, int); inline Node getReason(Node tc_rep, Node tc_term, Node tc_r_rep, Node tc_r); inline Node constructPair(Node tc_rep, Node a, Node b); Node findMemExp(Node r, Node tuple); @@ -178,6 +177,10 @@ private: bool isRel( Node n ) {return n.getType().isSet() && n.getType().getSetElementType().isTuple();} Node mkAnd( std::vector< TNode >& assumptions ); +public: + static Node reverseTuple( Node ); + static Node nthElementOfTuple( Node, int); + }; diff --git a/src/theory/sets/theory_sets_rewriter.cpp b/src/theory/sets/theory_sets_rewriter.cpp index 8d76748bb..61ec07e9d 100644 --- a/src/theory/sets/theory_sets_rewriter.cpp +++ b/src/theory/sets/theory_sets_rewriter.cpp @@ -16,6 +16,7 @@ #include "theory/sets/theory_sets_rewriter.h" #include "theory/sets/normal_form.h" +#include "theory/sets/theory_sets_rels.h" namespace CVC4 { namespace theory { @@ -62,37 +63,6 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) { bool isMember = checkConstantMembership(node[0], S); return RewriteResponse(REWRITE_DONE, nm->mkConst(isMember)); } -// if(node[1].getKind() == kind::TRANSPOSE) { -// // only work for node[0] is an actual tuple like (a, b), won't work for tuple variables -// if(node[0].getType().isSet() && !node[0].getType().getSetElementType().isTuple()) { -// Node atom = node; -// bool polarity = node.getKind() != kind::NOT; -// if( !polarity ) -// atom = atom[0]; -// Node new_node = NodeManager::currentNM()->mkNode(kind::MEMBER, atom[0], atom[1][0]); -// if(!polarity) -// new_node = new_node.negate(); -// return RewriteResponse(REWRITE_AGAIN, new_node); -// } -// if(node[0].isVar()) -// return RewriteResponse(REWRITE_DONE, node); -// std::vector elements; -// std::vector tuple_types = node[0].getType().getTupleTypes(); -// std::reverse(tuple_types.begin(), tuple_types.end()); -// TypeNode tn = NodeManager::currentNM()->mkTupleType(tuple_types); -// Datatype dt = tn.getDatatype(); -// elements.push_back(Node::fromExpr(dt[0].getConstructor())); -// for(Node::iterator child_it = node[0].end()-1; -// child_it != node[0].begin()-1; --child_it) { -// elements.push_back(*child_it); -// } -// Node new_node = NodeManager::currentNM()->mkNode(kind::MEMBER, -// NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, elements), -// node[1][0]); -// if(node.getKind() == kind::NOT) -// new_node = NodeManager::currentNM()->mkNode(kind::NOT, new_node); -// return RewriteResponse(REWRITE_AGAIN, new_node); -// } break; }//kind::MEMBER @@ -208,24 +178,135 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) { }//kind::UNION case kind::TRANSPOSE: { + if(node[0].getKind() == kind::TRANSPOSE) { + return RewriteResponse(REWRITE_AGAIN, node[0][0]); + } + + if(node[0].getKind() == kind::EMPTYSET) { + return RewriteResponse(REWRITE_DONE, nm->mkConst(EmptySet(nm->toType(node.getType())))); + } else if(node[0].isConst()) { + std::set new_tuple_set; + std::set tuple_set = NormalForm::getElementsFromNormalConstant(node[0]); + std::set::iterator tuple_it = tuple_set.begin(); + + while(tuple_it != tuple_set.end()) { + new_tuple_set.insert(TheorySetsRels::reverseTuple(*tuple_it)); + tuple_it++; + } + Node new_node = NormalForm::elementsToSet(new_tuple_set, node.getType()); + Assert(new_node.isConst()); + Trace("sets-postrewrite") << "Sets::postRewrite returning " << new_node << std::endl; + return RewriteResponse(REWRITE_DONE, new_node); + + } if(node[0].getKind() != kind::TRANSPOSE) { Trace("sets-postrewrite") << "Sets::postRewrite returning " << node << std::endl; return RewriteResponse(REWRITE_DONE, node); } - if(node[0].getKind() == kind::TRANSPOSE) { - return RewriteResponse(REWRITE_AGAIN, node[0][0]); + break; + } + + case kind::PRODUCT: { + Trace("sets-rels-postrewrite") << "Sets::postRewrite processing " << node << std::endl; + if( node[0].getKind() == kind::EMPTYSET || + node[1].getKind() == kind::EMPTYSET) { + return RewriteResponse(REWRITE_DONE, nm->mkConst(EmptySet(nm->toType(node.getType())))); + } else if( node[0].isConst() && node[1].isConst() ) { + Trace("sets-rels-postrewrite") << "Sets::postRewrite processing **** " << node << std::endl; + std::set new_tuple_set; + std::set left = NormalForm::getElementsFromNormalConstant(node[0]); + std::set right = NormalForm::getElementsFromNormalConstant(node[1]); + std::set::iterator left_it = left.begin(); + int left_len = (*left_it).getType().getTupleLength(); + TypeNode tn = node.getType().getSetElementType(); + while(left_it != left.end()) { + Trace("rels-debug") << "Sets::postRewrite processing left_it = " << *left_it << std::endl; + std::vector left_tuple; + left_tuple.push_back(Node::fromExpr(tn.getDatatype()[0].getConstructor())); + for(int i = 0; i < left_len; i++) { + left_tuple.push_back(TheorySetsRels::nthElementOfTuple(*left_it,i)); + } + std::set::iterator right_it = right.begin(); + int right_len = (*right_it).getType().getTupleLength(); + while(right_it != right.end()) { + Trace("rels-debug") << "Sets::postRewrite processing left_it = " << *right_it << std::endl; + std::vector right_tuple; + for(int j = 0; j < right_len; j++) { + right_tuple.push_back(TheorySetsRels::nthElementOfTuple(*right_it,j)); + } + std::vector new_tuple; + new_tuple.insert(new_tuple.end(), left_tuple.begin(), left_tuple.end()); + new_tuple.insert(new_tuple.end(), right_tuple.begin(), right_tuple.end()); + Node composed_tuple = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, new_tuple); + new_tuple_set.insert(composed_tuple); + right_it++; + } + left_it++; + } + Node new_node = NormalForm::elementsToSet(new_tuple_set, node.getType()); + Assert(new_node.isConst()); + Trace("sets-postrewrite") << "Sets::postRewrite returning " << new_node << std::endl; + return RewriteResponse(REWRITE_DONE, new_node); + } + break; + } + + case kind::JOIN: { + if( node[0].getKind() == kind::EMPTYSET || + node[1].getKind() == kind::EMPTYSET) { + return RewriteResponse(REWRITE_DONE, nm->mkConst(EmptySet(nm->toType(node.getType())))); + } else if( node[0].isConst() && node[1].isConst() ) { + Trace("sets-rels-postrewrite") << "Sets::postRewrite processing " << node << std::endl; + std::set new_tuple_set; + std::set left = NormalForm::getElementsFromNormalConstant(node[0]); + std::set right = NormalForm::getElementsFromNormalConstant(node[1]); + std::set::iterator left_it = left.begin(); + int left_len = (*left_it).getType().getTupleLength(); + TypeNode tn = node.getType().getSetElementType(); + while(left_it != left.end()) { + std::vector left_tuple; + left_tuple.push_back(Node::fromExpr(tn.getDatatype()[0].getConstructor())); + for(int i = 0; i < left_len - 1; i++) { + left_tuple.push_back(TheorySetsRels::nthElementOfTuple(*left_it,i)); + } + std::set::iterator right_it = right.begin(); + int right_len = (*right_it).getType().getTupleLength(); + while(right_it != right.end()) { + if(TheorySetsRels::nthElementOfTuple(*left_it,left_len-1) == TheorySetsRels::nthElementOfTuple(*right_it,0)) { + std::vector right_tuple; + for(int j = 1; j < right_len; j++) { + right_tuple.push_back(TheorySetsRels::nthElementOfTuple(*right_it,j)); + } + std::vector new_tuple; + new_tuple.insert(new_tuple.end(), left_tuple.begin(), left_tuple.end()); + new_tuple.insert(new_tuple.end(), right_tuple.begin(), right_tuple.end()); + Node composed_tuple = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, new_tuple); + new_tuple_set.insert(composed_tuple); + } + right_it++; + } + left_it++; + } + Node new_node = NormalForm::elementsToSet(new_tuple_set, node.getType()); + Assert(new_node.isConst()); + Trace("sets-postrewrite") << "Sets::postRewrite returning " << new_node << std::endl; + return RewriteResponse(REWRITE_DONE, new_node); } + break; } case kind::TRANSCLOSURE: { - if(node[0].getKind() != kind::TRANSCLOSURE) { + if(node[0].getKind() == kind::EMPTYSET) { + return RewriteResponse(REWRITE_DONE, nm->mkConst(EmptySet(nm->toType(node.getType())))); + } else if (node[0].isConst()) { + + } else if(node[0].getKind() == kind::TRANSCLOSURE) { + return RewriteResponse(REWRITE_AGAIN, node[0]); + } else if(node[0].getKind() != kind::TRANSCLOSURE) { Trace("sets-postrewrite") << "Sets::postRewrite returning " << node << std::endl; return RewriteResponse(REWRITE_DONE, node); } - if(node[0].getKind() == kind::TRANSCLOSURE) { - return RewriteResponse(REWRITE_AGAIN, node[0]); - } break; } -- 2.30.2