From 7a030ed02d5a8fcabe327c24cecaa9d69919e863 Mon Sep 17 00:00:00 2001 From: Paul Meng Date: Mon, 11 Jul 2016 21:40:25 -0400 Subject: [PATCH] added support for expansion of transitive closure --- src/theory/sets/theory_sets_rels.cpp | 279 ++++++++++++++++++--------- src/theory/sets/theory_sets_rels.h | 6 +- 2 files changed, 193 insertions(+), 92 deletions(-) diff --git a/src/theory/sets/theory_sets_rels.cpp b/src/theory/sets/theory_sets_rels.cpp index 3f7d079bd..24aa44f3b 100644 --- a/src/theory/sets/theory_sets_rels.cpp +++ b/src/theory/sets/theory_sets_rels.cpp @@ -234,12 +234,9 @@ int TheorySetsRels::EqcInfo::counter = 0; void TheorySetsRels::applyTCRule(Node exp, Node tc_term) { Trace("rels-debug") << "\n[sets-rels] *********** Applying TRANSITIVE CLOSURE rule on " << tc_term << " with explanation " << exp << std::endl; - bool polarity = exp.getKind() != kind::NOT; - Node atom = polarity ? exp : exp[0]; - Node tup_rep = getRepresentative(atom[0]); + Node tc_rep = getRepresentative(tc_term); Node tc_r_rep = getRepresentative(tc_term[0]); - // build the TC graph for tc_rep if it was not created before if( d_rel_nodes.find(tc_rep) == d_rel_nodes.end() ) { Trace("rels-debug") << "[sets-rels] Start building the TC graph!" << std::endl; @@ -247,50 +244,77 @@ int TheorySetsRels::EqcInfo::counter = 0; d_rel_nodes.insert(tc_rep); } - // insert tup_rep in the tc_graph if it is not in the graph already - TC_IT tc_graph_it = d_membership_tc_cache.find(tc_rep); + bool polarity = exp.getKind() != kind::NOT; if(polarity) { - Node fst_rep = getRepresentative(RelsUtils::nthElementOfTuple(tup_rep, 0)); - Node snd_rep = getRepresentative(RelsUtils::nthElementOfTuple(tup_rep, 1)); - - if(tc_graph_it != d_membership_tc_cache.end()) { - TC_PAIR_IT pair_set_it = tc_graph_it->second.find(fst_rep); + std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator mem_it = d_tc_membership_db.find(tc_term); - if(pair_set_it != tc_graph_it->second.end()) { - pair_set_it->second.insert(snd_rep); - } else { - std::hash_set< Node, NodeHashFunction > pair_set; - pair_set.insert(snd_rep); - tc_graph_it->second[fst_rep] = pair_set; - } - - Node reason = getReason(tc_rep, tc_term, tc_r_rep, tc_term[0]); - std::map< Node, Node >::iterator exp_it = d_membership_tc_exp_cache.find(tc_rep); - - if(!reason.isNull() && exp_it->second != reason) { - d_membership_tc_exp_cache[tc_rep] = Rewriter::rewrite(AND(exp_it->second, reason)); - } + if( mem_it == d_tc_membership_db.end() ) { + std::hash_set members; + members.insert(exp[0]); + d_tc_membership_db[tc_term] = members; } else { - std::map< Node, std::hash_set< Node, NodeHashFunction > > pair_set; - std::hash_set< Node, NodeHashFunction > snd_set; - - snd_set.insert(snd_rep); - pair_set[fst_rep] = snd_set; - d_membership_tc_cache[tc_rep] = pair_set; - Node reason = getReason(tc_rep, tc_term, tc_r_rep, tc_term[0]); - - if(!reason.isNull()) { - d_membership_tc_exp_cache[tc_rep] = reason; - } - } - // check if tup_rep already exists in TC graph for conflict - } else { - if(tc_graph_it != d_membership_tc_cache.end()) { - checkTCGraphForConflict(atom, tc_rep, d_trueNode, RelsUtils::nthElementOfTuple(tup_rep, 0), - RelsUtils::nthElementOfTuple(tup_rep, 1), tc_graph_it->second); + mem_it->second.insert(exp[0]); } } + //todo: need to construct a tc_graph if transitive closure is used in the context + +// Node atom = polarity ? exp : exp[0]; +// Node tup_rep = getRepresentative(atom[0]); +// Node tc_rep = getRepresentative(tc_term); +// Node tc_r_rep = getRepresentative(tc_term[0]); +// +// // build the TC graph for tc_rep if it was not created before +// if( d_rel_nodes.find(tc_rep) == d_rel_nodes.end() ) { +// Trace("rels-debug") << "[sets-rels] Start building the TC graph!" << std::endl; +// buildTCGraph(tc_r_rep, tc_rep, tc_term); +// d_rel_nodes.insert(tc_rep); +// } + + // insert tup_rep in the tc_graph if it is not in the graph already +// TC_IT tc_graph_it = d_membership_tc_cache.find(tc_rep); +// +// if( polarity ) { +// Node fst_rep = getRepresentative(RelsUtils::nthElementOfTuple(tup_rep, 0)); +// Node snd_rep = getRepresentative(RelsUtils::nthElementOfTuple(tup_rep, 1)); +// +// if( tc_graph_it != d_membership_tc_cache.end() ) { +// TC_PAIR_IT pair_set_it = tc_graph_it->second.find(fst_rep); +// +// if( pair_set_it != tc_graph_it->second.end() ) { +// pair_set_it->second.insert(snd_rep); +// } else { +// std::hash_set< Node, NodeHashFunction > pair_set; +// pair_set.insert(snd_rep); +// tc_graph_it->second[fst_rep] = pair_set; +// } +// +// Node reason = getReason(tc_rep, tc_term, tc_r_rep, tc_term[0]); +// std::map< Node, Node >::iterator exp_it = d_membership_tc_exp_cache.find(tc_rep); +// +// if(!reason.isNull() && exp_it->second != reason) { +// d_membership_tc_exp_cache[tc_rep] = Rewriter::rewrite(AND(exp_it->second, reason)); +// } +// } else { +// std::map< Node, std::hash_set< Node, NodeHashFunction > > pair_set; +// std::hash_set< Node, NodeHashFunction > snd_set; +// +// snd_set.insert(snd_rep); +// pair_set[fst_rep] = snd_set; +// d_membership_tc_cache[tc_rep] = pair_set; +// Node reason = getReason(tc_rep, tc_term, tc_r_rep, tc_term[0]); +// +// if(!reason.isNull()) { +// d_membership_tc_exp_cache[tc_rep] = reason; +// } +// } +// // check if tup_rep already exists in TC graph for conflict +// } else { +// if( tc_graph_it != d_membership_tc_cache.end() ) { +// checkTCGraphForConflict(atom, tc_rep, d_trueNode, RelsUtils::nthElementOfTuple(tup_rep, 0), +// RelsUtils::nthElementOfTuple(tup_rep, 1), tc_graph_it->second); +// } +// } } void TheorySetsRels::checkTCGraphForConflict (Node atom, Node tc_rep, Node exp, Node a, Node b, @@ -555,6 +579,33 @@ int TheorySetsRels::EqcInfo::counter = 0; } + void TheorySetsRels::finalizeTCInfer() { + Trace("rels-debug") << "[sets-rels] Finalizing transitive closure inferences!" << std::endl; + for(TC_IT tc_it = d_membership_tc_cache.begin(); tc_it != d_membership_tc_cache.end(); tc_it++) { + inferTC(tc_it->first, tc_it->second); + } + } + + void TheorySetsRels::inferTC(Node tc_rep, std::map< Node, std::hash_set< Node, NodeHashFunction > >& tc_graph) { + Trace("rels-debug") << "[sets-rels] Build TC graph for tc_rep = " << tc_rep << std::endl; + for(std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator pair_set_it = tc_graph.begin(); + pair_set_it != tc_graph.end(); pair_set_it++) { + for(std::hash_set< Node, NodeHashFunction >::iterator set_it = pair_set_it->second.begin(); + set_it != pair_set_it->second.end(); set_it++) { + std::hash_set elements; + Node pair = constructPair(tc_rep, pair_set_it->first, *set_it); + Node exp = findMemExp(tc_rep, pair); + + if(d_membership_tc_exp_cache.find(tc_rep) != d_membership_tc_exp_cache.end()) { + exp = AND(d_membership_tc_exp_cache[tc_rep], exp); + } + Assert(!exp.isNull()); + elements.insert(pair_set_it->first); + inferTC( exp, tc_rep, tc_graph, pair_set_it->first, *set_it, elements ); + } + } + } + void TheorySetsRels::inferTC( Node exp, Node tc_rep, std::map< Node, std::hash_set< Node, NodeHashFunction > >& tc_graph, Node start_node, Node cur_node, std::hash_set< Node, NodeHashFunction >& traversed ) { Node pair = constructPair(tc_rep, start_node, cur_node); @@ -584,33 +635,6 @@ int TheorySetsRels::EqcInfo::counter = 0; } } - void TheorySetsRels::finalizeTCInfer() { - Trace("rels-debug") << "[sets-rels] Finalizing transitive closure inferences!" << std::endl; - for(TC_IT tc_it = d_membership_tc_cache.begin(); tc_it != d_membership_tc_cache.end(); tc_it++) { - inferTC(tc_it->first, tc_it->second); - } - } - - void TheorySetsRels::inferTC(Node tc_rep, std::map< Node, std::hash_set< Node, NodeHashFunction > >& tc_graph) { - Trace("rels-debug") << "[sets-rels] Build TC graph for tc_rep = " << tc_rep << std::endl; - for(std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator pair_set_it = tc_graph.begin(); - pair_set_it != tc_graph.end(); pair_set_it++) { - for(std::hash_set< Node, NodeHashFunction >::iterator set_it = pair_set_it->second.begin(); - set_it != pair_set_it->second.end(); set_it++) { - std::hash_set elements; - Node pair = constructPair(tc_rep, pair_set_it->first, *set_it); - Node exp = findMemExp(tc_rep, pair); - - if(d_membership_tc_exp_cache.find(tc_rep) != d_membership_tc_exp_cache.end()) { - exp = AND(d_membership_tc_exp_cache[tc_rep], exp); - } - Assert(!exp.isNull()); - elements.insert(pair_set_it->first); - inferTC( exp, tc_rep, tc_graph, pair_set_it->first, *set_it, elements ); - } - } - } - // Bottom-up fashion to compute relations void TheorySetsRels::computeRels(Node n) { Trace("rels-debug") << "\n[sets-rels] computeJoinOrProductRelations for relation " << n << std::endl; @@ -793,7 +817,6 @@ int TheorySetsRels::EqcInfo::counter = 0; Trace("rels-lemma") << "[sets-rels-lemma] Process pending lemma : " << d_lemma_cache[i] << std::endl; d_sets_theory.d_out->lemma( d_lemma_cache[i] ); -// d_sets_theory.d_out->conflict() } for( std::map::iterator child_it = d_pending_facts.begin(); child_it != d_pending_facts.end(); child_it++ ) { @@ -806,7 +829,11 @@ int TheorySetsRels::EqcInfo::counter = 0; << child_it->first << " with reason " << child_it->second << std::endl; d_sets_theory.d_out->lemma(NodeManager::currentNM()->mkNode(kind::IMPLIES, child_it->second, child_it->first)); } + doTCLemmas(); } + + + d_tc_membership_db.clear(); d_rel_nodes.clear(); d_pending_facts.clear(); d_membership_constraints_cache.clear(); @@ -823,6 +850,94 @@ int TheorySetsRels::EqcInfo::counter = 0; d_node_id.clear(); } + void TheorySetsRels::doTCLemmas() { + std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator mem_it = d_tc_membership_db.begin(); + + while(mem_it != d_tc_membership_db.end()) { + Node tc_rep = getRepresentative(mem_it->first); + Node tc_r_rep = getRepresentative(mem_it->first[0]); + + // build the TC graph for tc_rep if it was not created before + if( d_rel_nodes.find(tc_rep) == d_rel_nodes.end() ) { + Trace("rels-debug") << "[sets-rels] Start building the TC graph for relation " << mem_it->first << std::endl; + buildTCGraph(tc_r_rep, tc_rep, mem_it->first); + d_rel_nodes.insert(tc_rep); + } + + std::hash_set< Node, NodeHashFunction >::iterator set_it = mem_it->second.begin(); + + while(set_it != mem_it->second.end()) { + std::hash_set hasSeen; + Node fst = RelsUtils::nthElementOfTuple(*set_it, 0); + Node snd = RelsUtils::nthElementOfTuple(*set_it, 1); + Node fst_rep = getRepresentative(fst); + Node snd_rep = getRepresentative(RelsUtils::nthElementOfTuple(*set_it, 1)); + TC_IT tc_graph_it = d_membership_tc_cache.find(tc_rep); + + if((tc_graph_it != d_membership_tc_cache.end() && !isTCReachable(fst_rep, snd_rep, hasSeen, tc_graph_it->second)) || + (tc_graph_it == d_membership_tc_cache.end())) { + Node reason = explain(MEMBER(*set_it, mem_it->first)); + Node sk_1 = NodeManager::currentNM()->mkSkolem("sde", fst_rep.getType()); + Node sk_2 = NodeManager::currentNM()->mkSkolem("sde", snd_rep.getType()); + Node mem_of_r = MEMBER(RelsUtils::constructPair(tc_r_rep, fst_rep, snd_rep), tc_r_rep); + Node sk_eq = EQUAL(sk_1, sk_2); + + if(fst_rep != fst) { + reason = AND(reason, explain(EQUAL(fst_rep, fst))); + } + if(snd_rep != snd) { + reason = AND(reason, explain(EQUAL(snd_rep, snd))); + } + if(tc_r_rep != mem_it->first[0]) { + reason = AND(reason, explain(EQUAL(tc_r_rep, mem_it->first[0]))); + } + if(tc_rep != mem_it->first) { + reason = AND(reason, explain(EQUAL(tc_r_rep, mem_it->first))); + } + + Node tc_lemma = NodeManager::currentNM()->mkNode(kind::IMPLIES, reason, + OR(mem_of_r, + (AND(MEMBER(RelsUtils::constructPair(tc_r_rep, fst_rep, sk_1), tc_r_rep), + (AND(MEMBER(RelsUtils::constructPair(tc_r_rep, sk_2, snd_rep), tc_r_rep), + (OR(sk_eq, MEMBER(RelsUtils::constructPair(tc_rep, sk_1, sk_2), tc_rep))))))))); + Trace("rels-lemma") << "[sets-rels-lemma] Process a TC lemma : " + << tc_lemma << std::endl; + d_sets_theory.d_out->lemma(tc_lemma); + d_sets_theory.d_out->requirePhase(Rewriter::rewrite(mem_of_r), true); + d_sets_theory.d_out->requirePhase(Rewriter::rewrite(sk_eq), true); + } + set_it++; + } + mem_it++; + } + } + + bool TheorySetsRels::isTCReachable(Node start, Node dest, std::hash_set& hasSeen, + std::map< Node, std::hash_set< Node, NodeHashFunction > >& tc_graph) { + if(hasSeen.find(start) == hasSeen.end()) { + hasSeen.insert(start); + } + + TC_PAIR_IT pair_set_it = tc_graph.find(start); + + if(pair_set_it != tc_graph.end()) { + if(pair_set_it->second.find(dest) != pair_set_it->second.end()) { + return true; + } else { + std::hash_set< Node, NodeHashFunction >::iterator set_it = pair_set_it->second.begin(); + + while(set_it != pair_set_it->second.end()) { + // need to check if *set_it has been looked already + if(hasSeen.find(*set_it) == hasSeen.end()) { + isTCReachable(*set_it, dest, hasSeen, tc_graph); + } + set_it++; + } + } + } + return false; + } + void TheorySetsRels::sendSplit(Node a, Node b, const char * c) { Node eq = a.eqNode( b ); Node neq = NOT( eq ); @@ -931,10 +1046,6 @@ int TheorySetsRels::EqcInfo::counter = 0; return false; } - bool TheorySetsRels::checkCycles(Node join_term) { - return false; - } - bool TheorySetsRels::safeAddToMap(std::map< Node, std::vector >& map, Node rel_rep, Node member) { std::map< Node, std::vector< Node > >::iterator mem_it = map.find(rel_rep); if(mem_it == map.end()) { @@ -1394,30 +1505,19 @@ int TheorySetsRels::EqcInfo::counter = 0; Trace("rels-std") << "Start making TC inference after adding a member " << mem_rep << " to " << tc_ei->d_tc.get() << std::endl; NodeMap::iterator map_it = tc_ei->d_mem_exp.begin(); - while(map_it != tc_ei->d_mem_exp.end()) { - Trace("rels-debug") << " mem = "<< (*map_it).first << " exp = " << (*map_it).second<< std::endl; - map_it++; - } Node exp = explainTCMem(tc_ei, mem_rep, fst_rep, snd_rep); Assert(!exp.isNull()); Node tc_lemma = NodeManager::currentNM()->mkNode(kind::IMPLIES, exp, MEMBER(mem_rep, tc_ei->d_tc.get())); d_pending_merge.push_back(tc_lemma); d_lemma.insert(tc_lemma); - Trace("rels-std") << "Start making TC inference after adding a member " << mem_rep << " to " << tc_ei->d_tc.get() - << " in_reachable size = " << in_reachable.size() - << " out_reachable size = " << out_reachable.size() - << " ***** 2" << std::endl; - std::hash_set::iterator in_reachable_it = in_reachable.begin(); while(in_reachable_it != in_reachable.end()) { Node in_node = d_id_node[*in_reachable_it]; Node in_pair = RelsUtils::constructPair(tc_ei->d_tc.get(), in_node, fst_rep); Node new_pair = RelsUtils::constructPair(tc_ei->d_tc.get(), in_node, snd_rep); - - Trace("rels-std") << "Reason for " << in_pair << " " << explainTCMem(tc_ei, in_pair, in_node, fst_rep) << std::endl; - Node reason = AND(explainTCMem(tc_ei, in_pair, in_node, fst_rep), exp); + tc_ei->d_mem_exp[new_pair] = reason; tc_ei->d_mem.insert(new_pair); Node tc_lemma = NodeManager::currentNM()->mkNode(kind::IMPLIES, reason, MEMBER(new_pair, tc_ei->d_tc.get())); @@ -1425,7 +1525,7 @@ int TheorySetsRels::EqcInfo::counter = 0; d_lemma.insert(tc_lemma); in_reachable_it++; } - Trace("rels-std") << "Start making TC inference after adding a member " << mem_rep << " to " << tc_ei->d_tc.get() << " ***** 3" << std::endl; + std::hash_set::iterator out_reachable_it = out_reachable.begin(); while(out_reachable_it != out_reachable.end()) { Node out_node = d_id_node[*out_reachable_it]; @@ -1441,9 +1541,7 @@ int TheorySetsRels::EqcInfo::counter = 0; Node in_pair_exp = explainTCMem(tc_ei, in_pair, in_node, snd_rep); Assert(in_pair_exp != Node::null()); - reason = AND(reason, in_pair_exp); - - Trace("rels-std") << "Start making TC inference after adding a member " << mem_rep << " to " << tc_ei->d_tc.get() << " ***** 3 9" << std::endl; + reason = AND(reason, in_pair_exp); tc_ei->d_mem_exp[new_pair] = reason; tc_ei->d_mem.insert(new_pair); Node tc_lemma = NodeManager::currentNM()->mkNode(kind::IMPLIES, reason, MEMBER(new_pair, tc_ei->d_tc.get())); @@ -1453,7 +1551,6 @@ int TheorySetsRels::EqcInfo::counter = 0; } out_reachable_it++; } - Trace("rels-std") << "Start making TC inference after adding a member " << mem_rep << " to " << tc_ei->d_tc.get() << " ***** 4" << std::endl; } void TheorySetsRels::collectInReachableNodes(EqcInfo* tc_ei, int start_id, std::hash_set& in_reachable, bool firstRound) { diff --git a/src/theory/sets/theory_sets_rels.h b/src/theory/sets/theory_sets_rels.h index 381ccddd9..5a1985d4e 100644 --- a/src/theory/sets/theory_sets_rels.h +++ b/src/theory/sets/theory_sets_rels.h @@ -129,6 +129,8 @@ private: std::map< Node, std::vector > d_membership_db; std::map< Node, std::vector > d_membership_exp_db; std::map< Node, Node > d_membership_tc_exp_cache; + + std::map< Node, std::hash_set > d_tc_membership_db; std::map< Node, std::map > > d_terms_cache; std::map< Node, std::map< Node, std::hash_set > > d_membership_tc_cache; @@ -173,9 +175,12 @@ private: void inferTC( Node, std::map< Node, std::hash_set< Node, NodeHashFunction > >& ); void inferTC( Node, Node, std::map< Node, std::hash_set< Node, NodeHashFunction > >&, Node, Node, std::hash_set< Node, NodeHashFunction >&); + bool isTCReachable(Node fst, Node snd, std::hash_set& hasSeen, + std::map< Node, std::hash_set< Node, NodeHashFunction > >& tc_graph); Node explain(Node); + void doTCLemmas(); void sendInfer( Node fact, Node exp, const char * c ); void sendLemma( Node fact, Node reason, const char * c ); void sendSplit( Node a, Node b, const char * c ); @@ -183,7 +188,6 @@ private: void doPendingSplitFacts(); void addSharedTerm( TNode n ); void checkTCGraphForConflict( Node, Node, Node, Node, Node, std::map< Node, std::hash_set< Node, NodeHashFunction > >& ); - bool checkCycles( Node ); // Helper functions bool insertIntoIdList(IdList&, int); -- 2.30.2