From 11b89d6b6cd49e4a318511012128c9cb93ad689a Mon Sep 17 00:00:00 2001 From: PaulMeng Date: Thu, 5 May 2016 14:58:10 -0500 Subject: [PATCH] change to use tuple element representatives to build TC graph for full effort --- src/theory/sets/theory_sets_rels.cpp | 176 +++++++++++++-------------- src/theory/sets/theory_sets_rels.h | 11 +- 2 files changed, 92 insertions(+), 95 deletions(-) diff --git a/src/theory/sets/theory_sets_rels.cpp b/src/theory/sets/theory_sets_rels.cpp index 428027acc..ccb917d5f 100644 --- a/src/theory/sets/theory_sets_rels.cpp +++ b/src/theory/sets/theory_sets_rels.cpp @@ -51,6 +51,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P Assert(d_pending_facts.empty()); } else { doPendingMerge(); + doPendingLemmas(); } Trace("rels") << "\n[sets-rels] ******************************* Done with the relational solver *******************************\n" << std::endl; } @@ -208,13 +209,15 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P if(mem_it != d_membership_db.end()) { for(std::vector::iterator pair_it = mem_it->second.begin(); pair_it != mem_it->second.end(); pair_it++) { - TC_PAIR_IT pair_set_it = tc_graph.find(RelsUtils::nthElementOfTuple(*pair_it, 0)); + Node fst_rep = getRepresentative(RelsUtils::nthElementOfTuple(*pair_it, 0)); + Node snd_rep = getRepresentative(RelsUtils::nthElementOfTuple(*pair_it, 1)); + TC_PAIR_IT pair_set_it = tc_graph.find(fst_rep); if( pair_set_it != tc_graph.end() ) { - pair_set_it->second.insert(RelsUtils::nthElementOfTuple(*pair_it, 1)); + pair_set_it->second.insert(snd_rep); } else { - std::hash_set< Node, NodeHashFunction > snd_pair_set; - snd_pair_set.insert(RelsUtils::nthElementOfTuple(*pair_it, 1)); - tc_graph[RelsUtils::nthElementOfTuple(*pair_it, 0)] = snd_pair_set; + std::hash_set< Node, NodeHashFunction > snd_set; + snd_set.insert(snd_rep); + tc_graph[fst_rep] = snd_set; } } } @@ -243,14 +246,16 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P // insert tup_rep in the tc_graph if it is not in the graph already TC_IT tc_graph_it = d_membership_tc_cache.find(tc_rep); if(polarity) { + Node fst_rep = getRepresentative(RelsUtils::nthElementOfTuple(tup_rep, 0)); + Node snd_rep = getRepresentative(RelsUtils::nthElementOfTuple(tup_rep, 1)); if(tc_graph_it != d_membership_tc_cache.end()) { - TC_PAIR_IT pair_set_it = tc_graph_it->second.find(RelsUtils::nthElementOfTuple(tup_rep, 0)); + TC_PAIR_IT pair_set_it = tc_graph_it->second.find(fst_rep); if(pair_set_it != tc_graph_it->second.end()) { - pair_set_it->second.insert(RelsUtils::nthElementOfTuple(tup_rep, 1)); + pair_set_it->second.insert(snd_rep); } else { std::hash_set< Node, NodeHashFunction > pair_set; - pair_set.insert(RelsUtils::nthElementOfTuple(tup_rep, 1)); - tc_graph_it->second[RelsUtils::nthElementOfTuple(tup_rep, 0)] = pair_set; + pair_set.insert(snd_rep); + tc_graph_it->second[fst_rep] = pair_set; } Node reason = getReason(tc_rep, tc_term, tc_r_rep, tc_term[0]); std::map< Node, Node >::iterator exp_it = d_membership_tc_exp_cache.find(tc_rep); @@ -259,9 +264,9 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P } } else { std::map< Node, std::hash_set< Node, NodeHashFunction > > pair_set; - std::hash_set< Node, NodeHashFunction > set; - set.insert(RelsUtils::nthElementOfTuple(tup_rep, 1)); - pair_set[RelsUtils::nthElementOfTuple(tup_rep, 0)] = set; + std::hash_set< Node, NodeHashFunction > snd_set; + snd_set.insert(snd_rep); + pair_set[fst_rep] = snd_set; d_membership_tc_cache[tc_rep] = pair_set; Node reason = getReason(tc_rep, tc_term, tc_r_rep, tc_term[0]); if(!reason.isNull()) { @@ -300,7 +305,6 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P // check if tup_rep already exists in TC graph for conflict } else { if(tc_graph_it != d_membership_tc_cache.end()) { - Trace("rels-debug") << "********** tc reach here 0" << std::endl; checkTCGraphForConflict(atom, tc_rep, d_trueNode, RelsUtils::nthElementOfTuple(tup_rep, 0), RelsUtils::nthElementOfTuple(tup_rep, 1), tc_graph_it->second); } @@ -310,11 +314,8 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P void TheorySetsRels::checkTCGraphForConflict (Node atom, Node tc_rep, Node exp, Node a, Node b, std::map< Node, std::hash_set< Node, NodeHashFunction > >& pair_set) { TC_PAIR_IT pair_set_it = pair_set.find(a); - Trace("rels-debug") << "********** tc reach here 1" << " a = " << a << " b = " << b << std::endl; if(pair_set_it != pair_set.end()) { - Trace("rels-debug") << "********** tc reach here 2" << std::endl; if(pair_set_it->second.find(b) != pair_set_it->second.end()) { - Trace("rels-debug") << "********** tc reach here 3" << std::endl; Node reason = AND(exp, findMemExp(tc_rep, constructPair(tc_rep, a, b))); if(atom[1] != tc_rep) { reason = AND(exp, explain(EQUAL(atom[1], tc_rep))); @@ -326,20 +327,17 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P // << AND(reason.negate(), atom) << std::endl; // d_sets_theory.d_out->conflict(AND(reason.negate(), atom)); } else { - Trace("rels-debug") << "********** tc reach here 4" << std::endl; std::hash_set< Node, NodeHashFunction >::iterator set_it = pair_set_it->second.begin(); while(set_it != pair_set_it->second.end()) { // need to check if *set_it has been looked already if(!areEqual(*set_it, a)) { checkTCGraphForConflict(atom, tc_rep, AND(exp, findMemExp(tc_rep, constructPair(tc_rep, a, *set_it))), *set_it, b, pair_set); - } - Trace("rels-debug") << "********** looping here 6 *set_it = " << *set_it << std::endl; + } set_it++; } } } - Trace("rels-debug") << "********** tc reach here 5" << std::endl; } @@ -576,43 +574,38 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P void TheorySetsRels::inferTC( Node exp, Node tc_rep, std::map< Node, std::hash_set< Node, NodeHashFunction > >& tc_graph, - Node start_node, Node cur_node, std::hash_set< Node, NodeHashFunction >& elements, bool first_round ) { + Node start_node, Node cur_node, std::hash_set< Node, NodeHashFunction >& traversed ) { Node pair = constructPair(tc_rep, start_node, cur_node); - if(safeAddToMap(d_membership_db, tc_rep, pair)) { - addToMap(d_membership_exp_cache, tc_rep, Rewriter::rewrite(exp)); - sendLemma( MEMBER(pair, tc_rep), Rewriter::rewrite(exp), "Transitivity" ); + std::map >::iterator mem_it = d_membership_db.find(tc_rep); + if(mem_it != d_membership_db.end()) { + if(std::find(mem_it->second.begin(), mem_it->second.end(), pair) == mem_it->second.end()) { + sendLemma( MEMBER(pair, tc_rep), Rewriter::rewrite(exp), "Transitivity" ); + } } - // check if cur_node has been traversed or not - if(!first_round) { - std::hash_set< Node, NodeHashFunction >::iterator ele_it = elements.begin(); - while(ele_it != elements.end()) { - if(areEqual(cur_node, *ele_it)) { - return; - } - ele_it++; - } + if(traversed.find(cur_node) != traversed.end()) { + return; } - std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator pair_set_it = tc_graph.begin(); + traversed.insert(cur_node); Node reason = exp; - while(pair_set_it != tc_graph.end()) { - if(areEqual(pair_set_it->first, cur_node)) { - reason = AND(exp, EQUAL(pair_set_it->first, cur_node)); - break; - } - pair_set_it++; - } - if(pair_set_it != tc_graph.end()) { - for(std::hash_set< Node, NodeHashFunction >::iterator set_it = pair_set_it->second.begin(); - set_it != pair_set_it->second.end(); set_it++) { - Node p = constructPair( tc_rep, cur_node, *set_it ); + std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator cur_set = tc_graph.find(cur_node); + if(cur_set != tc_graph.end()) { + for(std::hash_set< Node, NodeHashFunction >::iterator set_it = cur_set->second.begin(); + set_it != cur_set->second.end(); set_it++) { + Node new_pair = constructPair( tc_rep, cur_node, *set_it ); Assert(!reason.isNull()); - elements.insert(*set_it); - inferTC( AND( findMemExp(tc_rep, p), reason ), tc_rep, tc_graph, start_node, *set_it, elements, false ); + inferTC( AND( findMemExp(tc_rep, new_pair), reason ), tc_rep, tc_graph, start_node, *set_it, traversed ); } } } + void TheorySetsRels::finalizeTCInfer() { + Trace("rels-debug") << "[sets-rels] Finalizing transitive closure inferences!" << std::endl; + for(TC_IT tc_it = d_membership_tc_cache.begin(); tc_it != d_membership_tc_cache.end(); tc_it++) { + inferTC(tc_it->first, tc_it->second); + } + } + void TheorySetsRels::inferTC(Node tc_rep, std::map< Node, std::hash_set< Node, NodeHashFunction > >& tc_graph) { Trace("rels-debug") << "[sets-rels] Build TC graph for tc_rep = " << tc_rep << std::endl; for(std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator pair_set_it = tc_graph.begin(); @@ -627,19 +620,11 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P } Assert(!exp.isNull()); elements.insert(pair_set_it->first); - elements.insert(*set_it); - inferTC( exp, tc_rep, tc_graph, pair_set_it->first, *set_it, elements, true ); + inferTC( exp, tc_rep, tc_graph, pair_set_it->first, *set_it, elements ); } } } - void TheorySetsRels::finalizeTCInfer() { - Trace("rels-debug") << "[sets-rels] Finalizing transitive closure inferences!" << std::endl; - for(TC_IT tc_it = d_membership_tc_cache.begin(); tc_it != d_membership_tc_cache.end(); tc_it++) { - inferTC(tc_it->first, tc_it->second); - } - } - // Bottom-up fashion to compute relations void TheorySetsRels::computeRels(Node n) { Trace("rels-debug") << "\n[sets-rels] computeJoinOrProductRelations for relation " << n << std::endl; @@ -770,11 +755,9 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P reasons.push_back(explain(r1_exps[i])); reasons.push_back(explain(r2_exps[j])); if(r1_exps[i].getKind() == kind::MEMBER && r1_exps[i][0] != r1_elements[i]) { - Trace("rels-debug") << "************* $ r1 ele = " << r1_elements[i] << " r1 exp ele = " << r1_exps[i][0] << std::endl; reasons.push_back(explain(EQUAL(r1_elements[i], r1_exps[i][0]))); } if(r2_exps[j].getKind() == kind::MEMBER && r2_exps[j][0] != r2_elements[j]) { - Trace("rels-debug") << "************* $ r2 ele = " << r2_elements[j] << " r2 exp ele = " << r2_exps[j][0] << std::endl; reasons.push_back(explain(EQUAL(r2_elements[j], r2_exps[j][0]))); } if(!isProduct) { @@ -935,6 +918,8 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P } bool TheorySetsRels::areEqual( Node a, Node b ){ + Assert(a.getType() == b.getType()); + Trace("rels-eq") << "[sets-rels]**** checking equality between " << a << " and " << b << std::endl; if(a == b) { return true; } else if( hasTerm( a ) && hasTerm( b ) ){ @@ -1001,19 +986,23 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P // tuple might be a member of tc_rep; or it might be a member of rels or tc_terms such that // tc_terms are transitive closure of rels and are modulo equal to tc_rep - Node TheorySetsRels::findMemExp(Node tc_rep, Node tuple) { - Trace("rels-exp") << "TheorySetsRels::findMemExp ( tc_rep = " << tc_rep << ", tuple = " << tuple << ")" << std::endl; + Node TheorySetsRels::findMemExp(Node tc_rep, Node pair) { + Node fst = RelsUtils::nthElementOfTuple(pair, 0); + Node snd = RelsUtils::nthElementOfTuple(pair, 1); + Trace("rels-exp") << "TheorySetsRels::findMemExp ( tc_rep = " << tc_rep << ", pair = " << pair << ")" << std::endl; std::vector tc_terms = d_terms_cache.find(tc_rep)->second[kind::TCLOSURE]; Assert(tc_terms.size() > 0); for(unsigned int i = 0; i < tc_terms.size(); i++) { Node tc_term = tc_terms[i]; Node tc_r_rep = getRepresentative(tc_term[0]); - Trace("rels-exp") << "TheorySetsRels::findMemExp ( r_rep = " << tc_r_rep << ", tuple = " << tuple << ")" << std::endl; + Trace("rels-exp") << "TheorySetsRels::findMemExp ( r_rep = " << tc_r_rep << ", pair = " << pair << ")" << std::endl; std::map< Node, std::vector< Node > >::iterator tc_r_mems = d_membership_db.find(tc_r_rep); if(tc_r_mems != d_membership_db.end()) { for(unsigned int i = 0; i < tc_r_mems->second.size(); i++) { - if(areEqual(tc_r_mems->second[i], tuple)) { + Node fst_mem = RelsUtils::nthElementOfTuple(tc_r_mems->second[i], 0); + Node snd_mem = RelsUtils::nthElementOfTuple(tc_r_mems->second[i], 1); + if(areEqual(fst_mem, fst) && areEqual(snd_mem, snd)) { Node exp = MEMBER(tc_r_mems->second[i], tc_r_mems->first); if(tc_r_rep != tc_term[0]) { exp = explain(EQUAL(tc_r_rep, tc_term[0])); @@ -1021,14 +1010,14 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P if(tc_rep != tc_term) { exp = AND(exp, explain(EQUAL(tc_rep, tc_term))); } - if(tc_r_mems->second[i] != tuple) { - if(RelsUtils::nthElementOfTuple(tc_r_mems->second[i], 0) != RelsUtils::nthElementOfTuple(tuple, 0)) { - exp = AND(exp, explain(EQUAL(RelsUtils::nthElementOfTuple(tc_r_mems->second[i], 0), RelsUtils::nthElementOfTuple(tuple, 0)))); + if(tc_r_mems->second[i] != pair) { + if(fst_mem != fst) { + exp = AND(exp, explain(EQUAL(fst_mem, fst))); } - if(RelsUtils::nthElementOfTuple(tc_r_mems->second[i], 1) != RelsUtils::nthElementOfTuple(tuple, 1)) { - exp = AND(exp, explain(EQUAL(RelsUtils::nthElementOfTuple(tc_r_mems->second[i], 1), RelsUtils::nthElementOfTuple(tuple, 1)))); + if(snd_mem != snd) { + exp = AND(exp, explain(EQUAL(snd_mem, snd))); } - exp = AND(exp, EQUAL(tc_r_mems->second[i], tuple)); + exp = AND(exp, EQUAL(tc_r_mems->second[i], pair)); } return Rewriter::rewrite(AND(exp, explain(d_membership_exp_db[tc_r_rep][i]))); } @@ -1037,10 +1026,11 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P Node tc_term_rep = getRepresentative(tc_terms[i]); std::map< Node, std::vector< Node > >::iterator tc_t_mems = d_membership_db.find(tc_term_rep); - Trace("rels-exp") << "TheorySetsRels::findMemExp ( tc_t_rep = " << tc_term_rep << ", tuple = " << tuple << ")" << std::endl; if(tc_t_mems != d_membership_db.end()) { for(unsigned int j = 0; j < tc_t_mems->second.size(); j++) { - if(areEqual(tc_t_mems->second[j], tuple)) { + Node fst_mem = RelsUtils::nthElementOfTuple(tc_t_mems->second[j], 0); + Node snd_mem = RelsUtils::nthElementOfTuple(tc_t_mems->second[j], 1); + if(areEqual(fst_mem, fst) && areEqual(snd_mem, snd)) { Node exp = MEMBER(tc_t_mems->second[j], tc_t_mems->first); if(tc_rep != tc_terms[i]) { exp = AND(exp, explain(EQUAL(tc_rep, tc_terms[i]))); @@ -1048,14 +1038,14 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P if(tc_term_rep != tc_terms[i]) { exp = AND(exp, explain(EQUAL(tc_term_rep, tc_terms[i]))); } - if(tc_t_mems->second[j] != tuple) { - if(RelsUtils::nthElementOfTuple(tc_t_mems->second[j], 0) != RelsUtils::nthElementOfTuple(tuple, 0)) { - exp = AND(exp, explain(EQUAL(RelsUtils::nthElementOfTuple(tc_t_mems->second[j], 0), RelsUtils::nthElementOfTuple(tuple, 0)))); + if(tc_t_mems->second[j] != pair) { + if(fst_mem != fst) { + exp = AND(exp, explain(EQUAL(fst_mem, fst))); } - if(RelsUtils::nthElementOfTuple(tc_t_mems->second[j], 1) != RelsUtils::nthElementOfTuple(tuple, 1)) { - exp = AND(exp, explain(EQUAL(RelsUtils::nthElementOfTuple(tc_t_mems->second[j], 1), RelsUtils::nthElementOfTuple(tuple, 1)))); + if(snd_mem != snd) { + exp = AND(exp, explain(EQUAL(snd_mem, snd))); } - exp = AND(exp, EQUAL(tc_t_mems->second[j], tuple)); + exp = AND(exp, EQUAL(tc_t_mems->second[j], pair)); } return Rewriter::rewrite(AND(exp, explain(d_membership_exp_db[tc_term_rep][j]))); } @@ -1087,7 +1077,8 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P } void TheorySetsRels::makeSharedTerm( Node n ) { - if(d_shared_terms.find(n) == d_shared_terms.end() && !n.getType().isBoolean()) { + Trace("rels-share") << " [sets-rels] making shared term " << n << std::endl; + if(d_shared_terms.find(n) == d_shared_terms.end()) { Node skolem = NodeManager::currentNM()->mkSkolem( "sde", n.getType() ); sendLemma(MEMBER(skolem, SINGLETON(n)), d_trueNode, "share-term"); d_shared_terms.insert(n); @@ -1262,7 +1253,8 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P } TheorySetsRels::EqcInfo::EqcInfo( context::Context* c ) : - d_mem(c), d_not_mem(c), d_in(c), d_out(c), d_tc_mem_exp(c), d_tp(c), d_pt(c), d_join(c), d_tc(c) {} + counter(0), d_mem(c), d_not_mem(c), d_in(c), d_out(c), d_tc_mem_exp(c), + d_tp(c), d_pt(c), d_join(c), d_tc(c), d_id_in(c), d_id_out(c) {} void TheorySetsRels::eqNotifyNewClass( Node n ) { Trace("rels-std") << "[sets-rels] eqNotifyNewClass:" << " t = " << n << std::endl; @@ -1272,7 +1264,6 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P n.getKind() == kind::TCLOSURE)) { getOrMakeEqcInfo( n, true ); } - Trace("rels-std") << "[sets-rels] eqNotifyNewClass*****:" << " t = " << n << std::endl; } void TheorySetsRels::addTCMem(EqcInfo* tc_ei, Node mem) { Node fst = RelsUtils::nthElementOfTuple(mem, 0); @@ -1321,8 +1312,6 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P if(ei == NULL) { ei = getOrMakeEqcInfo( t2_1rep, true ); } - // might not need to store the membership info - // if we don't need to consider the eqc merge? if(polarity) { ei->d_mem.insert(t2[0]); } else { @@ -1345,8 +1334,8 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P if(polarity) { if(!ei->d_tc.get().isNull()) { addTCMem(ei, t2[0]); - ei->d_tc_mem_exp.insert(t2[0], t2); - sendInferTC(ei, t2[0], t2); + ei->d_tc_mem_exp.insert(t2[0], explain(t2)); + sendInferTC(ei, t2[0], explain(t2)); } else { std::vector tup_types = t2[1].getType().getSetElementType().getTupleTypes(); if( tup_types.size() == 2 && tup_types[0] == tup_types[1] ) { @@ -1354,7 +1343,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P EqcInfo* tc_ei = getOrMakeEqcInfo( tc_n ); if(tc_ei != NULL) { addTCMem(tc_ei, t2[0]); - Node exp = (tc_n == tc_ei->d_tc.get()) ? t2 : AND(EQUAL(tc_n, tc_ei->d_tc.get()), t2); + Node exp = (tc_n == tc_ei->d_tc.get()) ? explain(t2) : AND(EQUAL(tc_n, tc_ei->d_tc.get()), explain(t2)); tc_ei->d_tc_mem_exp.insert(t2[0], exp); sendInferTC(tc_ei, t2[0], exp); } @@ -1386,26 +1375,27 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P seen.insert(RelsUtils::nthElementOfTuple(mem, 0)); sendInferInTC(tc_ei, RelsUtils::nthElementOfTuple(mem, 0), RelsUtils::nthElementOfTuple(mem, 1), seen, exp); sendInferOutTC(tc_ei, RelsUtils::nthElementOfTuple(mem, 0), RelsUtils::nthElementOfTuple(mem, 1), seen, exp); + Trace("rels-std") << "[sets-rels] *** done with sendInferTC member = " << mem << " with explanation = " << exp << std::endl; } void TheorySetsRels::sendInferInTC(EqcInfo* tc_ei, Node fst, Node snd, std::hash_set seen, Node exp) { for(NodeListMap::iterator nl_it = tc_ei->d_in.begin(); nl_it != tc_ei->d_in.end(); nl_it++) { - if(areEqual((*nl_it).first, fst)) { - for(NodeList::const_iterator itr = (*nl_it).second->begin(); itr != (*nl_it).second->end(); itr++) { - Node pair = RelsUtils::constructPair(tc_ei->d_tc.get(), *itr, snd); + if((*nl_it).first == fst) { + for(NodeList::const_iterator in_itr = (*nl_it).second->begin(); in_itr != (*nl_it).second->end(); in_itr++) { + Node pair = RelsUtils::constructPair(tc_ei->d_tc.get(), *in_itr, snd); if(!tc_ei->d_mem.contains(pair)) { Node reason = ((*nl_it).first == fst) ? - Rewriter::rewrite(AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), *itr, fst)))): - Rewriter::rewrite(AND(EQUAL((*nl_it).first, fst), AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), *itr, fst))))); + Rewriter::rewrite(AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), *in_itr, fst)))): + Rewriter::rewrite(AND(EQUAL((*nl_it).first, fst), AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), *in_itr, fst))))); Node tc_lemma = NodeManager::currentNM()->mkNode(kind::IMPLIES, reason, MEMBER(pair, tc_ei->d_tc.get())); d_pending_merge.push_back(tc_lemma); d_lemma.insert(tc_lemma); tc_ei->d_mem.insert(pair); tc_ei->d_tc_mem_exp.insert(pair, reason); } - if(seen.find(*itr) == seen.end()) { - seen.insert(*itr); - sendInferInTC(tc_ei, *itr, snd, seen, AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), *itr, fst)))); + if(seen.find(*in_itr) == seen.end()) { + seen.insert(*in_itr); + sendInferInTC(tc_ei, *in_itr, snd, seen, AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), *in_itr, fst)))); } } } @@ -1414,7 +1404,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P void TheorySetsRels::sendInferOutTC(EqcInfo* tc_ei, Node fst, Node snd, std::hash_set seen, Node exp) { for(NodeListMap::iterator nl_it = tc_ei->d_out.begin(); nl_it != tc_ei->d_out.end(); nl_it++) { - if(areEqual((*nl_it).first, snd)) { + if((*nl_it).first == snd) { for(NodeList::const_iterator itr = (*nl_it).second->begin(); itr != (*nl_it).second->end(); itr++) { Node pair = RelsUtils::constructPair(tc_ei->d_tc.get(), fst, *itr); if(!tc_ei->d_mem.contains(pair)) { diff --git a/src/theory/sets/theory_sets_rels.h b/src/theory/sets/theory_sets_rels.h index 0d24c65b3..faee651b7 100644 --- a/src/theory/sets/theory_sets_rels.h +++ b/src/theory/sets/theory_sets_rels.h @@ -45,9 +45,11 @@ public: class TheorySetsRels { typedef context::CDChunkList NodeList; + typedef context::CDChunkList IdList; typedef context::CDHashSet NodeSet; typedef context::CDHashMap NodeBoolMap; typedef context::CDHashMap NodeListMap; + typedef context::CDHashMap IdListMap; typedef context::CDHashMap NodeSetMap; typedef context::CDHashMap NodeMap; @@ -76,6 +78,7 @@ private: public: EqcInfo( context::Context* c ); ~EqcInfo(){} + int counter; NodeSet d_mem; NodeSet d_not_mem; NodeListMap d_in; @@ -85,6 +88,10 @@ private: context::CDO< Node > d_pt; context::CDO< Node > d_join; context::CDO< Node > d_tc; + IdListMap d_id_in; + IdListMap d_id_out; + std::hash_map d_id_node; + std::hash_map d_node_id; }; /** has eqc info */ @@ -168,7 +175,7 @@ private: void finalizeTCInfer(); void inferTC( Node, std::map< Node, std::hash_set< Node, NodeHashFunction > >& ); void inferTC( Node, Node, std::map< Node, std::hash_set< Node, NodeHashFunction > >&, - Node, Node, std::hash_set< Node, NodeHashFunction >&, bool first_round = false); + Node, Node, std::hash_set< Node, NodeHashFunction >&); Node explain(Node); @@ -184,7 +191,7 @@ private: // Helper functions inline Node getReason(Node tc_rep, Node tc_term, Node tc_r_rep, Node tc_r); inline Node constructPair(Node tc_rep, Node a, Node b); - Node findMemExp(Node r, Node tuple); + Node findMemExp(Node r, Node pair); bool safeAddToMap( std::map< Node, std::vector >&, Node, Node ); void addToMap( std::map< Node, std::vector >&, Node, Node ); bool hasMember( Node, Node ); -- 2.30.2