From 6d22060b0ee92433bb65bf7e238619039bd7d9ed Mon Sep 17 00:00:00 2001 From: PaulMeng Date: Sat, 25 Jun 2016 17:55:09 -0400 Subject: [PATCH] reimplemented std effort for TC --- src/theory/sets/theory_sets_rels.cpp | 774 +++++++++++++++------------ src/theory/sets/theory_sets_rels.h | 123 +++-- 2 files changed, 497 insertions(+), 400 deletions(-) diff --git a/src/theory/sets/theory_sets_rels.cpp b/src/theory/sets/theory_sets_rels.cpp index ccb917d5f..e339740a3 100644 --- a/src/theory/sets/theory_sets_rels.cpp +++ b/src/theory/sets/theory_sets_rels.cpp @@ -36,10 +36,12 @@ namespace CVC4 { namespace theory { namespace sets { -typedef std::map > >::iterator TERM_IT; -typedef std::map > >::iterator TC_IT; -typedef std::map >::iterator MEM_IT; -typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_PAIR_IT; +typedef std::map > >::iterator TERM_IT; +typedef std::map > >::iterator TC_IT; +typedef std::map >::iterator MEM_IT; +typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_PAIR_IT; + +int TheorySetsRels::EqcInfo::counter = 0; void TheorySetsRels::check(Theory::Effort level) { Trace("rels") << "\n[sets-rels] ******************************* Start the relational solver *******************************\n" << std::endl; @@ -64,8 +66,9 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P // No relational terms found with rel_rep as its representative // But TRANSPOSE(rel_rep) may occur in the context if(d_terms_cache.find(rel_rep) == d_terms_cache.end()) { - Node tp_rel = NodeManager::currentNM()->mkNode(kind::TRANSPOSE, rel_rep); + Node tp_rel = NodeManager::currentNM()->mkNode(kind::TRANSPOSE, rel_rep); Node tp_rel_rep = getRepresentative(tp_rel); + if(d_terms_cache.find(tp_rel_rep) != d_terms_cache.end()) { for(unsigned int i = 0; i < m_it->second.size(); i++) { // Lazily apply transpose-occur rule. @@ -75,8 +78,8 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P } } else { for(unsigned int i = 0; i < m_it->second.size(); i++) { - Node exp = d_membership_exp_cache[rel_rep][i]; - std::map > kind_terms = d_terms_cache[rel_rep]; + Node exp = d_membership_exp_cache[rel_rep][i]; + std::map > kind_terms = d_terms_cache[rel_rep]; if(kind_terms.find(kind::TRANSPOSE) != kind_terms.end()) { std::vector tp_terms = kind_terms[kind::TRANSPOSE]; // exp is a membership term and tp_terms contains all @@ -120,12 +123,13 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P Trace("rels-debug") << "[sets-rels] Start collecting relational terms..." << std::endl; eq::EqClassesIterator eqcs_i = eq::EqClassesIterator( d_eqEngine ); while( !eqcs_i.isFinished() ){ - Node r = (*eqcs_i); - eq::EqClassIterator eqc_i = eq::EqClassIterator( r, d_eqEngine ); + Node r = (*eqcs_i); + eq::EqClassIterator eqc_i = eq::EqClassIterator( r, d_eqEngine ); Trace("rels-ee") << "[sets-rels-ee] term representative: " << r << std::endl; while( !eqc_i.isFinished() ){ Node n = (*eqc_i); Trace("rels-ee") << " term : " << n << std::endl; + if(getRepresentative(r) == getRepresentative(d_trueNode) || getRepresentative(r) == getRepresentative(d_falseNode)) { // collect membership info @@ -137,8 +141,9 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P reduceTupleVar(n); } else { if(safeAddToMap(d_membership_constraints_cache, rel_rep, tup_rep)) { - bool true_eq = areEqual(r, d_trueNode); - Node reason = true_eq ? n : n.negate(); + bool true_eq = areEqual(r, d_trueNode); + Node reason = true_eq ? n : n.negate(); + addToMap(d_membership_exp_cache, rel_rep, reason); Trace("rels-mem") << "[******] exp: " << reason << " for " << rel_rep << std::endl; if(true_eq) { @@ -153,13 +158,13 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P n.getKind() == kind::JOIN || n.getKind() == kind::PRODUCT || n.getKind() == kind::TCLOSURE ) { - std::map > rel_terms; - std::vector terms; + std::map > rel_terms; + std::vector terms; // No r record is found if( d_terms_cache.find(r) == d_terms_cache.end() ) { terms.push_back(n); - rel_terms[n.getKind()] = terms; - d_terms_cache[r] = rel_terms; + rel_terms[n.getKind()] = terms; + d_terms_cache[r] = rel_terms; } else { // No n's kind record is found if( d_terms_cache[r].find(n.getKind()) == d_terms_cache[r].end() ) { @@ -174,6 +179,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P } else if(n.getType().isTuple() && !n.isConst() && !n.isVar()) { for(unsigned int i = 0; i < n.getType().getTupleLength(); i++) { Node element = RelsUtils::nthElementOfTuple(n, i); + if(!element.isConst()) { makeSharedTerm(element); } @@ -206,12 +212,14 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P void TheorySetsRels::buildTCGraph(Node tc_r_rep, Node tc_rep, Node tc_term) { std::map< Node, std::hash_set< Node, NodeHashFunction > > tc_graph; MEM_IT mem_it = d_membership_db.find(tc_r_rep); + if(mem_it != d_membership_db.end()) { for(std::vector::iterator pair_it = mem_it->second.begin(); pair_it != mem_it->second.end(); pair_it++) { - Node fst_rep = getRepresentative(RelsUtils::nthElementOfTuple(*pair_it, 0)); - Node snd_rep = getRepresentative(RelsUtils::nthElementOfTuple(*pair_it, 1)); - TC_PAIR_IT pair_set_it = tc_graph.find(fst_rep); + Node fst_rep = getRepresentative(RelsUtils::nthElementOfTuple(*pair_it, 0)); + Node snd_rep = getRepresentative(RelsUtils::nthElementOfTuple(*pair_it, 1)); + TC_PAIR_IT pair_set_it = tc_graph.find(fst_rep); + if( pair_set_it != tc_graph.end() ) { pair_set_it->second.insert(snd_rep); } else { @@ -231,11 +239,11 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P void TheorySetsRels::applyTCRule(Node exp, Node tc_term) { Trace("rels-debug") << "\n[sets-rels] *********** Applying TRANSITIVE CLOSURE rule on " << tc_term << " with explanation " << exp << std::endl; - bool polarity = exp.getKind() != kind::NOT; - Node atom = polarity ? exp : exp[0]; - Node tup_rep = getRepresentative(atom[0]); - Node tc_rep = getRepresentative(tc_term); - Node tc_r_rep = getRepresentative(tc_term[0]); + bool polarity = exp.getKind() != kind::NOT; + Node atom = polarity ? exp : exp[0]; + Node tup_rep = getRepresentative(atom[0]); + Node tc_rep = getRepresentative(tc_term); + Node tc_r_rep = getRepresentative(tc_term[0]); // build the TC graph for tc_rep if it was not created before if( d_rel_nodes.find(tc_rep) == d_rel_nodes.end() ) { @@ -243,13 +251,17 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P buildTCGraph(tc_r_rep, tc_rep, tc_term); d_rel_nodes.insert(tc_rep); } + // insert tup_rep in the tc_graph if it is not in the graph already TC_IT tc_graph_it = d_membership_tc_cache.find(tc_rep); + if(polarity) { Node fst_rep = getRepresentative(RelsUtils::nthElementOfTuple(tup_rep, 0)); Node snd_rep = getRepresentative(RelsUtils::nthElementOfTuple(tup_rep, 1)); + if(tc_graph_it != d_membership_tc_cache.end()) { TC_PAIR_IT pair_set_it = tc_graph_it->second.find(fst_rep); + if(pair_set_it != tc_graph_it->second.end()) { pair_set_it->second.insert(snd_rep); } else { @@ -257,51 +269,26 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P pair_set.insert(snd_rep); tc_graph_it->second[fst_rep] = pair_set; } - Node reason = getReason(tc_rep, tc_term, tc_r_rep, tc_term[0]); - std::map< Node, Node >::iterator exp_it = d_membership_tc_exp_cache.find(tc_rep); + + Node reason = getReason(tc_rep, tc_term, tc_r_rep, tc_term[0]); + std::map< Node, Node >::iterator exp_it = d_membership_tc_exp_cache.find(tc_rep); + if(!reason.isNull() && exp_it->second != reason) { d_membership_tc_exp_cache[tc_rep] = Rewriter::rewrite(AND(exp_it->second, reason)); } } else { - std::map< Node, std::hash_set< Node, NodeHashFunction > > pair_set; - std::hash_set< Node, NodeHashFunction > snd_set; + std::map< Node, std::hash_set< Node, NodeHashFunction > > pair_set; + std::hash_set< Node, NodeHashFunction > snd_set; + snd_set.insert(snd_rep); - pair_set[fst_rep] = snd_set; - d_membership_tc_cache[tc_rep] = pair_set; + pair_set[fst_rep] = snd_set; + d_membership_tc_cache[tc_rep] = pair_set; Node reason = getReason(tc_rep, tc_term, tc_r_rep, tc_term[0]); + if(!reason.isNull()) { d_membership_tc_exp_cache[tc_rep] = reason; } } - // if(!d_tc_saver.contains(exp) && - // atom[0][0].getKind() != kind::SKOLEM && - // atom[0][1].getKind() != kind::SKOLEM) { - - // TypeNode k_type = tup_rep.getType().getTupleTypes()[1]; - // Node k_0 = NodeManager::currentNM()->mkSkolem("tc_sde_", k_type); - // Node k_1 = NodeManager::currentNM()->mkSkolem("tc_sde_", k_type); - // Node k_2 = NodeManager::currentNM()->mkSkolem("tc_sde_", k_type); - // Node k_3 = NodeManager::currentNM()->mkSkolem("tc_sde_", k_type); - // Node fact = NodeManager::currentNM()->mkNode( kind::AND, MEMBER(RelsUtils::constructPair(tc_rep, tup_rep[0], k_0), tc_r_rep), - // MEMBER(RelsUtils::constructPair(tc_rep, k_0, k_1), tc_r_rep), - // MEMBER(RelsUtils::constructPair(tc_rep, k_1, k_2), tc_r_rep), - // MEMBER(RelsUtils::constructPair(tc_rep, k_2, k_3), tc_r_rep), - // MEMBER(RelsUtils::constructPair(tc_rep, k_3, tup_rep[1]), tc_r_rep) ); - // Node reason = exp; - // if(tc_rep != tc_term) { - // reason = AND(reason, explain(EQUAL(tc_rep, tc_term))); - // } - // if(tc_r_rep != tc_term[0]) { - // reason = AND(reason, explain(EQUAL(tc_r_rep, tc_term[0]))); - // } - - // makeSharedTerm(k_0); - // makeSharedTerm(k_1); - // makeSharedTerm(k_2); - // makeSharedTerm(k_3); - // sendLemma( fact, reason, "tc-decompose" ); - // d_tc_saver.insert(exp); - // } // check if tup_rep already exists in TC graph for conflict } else { if(tc_graph_it != d_membership_tc_cache.end()) { @@ -314,9 +301,11 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P void TheorySetsRels::checkTCGraphForConflict (Node atom, Node tc_rep, Node exp, Node a, Node b, std::map< Node, std::hash_set< Node, NodeHashFunction > >& pair_set) { TC_PAIR_IT pair_set_it = pair_set.find(a); + if(pair_set_it != pair_set.end()) { if(pair_set_it->second.find(b) != pair_set_it->second.end()) { Node reason = AND(exp, findMemExp(tc_rep, constructPair(tc_rep, a, b))); + if(atom[1] != tc_rep) { reason = AND(exp, explain(EQUAL(atom[1], tc_rep))); } @@ -328,6 +317,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P // d_sets_theory.d_out->conflict(AND(reason.negate(), atom)); } else { std::hash_set< Node, NodeHashFunction >::iterator set_it = pair_set_it->second.begin(); + while(set_it != pair_set_it->second.end()) { // need to check if *set_it has been looked already if(!areEqual(*set_it, a)) { @@ -352,24 +342,25 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P void TheorySetsRels::applyProductRule(Node exp, Node product_term) { Trace("rels-debug") << "\n[sets-rels] *********** Applying PRODUCT rule " << std::endl; + if(d_rel_nodes.find(product_term) == d_rel_nodes.end()) { computeRels(product_term); d_rel_nodes.insert(product_term); } - bool polarity = exp.getKind() != kind::NOT; - Node atom = polarity ? exp : exp[0]; - Node r1_rep = getRepresentative(product_term[0]); - Node r2_rep = getRepresentative(product_term[1]); + bool polarity = exp.getKind() != kind::NOT; + Node atom = polarity ? exp : exp[0]; + Node r1_rep = getRepresentative(product_term[0]); + Node r2_rep = getRepresentative(product_term[1]); + Trace("rels-debug") << "\n[sets-rels] Apply PRODUCT-SPLIT rule on term: " << product_term << " with explanation: " << exp << std::endl; - std::vector r1_element; - std::vector r2_element; - - NodeManager *nm = NodeManager::currentNM(); - Datatype dt = r1_rep.getType().getSetElementType().getDatatype(); - unsigned int i = 0; - unsigned int s1_len = r1_rep.getType().getSetElementType().getTupleLength(); - unsigned int tup_len = product_term.getType().getSetElementType().getTupleLength(); + std::vector r1_element; + std::vector r2_element; + NodeManager *nm = NodeManager::currentNM(); + Datatype dt = r1_rep.getType().getSetElementType().getDatatype(); + unsigned int i = 0; + unsigned int s1_len = r1_rep.getType().getSetElementType().getTupleLength(); + unsigned int tup_len = product_term.getType().getSetElementType().getTupleLength(); r1_element.push_back(Node::fromExpr(dt[0].getConstructor())); for(; i < s1_len; ++i) { @@ -384,12 +375,12 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P Node fact_1; Node fact_2; - Node reason_1 = exp; - Node reason_2 = exp; - Node t1 = nm->mkNode(kind::APPLY_CONSTRUCTOR, r1_element); - Node t1_rep = getRepresentative(t1); - Node t2 = nm->mkNode(kind::APPLY_CONSTRUCTOR, r2_element); - Node t2_rep = getRepresentative(t2); + Node reason_1 = exp; + Node reason_2 = exp; + Node t1 = nm->mkNode(kind::APPLY_CONSTRUCTOR, r1_element); + Node t1_rep = getRepresentative(t1); + Node t2 = nm->mkNode(kind::APPLY_CONSTRUCTOR, r2_element); + Node t2_rep = getRepresentative(t2); fact_1 = MEMBER( t1, r1_rep ); fact_2 = MEMBER( t2, r2_rep ); @@ -408,14 +399,6 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P if(polarity) { sendInfer(fact_1, reason_1, "product-split"); sendInfer(fact_2, reason_2, "product-split"); - // if(safeAddToMap(d_membership_db, r1_rep, t1_rep)) { - // addToMap(d_membership_exp_db, r1_rep, reason_1); - // } - // if(safeAddToMap(d_membership_db, r2_rep, t2_rep)) { - // addToMap(d_membership_exp_db, r2_rep, reason_2); - - // } - } else { sendInfer(fact_1.negate(), reason_1, "product-split"); sendInfer(fact_2.negate(), reason_2, "product-split"); @@ -441,10 +424,11 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P computeRels(join_term); d_rel_nodes.insert(join_term); } - bool polarity = exp.getKind() != kind::NOT; - Node atom = polarity ? exp : exp[0]; - Node r1_rep = getRepresentative(join_term[0]); - Node r2_rep = getRepresentative(join_term[1]); + + bool polarity = exp.getKind() != kind::NOT; + Node atom = polarity ? exp : exp[0]; + Node r1_rep = getRepresentative(join_term[0]); + Node r2_rep = getRepresentative(join_term[1]); if(polarity) { @@ -453,23 +437,24 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P std::vector r1_element; std::vector r2_element; - NodeManager *nm = NodeManager::currentNM(); - TypeNode shared_type = r2_rep.getType().getSetElementType().getTupleTypes()[0]; - Node shared_x = nm->mkSkolem("sde_", shared_type); - Datatype dt = r1_rep.getType().getSetElementType().getDatatype(); - unsigned int i = 0; - unsigned int s1_len = r1_rep.getType().getSetElementType().getTupleLength(); - unsigned int tup_len = join_term.getType().getSetElementType().getTupleLength(); + NodeManager *nm = NodeManager::currentNM(); + TypeNode shared_type = r2_rep.getType().getSetElementType().getTupleTypes()[0]; + Node shared_x = nm->mkSkolem("sde_", shared_type); + Datatype dt = r1_rep.getType().getSetElementType().getDatatype(); + unsigned int i = 0; + unsigned int s1_len = r1_rep.getType().getSetElementType().getTupleLength(); + unsigned int tup_len = join_term.getType().getSetElementType().getTupleLength(); r1_element.push_back(Node::fromExpr(dt[0].getConstructor())); for(; i < s1_len-1; ++i) { r1_element.push_back(RelsUtils::nthElementOfTuple(atom[0], i)); } - r1_element.push_back(shared_x); + r1_element.push_back(shared_x); dt = r2_rep.getType().getSetElementType().getDatatype(); r2_element.push_back(Node::fromExpr(dt[0].getConstructor())); r2_element.push_back(shared_x); + for(; i < tup_len; ++i) { r2_element.push_back(RelsUtils::nthElementOfTuple(atom[0], i)); } @@ -478,6 +463,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P Node t2 = nm->mkNode(kind::APPLY_CONSTRUCTOR, r2_element); computeTupleReps(t1); computeTupleReps(t2); + std::vector elements = d_membership_trie[r1_rep].findTerms(d_tuple_reps[t1]); for(unsigned int j = 0; j < elements.size(); j++) { @@ -490,8 +476,8 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P } Node fact; - Node reason = atom[1] == join_term ? exp : AND(exp, explain(EQUAL(atom[1], join_term))); - Node reasons = reason; + Node reason = atom[1] == join_term ? exp : AND(exp, explain(EQUAL(atom[1], join_term))); + Node reasons = reason; fact = MEMBER(t1, r1_rep); if(r1_rep != join_term[0]) { @@ -499,8 +485,9 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P } sendInfer(fact, reasons, "join-split"); - reasons = reason; - fact = MEMBER(t2, r2_rep); + reasons = reason; + fact = MEMBER(t2, r2_rep); + if(r2_rep != join_term[1]) { reasons = Rewriter::rewrite(AND(reason, explain(EQUAL(r2_rep, join_term[1])))); } @@ -531,9 +518,9 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P */ void TheorySetsRels::applyTransposeRule(Node exp, Node tp_term, bool tp_occur) { Trace("rels-debug") << "\n[sets-rels] *********** Applying TRANSPOSE rule " << std::endl; - bool polarity = exp.getKind() != kind::NOT; - Node atom = polarity ? exp : exp[0]; - Node reversedTuple = getRepresentative(RelsUtils::reverseTuple(atom[0])); + bool polarity = exp.getKind() != kind::NOT; + Node atom = polarity ? exp : exp[0]; + Node reversedTuple = getRepresentative(RelsUtils::reverseTuple(atom[0])); if(tp_occur) { Trace("rels-debug") << "\n[sets-rels] Apply TRANSPOSE-OCCUR rule on term: " << tp_term @@ -543,9 +530,9 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P return; } - Node tp_t0_rep = getRepresentative(tp_term[0]); - Node reason = atom[1] == tp_term ? exp : Rewriter::rewrite(AND(exp, EQUAL(atom[1], tp_term))); - Node fact = MEMBER(reversedTuple, tp_t0_rep); + Node tp_t0_rep = getRepresentative(tp_term[0]); + Node reason = atom[1] == tp_term ? exp : Rewriter::rewrite(AND(exp, EQUAL(atom[1], tp_term))); + Node fact = MEMBER(reversedTuple, tp_t0_rep); if(!polarity) { // tp_term is a nested term and we eagerly compute its subterms' members @@ -575,8 +562,9 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P void TheorySetsRels::inferTC( Node exp, Node tc_rep, std::map< Node, std::hash_set< Node, NodeHashFunction > >& tc_graph, Node start_node, Node cur_node, std::hash_set< Node, NodeHashFunction >& traversed ) { - Node pair = constructPair(tc_rep, start_node, cur_node); - std::map >::iterator mem_it = d_membership_db.find(tc_rep); + Node pair = constructPair(tc_rep, start_node, cur_node); + std::map >::iterator mem_it = d_membership_db.find(tc_rep); + if(mem_it != d_membership_db.end()) { if(std::find(mem_it->second.begin(), mem_it->second.end(), pair) == mem_it->second.end()) { sendLemma( MEMBER(pair, tc_rep), Rewriter::rewrite(exp), "Transitivity" ); @@ -586,9 +574,11 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P if(traversed.find(cur_node) != traversed.end()) { return; } + traversed.insert(cur_node); - Node reason = exp; + Node reason = exp; std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator cur_set = tc_graph.find(cur_node); + if(cur_set != tc_graph.end()) { for(std::hash_set< Node, NodeHashFunction >::iterator set_it = cur_set->second.begin(); set_it != cur_set->second.end(); set_it++) { @@ -612,9 +602,10 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P pair_set_it != tc_graph.end(); pair_set_it++) { for(std::hash_set< Node, NodeHashFunction >::iterator set_it = pair_set_it->second.begin(); set_it != pair_set_it->second.end(); set_it++) { - std::hash_set elements; - Node pair = constructPair(tc_rep, pair_set_it->first, *set_it); - Node exp = findMemExp(tc_rep, pair); + std::hash_set elements; + Node pair = constructPair(tc_rep, pair_set_it->first, *set_it); + Node exp = findMemExp(tc_rep, pair); + if(d_membership_tc_exp_cache.find(tc_rep) != d_membership_tc_exp_cache.end()) { exp = AND(d_membership_tc_exp_cache[tc_rep], exp); } @@ -674,15 +665,16 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P if(d_membership_db.find(getRepresentative(n[0])) == d_membership_db.end()) return; - Node n_rep = getRepresentative(n); - Node n0_rep = getRepresentative(n[0]); - std::vector tuples = d_membership_db[n0_rep]; - std::vector exps = d_membership_exp_db[n0_rep]; + Node n_rep = getRepresentative(n); + Node n0_rep = getRepresentative(n[0]); + std::vector tuples = d_membership_db[n0_rep]; + std::vector exps = d_membership_exp_db[n0_rep]; Assert(tuples.size() == exps.size()); for(unsigned int i = 0; i < tuples.size(); i++) { - Node reason = exps[i][1] == n0_rep ? exps[i] : AND(exps[i], EQUAL(exps[i][1], n0_rep)); - Node rev_tup = getRepresentative(RelsUtils::reverseTuple(tuples[i])); - Node fact = MEMBER(rev_tup, n_rep); + Node reason = exps[i][1] == n0_rep ? exps[i] : AND(exps[i], EQUAL(exps[i][1], n0_rep)); + Node rev_tup = getRepresentative(RelsUtils::reverseTuple(tuples[i])); + Node fact = MEMBER(rev_tup, n_rep); + if(holds(fact)) { Trace("rels-debug") << "[sets-rels] New fact: " << fact << " already holds! Skip..." << std::endl; } else { @@ -697,11 +689,12 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P * */ void TheorySetsRels::composeTupleMemForRels( Node n ) { - Node r1 = n[0]; - Node r2 = n[1]; - Node r1_rep = getRepresentative(r1); - Node r2_rep = getRepresentative(r2); - NodeManager* nm = NodeManager::currentNM(); + Node r1 = n[0]; + Node r2 = n[1]; + Node r1_rep = getRepresentative(r1); + Node r2_rep = getRepresentative(r2); + NodeManager* nm = NodeManager::currentNM(); + Trace("rels-debug") << "[sets-rels] start composing tuples in relations " << r1 << " and " << r2 << std::endl; @@ -711,23 +704,24 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P std::vector new_tups; std::vector new_exps; - std::vector r1_elements = d_membership_db[r1_rep]; - std::vector r2_elements = d_membership_db[r2_rep]; - std::vector r1_exps = d_membership_exp_db[r1_rep]; - std::vector r2_exps = d_membership_exp_db[r2_rep]; - Node new_rel = n.getKind() == kind::JOIN ? nm->mkNode(kind::JOIN, r1_rep, r2_rep) + std::vector r1_elements = d_membership_db[r1_rep]; + std::vector r2_elements = d_membership_db[r2_rep]; + std::vector r1_exps = d_membership_exp_db[r1_rep]; + std::vector r2_exps = d_membership_exp_db[r2_rep]; + + Node new_rel = n.getKind() == kind::JOIN ? nm->mkNode(kind::JOIN, r1_rep, r2_rep) : nm->mkNode(kind::PRODUCT, r1_rep, r2_rep); - Node new_rel_rep = getRepresentative(new_rel); + Node new_rel_rep = getRepresentative(new_rel); unsigned int t1_len = r1_elements.front().getType().getTupleLength(); unsigned int t2_len = r2_elements.front().getType().getTupleLength(); for(unsigned int i = 0; i < r1_elements.size(); i++) { for(unsigned int j = 0; j < r2_elements.size(); j++) { - std::vector composed_tuple; - TypeNode tn = n.getType().getSetElementType(); - Node r1_rmost = RelsUtils::nthElementOfTuple(r1_elements[i], t1_len-1); - Node r2_lmost = RelsUtils::nthElementOfTuple(r2_elements[j], 0); + std::vector composed_tuple; + TypeNode tn = n.getType().getSetElementType(); + Node r1_rmost = RelsUtils::nthElementOfTuple(r1_elements[i], t1_len-1); + Node r2_lmost = RelsUtils::nthElementOfTuple(r2_elements[j], 0); composed_tuple.push_back(Node::fromExpr(tn.getDatatype()[0].getConstructor())); if((areEqual(r1_rmost, r2_lmost) && n.getKind() == kind::JOIN) || @@ -735,6 +729,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P bool isProduct = n.getKind() == kind::PRODUCT; unsigned int k = 0; unsigned int l = 1; + for(; k < t1_len - 1; ++k) { composed_tuple.push_back(RelsUtils::nthElementOfTuple(r1_elements[i], k)); } @@ -745,8 +740,9 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P for(; l < t2_len; ++l) { composed_tuple.push_back(RelsUtils::nthElementOfTuple(r2_elements[j], l)); } - Node composed_tuple_rep = getRepresentative(nm->mkNode(kind::APPLY_CONSTRUCTOR, composed_tuple)); - Node fact = MEMBER(composed_tuple_rep, new_rel_rep); + Node composed_tuple_rep = getRepresentative(nm->mkNode(kind::APPLY_CONSTRUCTOR, composed_tuple)); + Node fact = MEMBER(composed_tuple_rep, new_rel_rep); + if(holds(fact)) { Trace("rels-debug") << "[sets-rels] New fact: " << fact << " already holds! Skip..." << std::endl; } else { @@ -828,12 +824,15 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P d_lemma_cache.clear(); d_membership_trie.clear(); d_tuple_reps.clear(); + d_id_node.clear(); + d_node_id.clear(); } void TheorySetsRels::sendSplit(Node a, Node b, const char * c) { - Node eq = a.eqNode( b ); - Node neq = NOT( eq ); - Node lemma_or = OR( eq, neq ); + Node eq = a.eqNode( b ); + Node neq = NOT( eq ); + Node lemma_or = OR( eq, neq ); + Trace("rels-lemma") << "[sets-lemma] Lemma " << c << " SPLIT : " << lemma_or << std::endl; d_lemma_cache.push_back( lemma_or ); } @@ -855,18 +854,18 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P void TheorySetsRels::doPendingFacts() { std::map::iterator map_it = d_pending_facts.begin(); while( !(*d_conflict) && map_it != d_pending_facts.end()) { - Node fact = map_it->first; - Node exp = d_pending_facts[ fact ]; + Node exp = d_pending_facts[ fact ]; + if(fact.getKind() == kind::AND) { for(size_t j=0; j >::iterator TC_P void TheorySetsRels::doPendingSplitFacts() { std::map::iterator map_it = d_pending_split_facts.begin(); while( !(*d_conflict) && map_it != d_pending_split_facts.end()) { - Node fact = map_it->first; - Node exp = d_pending_split_facts[ fact ]; + Node exp = d_pending_split_facts[ fact ]; + if(fact.getKind() == kind::AND) { for(size_t j=0; j >::iterator TC_P Trace("rels-reason") << "[sets-rels] getReason(" << tc_rep << ", " << tc_term << ", " << tc_r_rep << ", " << tc_r << std::endl; if(tc_term != tc_rep) { Node reason = explain(EQUAL(tc_term, tc_rep)); + if(tc_term[0] != tc_r_rep) { return AND(reason, explain(EQUAL(tc_term[0], tc_r_rep))); } @@ -987,14 +987,15 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P // tuple might be a member of tc_rep; or it might be a member of rels or tc_terms such that // tc_terms are transitive closure of rels and are modulo equal to tc_rep Node TheorySetsRels::findMemExp(Node tc_rep, Node pair) { - Node fst = RelsUtils::nthElementOfTuple(pair, 0); - Node snd = RelsUtils::nthElementOfTuple(pair, 1); Trace("rels-exp") << "TheorySetsRels::findMemExp ( tc_rep = " << tc_rep << ", pair = " << pair << ")" << std::endl; - std::vector tc_terms = d_terms_cache.find(tc_rep)->second[kind::TCLOSURE]; + Node fst = RelsUtils::nthElementOfTuple(pair, 0); + Node snd = RelsUtils::nthElementOfTuple(pair, 1); + std::vector tc_terms = d_terms_cache.find(tc_rep)->second[kind::TCLOSURE]; + Assert(tc_terms.size() > 0); for(unsigned int i = 0; i < tc_terms.size(); i++) { - Node tc_term = tc_terms[i]; - Node tc_r_rep = getRepresentative(tc_term[0]); + Node tc_term = tc_terms[i]; + Node tc_r_rep = getRepresentative(tc_term[0]); Trace("rels-exp") << "TheorySetsRels::findMemExp ( r_rep = " << tc_r_rep << ", pair = " << pair << ")" << std::endl; std::map< Node, std::vector< Node > >::iterator tc_r_mems = d_membership_db.find(tc_r_rep); @@ -1002,8 +1003,10 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P for(unsigned int i = 0; i < tc_r_mems->second.size(); i++) { Node fst_mem = RelsUtils::nthElementOfTuple(tc_r_mems->second[i], 0); Node snd_mem = RelsUtils::nthElementOfTuple(tc_r_mems->second[i], 1); + if(areEqual(fst_mem, fst) && areEqual(snd_mem, snd)) { Node exp = MEMBER(tc_r_mems->second[i], tc_r_mems->first); + if(tc_r_rep != tc_term[0]) { exp = explain(EQUAL(tc_r_rep, tc_term[0])); } @@ -1024,12 +1027,14 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P } } - Node tc_term_rep = getRepresentative(tc_terms[i]); - std::map< Node, std::vector< Node > >::iterator tc_t_mems = d_membership_db.find(tc_term_rep); + Node tc_term_rep = getRepresentative(tc_terms[i]); + std::map< Node, std::vector< Node > >::iterator tc_t_mems = d_membership_db.find(tc_term_rep); + if(tc_t_mems != d_membership_db.end()) { for(unsigned int j = 0; j < tc_t_mems->second.size(); j++) { Node fst_mem = RelsUtils::nthElementOfTuple(tc_t_mems->second[j], 0); Node snd_mem = RelsUtils::nthElementOfTuple(tc_t_mems->second[j], 1); + if(areEqual(fst_mem, fst) && areEqual(snd_mem, snd)) { Node exp = MEMBER(tc_t_mems->second[j], tc_t_mems->first); if(tc_rep != tc_terms[i]) { @@ -1052,14 +1057,6 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P } } } -// std::map< Node, std::vector< Node > >::iterator tc_mems = d_membership_db.find(tc_rep); -// if(tc_mems != d_membership_db.end()) { -// for(unsigned int i = 0; i < tc_mems->second.size(); i++) { -// if(tc_mems->second[i] == tuple) { -// return explain(d_membership_exp_db[tc_rep][i]); -// } -// } -// } return Node::null(); } @@ -1087,9 +1084,10 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P bool TheorySetsRels::holds(Node node) { Trace("rels-check") << " [sets-rels] Check if node = " << node << " already holds " << std::endl; - bool polarity = node.getKind() != kind::NOT; - Node atom = polarity ? node : node[0]; - Node polarity_atom = polarity ? d_trueNode : d_falseNode; + bool polarity = node.getKind() != kind::NOT; + Node atom = polarity ? node : node[0]; + Node polarity_atom = polarity ? d_trueNode : d_falseNode; + if(d_eqEngine->hasTerm(atom)) { Trace("rels-check") << " [sets-rels] node = " << node << " is in the EE " << std::endl; return areEqual(atom, polarity_atom); @@ -1129,6 +1127,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P Trace("rels-debug") << "Reduce tuple var: " << n[0] << " to concrete one " << std::endl; std::vector tuple_elements; tuple_elements.push_back(Node::fromExpr((n[0].getType().getDatatype())[0].getConstructor())); + for(unsigned int i = 0; i < n[0].getType().getTupleLength(); i++) { Node element = RelsUtils::nthElementOfTuple(n[0], i); makeSharedTerm(element); @@ -1147,6 +1146,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P eq::EqualityEngine* eq, context::CDO* conflict, TheorySets& d_set): + d_eqEngine(eq), d_sets_theory(d_set), d_trueNode(NodeManager::currentNM()->mkConst(true)), d_falseNode(NodeManager::currentNM()->mkConst(false)), @@ -1156,7 +1156,6 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P d_lemma(u), d_shared_terms(u), d_tc_saver(u), - d_eqEngine(eq), d_conflict(conflict) { d_eqEngine->addFunctionKind(kind::PRODUCT); @@ -1168,8 +1167,9 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P TheorySetsRels::~TheorySetsRels() {} std::vector TupleTrie::findTerms( std::vector< Node >& reps, int argIndex ) { - std::vector nodes; - std::map< Node, TupleTrie >::iterator it; + std::vector nodes; + std::map< Node, TupleTrie >::iterator it; + if( argIndex==(int)reps.size()-1 ){ if(reps[argIndex].getKind() == kind::SKOLEM) { it = d_data.begin(); @@ -1231,10 +1231,9 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P Node TheorySetsRels::explain(Node literal) { Trace("rels-exp") << "[sets-rels] TheorySetsRels::explain(" << literal << ")"<< std::endl; - - bool polarity = literal.getKind() != kind::NOT; - TNode atom = polarity ? literal : literal[0]; - std::vector assumptions; + std::vector assumptions; + bool polarity = literal.getKind() != kind::NOT; + TNode atom = polarity ? literal : literal[0]; if(atom.getKind() == kind::EQUAL || atom.getKind() == kind::IFF) { d_eqEngine->explainEquality(atom[0], atom[1], polarity, assumptions); @@ -1253,7 +1252,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P } TheorySetsRels::EqcInfo::EqcInfo( context::Context* c ) : - counter(0), d_mem(c), d_not_mem(c), d_in(c), d_out(c), d_tc_mem_exp(c), + d_mem(c), d_not_mem(c), d_in(c), d_out(c), d_mem_exp(c), d_tp(c), d_pt(c), d_join(c), d_tc(c), d_id_in(c), d_id_out(c) {} void TheorySetsRels::eqNotifyNewClass( Node n ) { @@ -1265,36 +1264,222 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P getOrMakeEqcInfo( n, true ); } } - void TheorySetsRels::addTCMem(EqcInfo* tc_ei, Node mem) { - Node fst = RelsUtils::nthElementOfTuple(mem, 0); - Node snd = RelsUtils::nthElementOfTuple(mem, 1); - - NodeList* in_lst; - NodeList* out_lst; - NodeListMap::iterator tc_in_mem_it = tc_ei->d_in.find(snd); - if(tc_in_mem_it == tc_ei->d_in.end()) { - in_lst = new(d_sets_theory.getSatContext()->getCMM()) NodeList( true, d_sets_theory.getSatContext(), false, + + // Create a integer id for tuple element + int TheorySetsRels::getOrMakeElementRepId(EqcInfo* ei, Node e_rep) { + Trace("rels-std") << "[sets-rels] getOrMakeElementRepId:" << " e_rep = " << e_rep << std::endl; + std::map::iterator nid_it = d_node_id.find(e_rep); + + if( nid_it == d_node_id.end() ) { + Trace("rels-std") << "[sets-rels] getOrMakeElementRepId:" << " *** 0"<< std::endl; + if(d_eqEngine->hasTerm(e_rep)) { + // it is possible that e's rep changes at this moment, thus we need to know the eqc of e's previous rep + eq::EqClassIterator rep_eqc_i = eq::EqClassIterator( e_rep, d_eqEngine ); + Trace("rels-std") << "[sets-rels] getOrMakeElementRepId:" << " *** 1"<< std::endl; + while(!rep_eqc_i.isFinished()) { + std::map::iterator id_it = d_node_id.find(*rep_eqc_i); + + if( id_it != d_node_id.end() ) { + d_id_node[id_it->second] = e_rep; + d_node_id[e_rep] = id_it->second; + return id_it->second; + } + rep_eqc_i++; + } + } + d_id_node[ei->counter] = e_rep; + d_node_id[e_rep] = ei->counter; + ei->counter++; + return ei->counter-1; + } + Trace("rels-std") << "[sets-rels] finish getOrMakeElementRepId:" << " e_rep = " << e_rep << std::endl; + return nid_it->second; + } + + bool TheorySetsRels::insertIntoIdList(IdList& idList, int mem) { + IdList::const_iterator idListIt = idList.begin(); + while(idListIt != idList.end()) { + if(*idListIt == mem) { + return false; + } + idListIt++; + } + idList.push_back(mem); + return true; + } + + void TheorySetsRels::addTCMemAndSendInfer(EqcInfo* tc_ei, Node membership, Node exp, bool fromRel) { + Trace("rels-std") << "[sets-rels] addTCMemAndSendInfer:" << " membership = " << membership << " from a relation? " << fromRel<< std::endl; + IdList* in_lst; + IdList* out_lst; + Node fst = RelsUtils::nthElementOfTuple(membership[0], 0); + Node snd = RelsUtils::nthElementOfTuple(membership[0], 1); + Node fst_rep = getRepresentative(fst); + Node snd_rep = getRepresentative(snd); + Node mem_rep = RelsUtils::constructPair(tc_ei->d_tc, fst_rep, snd_rep); + int fst_rep_id = getOrMakeElementRepId(tc_ei, fst_rep); + int snd_rep_id = getOrMakeElementRepId(tc_ei, snd_rep); + + IdListMap::iterator tc_in_mem_it = tc_ei->d_id_in.find(snd_rep_id); + + if(tc_in_mem_it == tc_ei->d_id_in.end()) { + in_lst = new(d_sets_theory.getSatContext()->getCMM()) IdList( true, d_sets_theory.getSatContext(), false, context::ContextMemoryAllocator(d_sets_theory.getSatContext()->getCMM()) ); - tc_ei->d_in.insertDataFromContextMemory(snd, in_lst); - Trace("rels-std") << "Create cache for " << snd << std::endl; + tc_ei->d_id_in.insertDataFromContextMemory(snd_rep_id, in_lst); + Trace("rels-std") << "Create in cache for " << snd_rep << std::endl; } else { in_lst = (*tc_in_mem_it).second; } - Trace("rels-std") << "Add in membership arrow for " << snd << " : " << fst << std::endl; - in_lst->push_back( fst ); + // If fst_rep is inserted into in_lst successfully, + // save rep pair's exp and send out TC inference lemmas. + // Otherwise, mem's rep is already in the TC and return. + if(insertIntoIdList(*in_lst, fst_rep_id)) { + Node reason = exp == Node::null() ? explain(membership) : exp; + if(!fromRel && tc_ei->d_tc.get() != membership[1]) { + reason = AND(reason, explain(EQUAL(tc_ei->d_tc.get(), membership[1]))); + } + if(fst != fst_rep) { + reason = AND(reason, explain(EQUAL(fst, fst_rep))); + } + if(snd != snd_rep) { + reason = AND(reason, explain(EQUAL(snd, snd_rep))); + } + tc_ei->d_mem_exp[mem_rep] = reason; + Trace("rels-std") << "Added member " << mem_rep << " for " << tc_ei->d_tc.get()<< " with reason = " << reason << std::endl; + tc_ei->d_mem.insert(mem_rep); + Trace("rels-std") << "Added in membership arrow for " << snd_rep << " from: " << fst_rep << std::endl; + } else { + // Nothing inserted into the eqc + return; + } - NodeListMap::iterator tc_out_mem_it = tc_ei->d_out.find(fst); - if(tc_out_mem_it == tc_ei->d_out.end()) { - out_lst = new(d_sets_theory.getSatContext()->getCMM()) NodeList( true, d_sets_theory.getSatContext(), false, - context::ContextMemoryAllocator(d_sets_theory.getSatContext()->getCMM()) ); - tc_ei->d_out.insertDataFromContextMemory(fst, out_lst); - Trace("rels-std") << "Create cache for " << fst << std::endl; + IdListMap::iterator tc_out_mem_it = tc_ei->d_id_out.find(fst_rep_id); + if(tc_out_mem_it == tc_ei->d_id_out.end()) { + out_lst = new(d_sets_theory.getSatContext()->getCMM()) IdList( true, d_sets_theory.getSatContext(), false, + context::ContextMemoryAllocator(d_sets_theory.getSatContext()->getCMM()) ); + tc_ei->d_id_out.insertDataFromContextMemory(fst_rep_id, out_lst); + Trace("rels-std") << "Create out arrow cache for " << snd_rep << std::endl; } else { out_lst = (*tc_out_mem_it).second; } - Trace("rels-std") << "Add out membership arrow for " << fst << " : " << snd << std::endl; - out_lst->push_back( snd ); + insertIntoIdList(*out_lst, snd_rep_id); + Trace("rels-std") << "Add out membership arrow for " << fst_rep << " to : " << snd_rep << std::endl; + sendTCInference(tc_ei, mem_rep, fst_rep, snd_rep, fst_rep_id, snd_rep_id); + } + + Node TheorySetsRels::explainTCMem(EqcInfo* ei, Node pair, Node fst, Node snd) { + if(ei->d_mem_exp.find(pair) != ei->d_mem_exp.end()) { + return (*ei->d_mem_exp.find(pair)).second; + } + NodeMap::iterator mem_exp_it = ei->d_mem_exp.begin(); + while(mem_exp_it != ei->d_mem_exp.end()) { + Node tuple = (*mem_exp_it).first; + Node fst_e = RelsUtils::nthElementOfTuple(tuple, 0); + Node snd_e = RelsUtils::nthElementOfTuple(tuple, 1); + if(areEqual(fst, fst_e) && areEqual(snd, snd_e)) { + return AND(explain(EQUAL(snd, snd_e)), AND(explain(EQUAL(fst, fst_e)), (*mem_exp_it).second)); + } + ++mem_exp_it; + } + return Node::null(); + } + + void TheorySetsRels::sendTCInference(EqcInfo* tc_ei, Node mem_rep, Node fst_rep, Node snd_rep, int id1, int id2) { + Trace("rels-std") << "Start making TC inference after adding a member " << mem_rep << " to " << tc_ei->d_tc.get() << std::endl; + Node exp = explainTCMem(tc_ei, mem_rep, fst_rep, snd_rep); + Node tc_lemma = NodeManager::currentNM()->mkNode(kind::IMPLIES, exp, MEMBER(mem_rep, tc_ei->d_tc.get())); + d_pending_merge.push_back(tc_lemma); + d_lemma.insert(tc_lemma); + + std::hash_set in_reachable; + std::hash_set out_reachable; + collectInReachableNodes(tc_ei, id1, in_reachable); + collectOutReachableNodes(tc_ei, id2, out_reachable); + Trace("rels-std") << "Start making TC inference after adding a member " << mem_rep << " to " << tc_ei->d_tc.get() << " ***** 2" << std::endl; + std::hash_set::iterator in_reachable_it = in_reachable.begin(); + while(in_reachable_it != in_reachable.end()) { + Node in_node = d_id_node[*in_reachable_it]; + Node in_pair = RelsUtils::constructPair(tc_ei->d_tc.get(), in_node, fst_rep); + Node new_pair = RelsUtils::constructPair(tc_ei->d_tc.get(), in_node, snd_rep); + Node reason = AND(explainTCMem(tc_ei, in_pair, in_node, fst_rep), exp); + Trace("rels-std") << "***$$$$$$$$$$ Adding exp for " << new_pair << " with reason " << reason << std::endl; + tc_ei->d_mem_exp[new_pair] = reason; + tc_ei->d_mem.insert(new_pair); + Node tc_lemma = NodeManager::currentNM()->mkNode(kind::IMPLIES, reason, MEMBER(new_pair, tc_ei->d_tc.get())); + d_pending_merge.push_back(tc_lemma); + d_lemma.insert(tc_lemma); + in_reachable_it++; + } + Trace("rels-std") << "Start making TC inference after adding a member " << mem_rep << " to " << tc_ei->d_tc.get() << " ***** 3" << std::endl; + std::hash_set::iterator out_reachable_it = out_reachable.begin(); + while(out_reachable_it != out_reachable.end()) { + Node out_node = d_id_node[*out_reachable_it]; + Node out_pair = RelsUtils::constructPair(tc_ei->d_tc.get(), snd_rep, out_node); + Node reason = explainTCMem(tc_ei, out_pair, snd_rep, out_node); + Assert(reason != Node::null()); + + std::hash_set::iterator in_reachable_it = in_reachable.begin(); + while(in_reachable_it != in_reachable.end()) { + Node in_node = d_id_node[*in_reachable_it]; + Node in_pair = RelsUtils::constructPair(tc_ei->d_tc.get(), in_node, snd_rep); + Node new_pair = RelsUtils::constructPair(tc_ei->d_tc.get(), in_node, out_node); + Node in_pair_exp = explainTCMem(tc_ei, in_pair, in_node, snd_rep); + + Assert(in_pair_exp != Node::null()); + reason = AND(reason, in_pair_exp); + + Trace("rels-std") << "Start making TC inference after adding a member " << mem_rep << " to " << tc_ei->d_tc.get() << " ***** 3 9" << std::endl; + tc_ei->d_mem_exp[new_pair] = reason; + tc_ei->d_mem.insert(new_pair); + Node tc_lemma = NodeManager::currentNM()->mkNode(kind::IMPLIES, reason, MEMBER(new_pair, tc_ei->d_tc.get())); + d_pending_merge.push_back(tc_lemma); + d_lemma.insert(tc_lemma); + in_reachable_it++; + } + out_reachable_it++; + } + Trace("rels-std") << "Start making TC inference after adding a member " << mem_rep << " to " << tc_ei->d_tc.get() << " ***** 4" << std::endl; + } + + void TheorySetsRels::collectInReachableNodes(EqcInfo* tc_ei, int start_id, std::hash_set& in_reachable, bool firstRound) { + Trace("rels-std") << "Start collecting in-reachable nodes for node with id " << start_id << " ***** 0" << std::endl; + if(in_reachable.find(start_id) != in_reachable.end()) { + return; + } + if(!firstRound) { + in_reachable.insert(start_id); + } + IdListMap::const_iterator id_list_map_it = tc_ei->d_id_in.find(start_id); + + if(id_list_map_it != tc_ei->d_id_in.end()) { + IdList::const_iterator id_list_it = (*id_list_map_it).second->begin(); + while(id_list_it != (*id_list_map_it).second->end()) { + collectInReachableNodes(tc_ei, *id_list_it, in_reachable, false); + id_list_it++; + } + } } + + void TheorySetsRels::collectOutReachableNodes(EqcInfo* tc_ei, int start_id, std::hash_set& out_reachable, bool firstRound) { + Trace("rels-std") << "Start collecting out-reachable nodes for node with id " << start_id << " ***** 0" << std::endl; + if(out_reachable.find(start_id) != out_reachable.end()) { + return; + } + if(!firstRound) { + out_reachable.insert(start_id); + } + IdListMap::const_iterator id_list_map_it = tc_ei->d_id_out.find(start_id); + + if(id_list_map_it != tc_ei->d_id_out.end()) { + IdList::const_iterator id_list_it = (*id_list_map_it).second->begin(); + while(id_list_it != (*id_list_map_it).second->end()) { + collectOutReachableNodes(tc_ei, *id_list_it, out_reachable, false); + id_list_it++; + } + } + } + + // Merge t2 into t1, t1 will be the rep of the new eqc void TheorySetsRels::eqNotifyPostMerge( Node t1, Node t2 ) { Trace("rels-std") << "[sets-rels] eqNotifyPostMerge:" << " t1 = " << t1 << " t2 = " << t2 << std::endl; @@ -1305,18 +1490,20 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P t2[0].getType().isTuple()) { Assert(t1 == d_trueNode || t1 == d_falseNode); - bool polarity = t1 == d_trueNode; - Node t2_1rep = getRepresentative(t2[1]); - EqcInfo* ei = getOrMakeEqcInfo( t2_1rep ); + bool polarity = t1 == d_trueNode; + Node t2_1rep = getRepresentative(t2[1]); + EqcInfo* ei = getOrMakeEqcInfo( t2_1rep ); if(ei == NULL) { ei = getOrMakeEqcInfo( t2_1rep, true ); } if(polarity) { ei->d_mem.insert(t2[0]); + ei->d_mem_exp[t2[0]] = explain(t2); } else { ei->d_not_mem.insert(t2[0]); } + // Process a membership constraint that a tuple is a member of transpose of rel if(!ei->d_tp.get().isNull()) { Node exp = polarity ? explain(t2) : explain(t2.negate()); if(ei->d_tp.get() != t2[1]) { @@ -1324,6 +1511,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P } sendInferTranspose( polarity, t2[0], ei->d_tp.get(), exp, true ); } + // Process a membership constraint that a tuple is a member of product of rel if(!ei->d_pt.get().isNull()) { Node exp = polarity ? explain(t2) : explain(t2.negate()); if(ei->d_pt.get() != t2[1]) { @@ -1331,21 +1519,19 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P } sendInferProduct(polarity, t2[0], ei->d_pt.get(), exp); } + // Process a membership constraint that a tuple is a member of transitive closure of rel if(polarity) { if(!ei->d_tc.get().isNull()) { - addTCMem(ei, t2[0]); - ei->d_tc_mem_exp.insert(t2[0], explain(t2)); - sendInferTC(ei, t2[0], explain(t2)); + addTCMemAndSendInfer(ei, t2, Node::null()); + // when we see (a, b) in R and TC(R) has not been seen, we create a EQC for TC(R) to save (a, b) } else { std::vector tup_types = t2[1].getType().getSetElementType().getTupleTypes(); + if( tup_types.size() == 2 && tup_types[0] == tup_types[1] ) { - Node tc_n = NodeManager::currentNM()->mkNode(kind::TCLOSURE, t2[1]); - EqcInfo* tc_ei = getOrMakeEqcInfo( tc_n ); - if(tc_ei != NULL) { - addTCMem(tc_ei, t2[0]); - Node exp = (tc_n == tc_ei->d_tc.get()) ? explain(t2) : AND(EQUAL(tc_n, tc_ei->d_tc.get()), explain(t2)); - tc_ei->d_tc_mem_exp.insert(t2[0], exp); - sendInferTC(tc_ei, t2[0], exp); + Node tc_n = NodeManager::currentNM()->mkNode(kind::TCLOSURE, t2[1]); + EqcInfo* tc_ei = getOrMakeEqcInfo( tc_n ); + if( tc_ei != NULL ) { + addTCMemAndSendInfer(tc_ei, t2, Node::null(), true); } } } @@ -1363,137 +1549,43 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P Trace("rels-std") << "[sets-rels] done with eqNotifyPostMerge:" << " t1 = " << t1 << " t2 = " << t2 << std::endl; } - void TheorySetsRels::sendInferTC(EqcInfo* tc_ei, Node mem, Node exp) { - Trace("rels-std") << "[sets-rels] sendInferTC member = " << mem << " with explanation = " << exp << std::endl; - if(!tc_ei->d_mem.contains(mem)) { - Node tc_lemma = NodeManager::currentNM()->mkNode(kind::IMPLIES, exp, MEMBER(mem, tc_ei->d_tc.get())); - d_pending_merge.push_back(tc_lemma); - d_lemma.insert(tc_lemma); - tc_ei->d_mem.insert(mem); - } - std::hash_set seen; - seen.insert(RelsUtils::nthElementOfTuple(mem, 0)); - sendInferInTC(tc_ei, RelsUtils::nthElementOfTuple(mem, 0), RelsUtils::nthElementOfTuple(mem, 1), seen, exp); - sendInferOutTC(tc_ei, RelsUtils::nthElementOfTuple(mem, 0), RelsUtils::nthElementOfTuple(mem, 1), seen, exp); - Trace("rels-std") << "[sets-rels] *** done with sendInferTC member = " << mem << " with explanation = " << exp << std::endl; - } - - void TheorySetsRels::sendInferInTC(EqcInfo* tc_ei, Node fst, Node snd, std::hash_set seen, Node exp) { - for(NodeListMap::iterator nl_it = tc_ei->d_in.begin(); nl_it != tc_ei->d_in.end(); nl_it++) { - if((*nl_it).first == fst) { - for(NodeList::const_iterator in_itr = (*nl_it).second->begin(); in_itr != (*nl_it).second->end(); in_itr++) { - Node pair = RelsUtils::constructPair(tc_ei->d_tc.get(), *in_itr, snd); - if(!tc_ei->d_mem.contains(pair)) { - Node reason = ((*nl_it).first == fst) ? - Rewriter::rewrite(AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), *in_itr, fst)))): - Rewriter::rewrite(AND(EQUAL((*nl_it).first, fst), AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), *in_itr, fst))))); - Node tc_lemma = NodeManager::currentNM()->mkNode(kind::IMPLIES, reason, MEMBER(pair, tc_ei->d_tc.get())); - d_pending_merge.push_back(tc_lemma); - d_lemma.insert(tc_lemma); - tc_ei->d_mem.insert(pair); - tc_ei->d_tc_mem_exp.insert(pair, reason); - } - if(seen.find(*in_itr) == seen.end()) { - seen.insert(*in_itr); - sendInferInTC(tc_ei, *in_itr, snd, seen, AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), *in_itr, fst)))); - } - } - } - } - } - - void TheorySetsRels::sendInferOutTC(EqcInfo* tc_ei, Node fst, Node snd, std::hash_set seen, Node exp) { - for(NodeListMap::iterator nl_it = tc_ei->d_out.begin(); nl_it != tc_ei->d_out.end(); nl_it++) { - if((*nl_it).first == snd) { - for(NodeList::const_iterator itr = (*nl_it).second->begin(); itr != (*nl_it).second->end(); itr++) { - Node pair = RelsUtils::constructPair(tc_ei->d_tc.get(), fst, *itr); - if(!tc_ei->d_mem.contains(pair)) { - Node reason = ((*nl_it).first == snd) ? - Rewriter::rewrite(AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), snd, *itr)))) : - Rewriter::rewrite(AND(EQUAL((*nl_it).first, snd), AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), snd, *itr))))); - Node tc_lemma = NodeManager::currentNM()->mkNode(kind::IMPLIES, reason, MEMBER(pair, tc_ei->d_tc.get())); - d_pending_merge.push_back(tc_lemma); - d_lemma.insert(tc_lemma); - tc_ei->d_mem.insert(pair); - tc_ei->d_tc_mem_exp.insert(pair, reason); - } - if(seen.find(*itr) == seen.end()) { - seen.insert(*itr); - sendInferOutTC(tc_ei, snd, *itr, seen, AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), snd, *itr)))); - } - } - } - } - } - - Node TheorySetsRels::findTCMemExp(EqcInfo* tc_ei, Node mem) { - NodeMap::iterator exp_it = tc_ei->d_tc_mem_exp.find(mem); - Assert(exp_it != tc_ei->d_tc_mem_exp.end()); - return (*exp_it).second; - } - - void TheorySetsRels::mergeTCEqcExp(EqcInfo* ei_1, EqcInfo* ei_2) { - for(NodeMap::iterator itr = ei_2->d_tc_mem_exp.begin(); itr != ei_2->d_tc_mem_exp.end(); itr++) { - NodeMap::iterator exp_it = ei_1->d_tc_mem_exp.find((*itr).first); - if(exp_it != ei_1->d_tc_mem_exp.end()) { - ei_1->d_tc_mem_exp.insert((*itr).first, OR((*itr).second, (*exp_it).second)); - } else { - ei_1->d_tc_mem_exp.insert((*itr).first, (*itr).second); - } - } - } - - void TheorySetsRels::buildTCAndExp(Node n, EqcInfo* ei) { - for(NodeSet::key_iterator mem_it = ei->d_mem.key_begin(); mem_it != ei->d_mem.key_end(); mem_it++) { - addTCMem(ei, *mem_it); - Node exp = (!ei->d_tc.get().isNull() && n == ei->d_tc.get()) ? - AND(MEMBER(*mem_it, n), explain(EQUAL(n, ei->d_tc.get()))) : - MEMBER(*mem_it, n); - ei->d_tc_mem_exp.insert(*mem_it, exp); - } - } - void TheorySetsRels::mergeTCEqcs(Node t1, Node t2) { Trace("rels-std") << "[sets-rels] Merge TC eqcs t1 = " << t1 << " and t2 = " << t2 << std::endl; EqcInfo* t1_ei = getOrMakeEqcInfo(t1); EqcInfo* t2_ei = getOrMakeEqcInfo(t2); + if(t1_ei != NULL && t2_ei != NULL) { - // Apply TC rule on members of t2 and t1->tc + NodeSet::const_iterator non_mem_it = t2_ei->d_not_mem.begin(); + while(non_mem_it != t2_ei->d_not_mem.end()) { + t1_ei->d_not_mem.insert(*non_mem_it); + non_mem_it++; + } if(!t1_ei->d_tc.get().isNull()) { - mergeTCEqcExp(t1_ei, t2_ei); - for(NodeSet::key_iterator itr = t2_ei->d_mem.key_begin(); itr != t2_ei->d_mem.key_end(); itr++) { - sendInferTC(t1_ei, *itr, findTCMemExp(t1_ei, *itr)); - if(!t1_ei->d_mem.contains(*itr)) { - - } + NodeSet::const_iterator mem_it = t2_ei->d_mem.begin(); + while(mem_it != t2_ei->d_mem.end()) { + addTCMemAndSendInfer(t1_ei, MEMBER(*mem_it, t2_ei->d_tc.get()), (*t2_ei->d_mem_exp.find(*mem_it)).second); + mem_it++; } } else if(!t2_ei->d_tc.get().isNull()) { t1_ei->d_tc.set(t2_ei->d_tc); - buildTCAndExp(t1, t1_ei); - mergeTCEqcExp(t1_ei, t2_ei); - for(NodeSet::key_iterator itr = t2_ei->d_mem.key_begin(); itr != t2_ei->d_mem.key_end(); itr++) { - sendInferTC(t1_ei, *itr, findTCMemExp(t1_ei, *itr)); - if(!t1_ei->d_mem.contains(*itr) && !t2_ei->d_mem.contains(*itr)) { - - } + NodeSet::const_iterator t1_mem_it = t1_ei->d_mem.begin(); + while(t1_mem_it != t1_ei->d_mem.end()) { + addTCMemAndSendInfer(t1_ei, MEMBER(*t1_mem_it, t1_ei->d_tc.get()), (*t1_ei->d_mem_exp.find(*t1_mem_it)).second); + t1_mem_it++; + } + NodeSet::const_iterator t2_mem_it = t2_ei->d_mem.begin(); + while(t2_mem_it != t2_ei->d_mem.end()) { + addTCMemAndSendInfer(t1_ei, MEMBER(*t2_mem_it, t2_ei->d_tc.get()), (*t2_ei->d_mem_exp.find(*t2_mem_it)).second); + t2_mem_it++; } - } - // t1 was created already and t2 was not - } else if(t1_ei != NULL) { - if(t1_ei->d_tc.get().isNull() && t2.getKind() == kind::TCLOSURE) { - t1_ei->d_tc.set( t2 ); - } - } else if(t2_ei != NULL){ - t1_ei = getOrMakeEqcInfo(t1, true); - for(NodeSet::key_iterator itr = t2_ei->d_mem.key_begin(); itr != t2_ei->d_mem.key_end(); itr++) { - t1_ei->d_mem.insert(*itr); - } - if(t1_ei->d_tc.get().isNull() && !t2_ei->d_tc.get().isNull()) { - t1_ei->d_tc.set(t2_ei->d_tc); } } + Trace("rels-std") << "[sets-rels] Done with merging TC eqcs t1 = " << t1 << " and t2 = " << t2 << std::endl; } + + + void TheorySetsRels::mergeProductEqcs(Node t1, Node t2) { Trace("rels-std") << "[sets-rels] Merge PRODUCT eqcs t1 = " << t1 << " and t2 = " << t2 << std::endl; EqcInfo* t1_ei = getOrMakeEqcInfo(t1); @@ -1503,7 +1595,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P if(!t1_ei->d_pt.get().isNull() && !t2_ei->d_pt.get().isNull()) { sendInferProduct( true, t1_ei->d_pt.get(), t2_ei->d_pt.get(), explain(EQUAL(t1, t2)) ); } - // Apply Product rule on (non)members of t2 and t1->tp + // Apply Product rule on (non)members of t2 and t1->pt if(!t1_ei->d_pt.get().isNull()) { for(NodeSet::key_iterator itr = t2_ei->d_mem.key_begin(); itr != t2_ei->d_mem.key_end(); itr++) { if(!t1_ei->d_mem.contains(*itr)) { @@ -1551,6 +1643,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P Trace("rels-std") << "[sets-rels] Merge TRANSPOSE eqcs t1 = " << t1 << " and t2 = " << t2 << std::endl; EqcInfo* t1_ei = getOrMakeEqcInfo(t1); EqcInfo* t2_ei = getOrMakeEqcInfo(t2); + if(t1_ei != NULL && t2_ei != NULL) { // TP(t1) = TP(t2) -> t1 = t2; if(!t1_ei->d_tp.get().isNull() && !t2_ei->d_tp.get().isNull()) { @@ -1662,16 +1755,15 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P return; } - std::vector r1_element; - std::vector r2_element; - Node r1 = t2[0]; - Node r2 = t2[1]; - - NodeManager *nm = NodeManager::currentNM(); - Datatype dt = r1.getType().getSetElementType().getDatatype(); - unsigned int i = 0; - unsigned int s1_len = r1.getType().getSetElementType().getTupleLength(); - unsigned int tup_len = t2.getType().getSetElementType().getTupleLength(); + std::vector r1_element; + std::vector r2_element; + Node r1 = t2[0]; + Node r2 = t2[1]; + NodeManager *nm = NodeManager::currentNM(); + Datatype dt = r1.getType().getSetElementType().getDatatype(); + unsigned int i = 0; + unsigned int s1_len = r1.getType().getSetElementType().getTupleLength(); + unsigned int tup_len = t2.getType().getSetElementType().getTupleLength(); r1_element.push_back(Node::fromExpr(dt[0].getConstructor())); for(; i < s1_len; ++i) { @@ -1683,10 +1775,12 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P for(; i < tup_len; ++i) { r2_element.push_back(RelsUtils::nthElementOfTuple(t1, i)); } - Node tuple_1 = getRepresentative(nm->mkNode(kind::APPLY_CONSTRUCTOR, r1_element)); - Node tuple_2 = getRepresentative(nm->mkNode(kind::APPLY_CONSTRUCTOR, r2_element)); + Node n1; Node n2; + Node tuple_1 = getRepresentative(nm->mkNode(kind::APPLY_CONSTRUCTOR, r1_element)); + Node tuple_2 = getRepresentative(nm->mkNode(kind::APPLY_CONSTRUCTOR, r2_element)); + if(polarity) { n1 = NodeManager::currentNM()->mkNode( kind::IMPLIES, exp, MEMBER(tuple_1, r1) ); n2 = NodeManager::currentNM()->mkNode( kind::IMPLIES, exp, MEMBER(tuple_2, r2) ); @@ -1752,9 +1846,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P all.insert(t); } } - Assert(all.size() > 0); - if (all.size() == 1) { // All the same, or just one return conjunctions[0]; diff --git a/src/theory/sets/theory_sets_rels.h b/src/theory/sets/theory_sets_rels.h index faee651b7..c38e027c7 100644 --- a/src/theory/sets/theory_sets_rels.h +++ b/src/theory/sets/theory_sets_rels.h @@ -35,8 +35,8 @@ public: /** the data */ std::map< Node, TupleTrie > d_data; public: - Node existsTerm( std::vector< Node >& reps, int argIndex = 0 ); std::vector findTerms( std::vector< Node >& reps, int argIndex = 0 ); + Node existsTerm( std::vector< Node >& reps, int argIndex = 0 ); bool addTerm( Node n, std::vector< Node >& reps, int argIndex = 0 ); void debugPrint( const char * c, Node n, unsigned depth = 0 ); void clear() { d_data.clear(); } @@ -44,14 +44,14 @@ public: class TheorySetsRels { - typedef context::CDChunkList NodeList; - typedef context::CDChunkList IdList; - typedef context::CDHashSet NodeSet; - typedef context::CDHashMap NodeBoolMap; + typedef context::CDChunkList NodeList; + typedef context::CDChunkList IdList; + typedef context::CDHashMap IdListMap; + typedef context::CDHashSet NodeSet; + typedef context::CDHashMap NodeBoolMap; typedef context::CDHashMap NodeListMap; - typedef context::CDHashMap IdListMap; - typedef context::CDHashMap NodeSetMap; - typedef context::CDHashMap NodeMap; + typedef context::CDHashMap NodeSetMap; + typedef context::CDHashMap NodeMap; public: TheorySetsRels(context::Context* c, @@ -78,64 +78,59 @@ private: public: EqcInfo( context::Context* c ); ~EqcInfo(){} - int counter; - NodeSet d_mem; - NodeSet d_not_mem; - NodeListMap d_in; - NodeListMap d_out; - NodeMap d_tc_mem_exp; - context::CDO< Node > d_tp; - context::CDO< Node > d_pt; - context::CDO< Node > d_join; - context::CDO< Node > d_tc; - IdListMap d_id_in; - IdListMap d_id_out; - std::hash_map d_id_node; - std::hash_map d_node_id; + static int counter; + NodeSet d_mem; + NodeSet d_not_mem; + NodeListMap d_in; + NodeListMap d_out; + NodeMap d_mem_exp; + IdListMap d_id_in; // mapping from a element rep id to a list of rep ids that pointed by + IdListMap d_id_out; // mapping from a element rep id to a list of rep ids that point to + context::CDO< Node > d_tp; + context::CDO< Node > d_pt; + context::CDO< Node > d_join; + context::CDO< Node > d_tc; }; +private: + std::map d_id_node; // mapping between integer id and tuple element rep + std::map d_node_id; // mapping between tuple element rep and integer id + /** has eqc info */ bool hasEqcInfo( TNode n ) { return d_eqc_info.find( n )!=d_eqc_info.end(); } private: - - TheorySets& d_sets_theory; - + eq::EqualityEngine *d_eqEngine; + context::CDO *d_conflict; + TheorySets& d_sets_theory; /** True and false constant nodes */ - Node d_trueNode; - Node d_falseNode; - + Node d_trueNode; + Node d_falseNode; // Facts and lemmas to be sent to EE - std::map< Node, Node > d_pending_facts; - std::map< Node, Node > d_pending_split_facts; - std::vector< Node > d_lemma_cache; - - NodeList d_pending_merge; - + std::map< Node, Node > d_pending_facts; + std::map< Node, Node > d_pending_split_facts; + std::vector< Node > d_lemma_cache; + NodeList d_pending_merge; /** inferences: maintained to ensure ref count for internally introduced nodes */ - NodeList d_infer; - NodeList d_infer_exp; - NodeSet d_lemma; - NodeSet d_shared_terms; - + NodeList d_infer; + NodeList d_infer_exp; + NodeSet d_lemma; + NodeSet d_shared_terms; // tc terms that have been decomposed - NodeSet d_tc_saver; - - std::hash_set< Node, NodeHashFunction > d_rel_nodes; - std::map< Node, std::vector > d_tuple_reps; - std::map< Node, TupleTrie > d_membership_trie; - std::hash_set< Node, NodeHashFunction > d_symbolic_tuples; - std::map< Node, std::vector > d_membership_constraints_cache; - std::map< Node, std::vector > d_membership_exp_cache; - std::map< Node, std::map > > d_terms_cache; - std::map< Node, std::vector > d_membership_db; - std::map< Node, std::vector > d_membership_exp_db; - std::map< Node, std::map< Node, std::hash_set > > d_membership_tc_cache; - std::map< Node, Node > d_membership_tc_exp_cache; - - eq::EqualityEngine *d_eqEngine; - context::CDO *d_conflict; + NodeSet d_tc_saver; + + std::hash_set< Node, NodeHashFunction > d_rel_nodes; + std::map< Node, std::vector > d_tuple_reps; + std::map< Node, TupleTrie > d_membership_trie; + std::hash_set< Node, NodeHashFunction > d_symbolic_tuples; + std::map< Node, std::vector > d_membership_constraints_cache; + std::map< Node, std::vector > d_membership_exp_cache; + std::map< Node, std::vector > d_membership_db; + std::map< Node, std::vector > d_membership_exp_db; + std::map< Node, Node > d_membership_tc_exp_cache; + std::map< Node, std::map > > d_terms_cache; + std::map< Node, std::map< Node, std::hash_set > > d_membership_tc_cache; /** information necessary for equivalence classes */ public: @@ -152,13 +147,15 @@ private: void mergeTCEqcs(Node t1, Node t2); void sendInferTranspose(bool, Node, Node, Node, bool reverseOnly = false); void sendInferProduct(bool, Node, Node, Node); - void sendInferTC(EqcInfo* tc_ei, Node mem, Node exp); - void sendInferInTC(EqcInfo* tc_ei, Node fst, Node snd, std::hash_set seen, Node exp); - void sendInferOutTC(EqcInfo* tc_ei, Node fst, Node snd, std::hash_set seen, Node exp); - void addTCMem(EqcInfo* tc_ei, Node mem); + void sendTCInference(EqcInfo* tc_ei, Node mem_rep, Node fst_rep, Node snd_rep, int id1, int id2); + void addTCMemAndSendInfer(EqcInfo* tc_ei, Node mem, Node exp, bool fromRel = false); Node findTCMemExp(EqcInfo*, Node); void mergeTCEqcExp(EqcInfo*, EqcInfo*); void buildTCAndExp(Node, EqcInfo*); + int getOrMakeElementRepId(EqcInfo*, Node); + void collectInReachableNodes(EqcInfo* tc_ei, int start_id, std::hash_set& in_reachable, bool firstRound = true); + void collectOutReachableNodes(EqcInfo* tc_ei, int start_id, std::hash_set& out_reachable, bool firstRound = true); + Node explainTCMem(EqcInfo*, Node, Node, Node); void check(); @@ -189,6 +186,7 @@ private: bool checkCycles( Node ); // Helper functions + bool insertIntoIdList(IdList&, int); inline Node getReason(Node tc_rep, Node tc_term, Node tc_r_rep, Node tc_r); inline Node constructPair(Node tc_rep, Node a, Node b); Node findMemExp(Node r, Node pair); @@ -206,6 +204,13 @@ private: inline void addToMembershipDB( Node, Node, Node ); bool isRel( Node n ) {return n.getType().isSet() && n.getType().getSetElementType().isTuple();} Node mkAnd( std::vector< TNode >& assumptions ); + void printNodeMap(char* fst, char* snd, NodeMap map) { + NodeMap::iterator map_it = map.begin(); + while(map_it != map.end()) { + Trace("rels-debug") << fst << " "<< (*map_it).first << " " << snd << " " << (*map_it).second<< std::endl; + map_it++; + } + } }; -- 2.30.2