From 47af9913d01b735155cced5c8186b1a0cc60c56b Mon Sep 17 00:00:00 2001 From: PaulMeng Date: Wed, 4 May 2016 10:21:18 -0500 Subject: [PATCH] implemented TC for standard effort --- src/theory/sets/theory_sets_private.cpp | 5 + src/theory/sets/theory_sets_rels.cpp | 351 +++++++++++++++++++----- src/theory/sets/theory_sets_rels.h | 30 +- 3 files changed, 319 insertions(+), 67 deletions(-) diff --git a/src/theory/sets/theory_sets_private.cpp b/src/theory/sets/theory_sets_private.cpp index bc9227e54..aec2c119c 100644 --- a/src/theory/sets/theory_sets_private.cpp +++ b/src/theory/sets/theory_sets_private.cpp @@ -696,6 +696,11 @@ const TheorySetsPrivate::Elements& TheorySetsPrivate::getElements std::inserter(cur, cur.begin()) ); break; } + case kind::JOIN: + case kind::TCLOSURE: + case kind::TRANSPOSE: + case kind::PRODUCT: + break; default: Assert(theory::kindToTheoryId(k) != theory::THEORY_SETS, (std::string("Kind belonging to set theory not explicitly handled: ") + kindToString(k)).c_str()); diff --git a/src/theory/sets/theory_sets_rels.cpp b/src/theory/sets/theory_sets_rels.cpp index 75e3d4831..428027acc 100644 --- a/src/theory/sets/theory_sets_rels.cpp +++ b/src/theory/sets/theory_sets_rels.cpp @@ -230,6 +230,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P << tc_term << " with explanation " << exp << std::endl; bool polarity = exp.getKind() != kind::NOT; Node atom = polarity ? exp : exp[0]; + Node tup_rep = getRepresentative(atom[0]); Node tc_rep = getRepresentative(tc_term); Node tc_r_rep = getRepresentative(tc_term[0]); @@ -239,17 +240,17 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P buildTCGraph(tc_r_rep, tc_rep, tc_term); d_rel_nodes.insert(tc_rep); } - // insert atom[0] in the tc_graph if it is not in the graph already + // insert tup_rep in the tc_graph if it is not in the graph already TC_IT tc_graph_it = d_membership_tc_cache.find(tc_rep); if(polarity) { if(tc_graph_it != d_membership_tc_cache.end()) { - TC_PAIR_IT pair_set_it = tc_graph_it->second.find(RelsUtils::nthElementOfTuple(atom[0], 0)); + TC_PAIR_IT pair_set_it = tc_graph_it->second.find(RelsUtils::nthElementOfTuple(tup_rep, 0)); if(pair_set_it != tc_graph_it->second.end()) { - pair_set_it->second.insert(RelsUtils::nthElementOfTuple(atom[0], 1)); + pair_set_it->second.insert(RelsUtils::nthElementOfTuple(tup_rep, 1)); } else { std::hash_set< Node, NodeHashFunction > pair_set; - pair_set.insert(RelsUtils::nthElementOfTuple(atom[0], 1)); - tc_graph_it->second[RelsUtils::nthElementOfTuple(atom[0], 0)] = pair_set; + pair_set.insert(RelsUtils::nthElementOfTuple(tup_rep, 1)); + tc_graph_it->second[RelsUtils::nthElementOfTuple(tup_rep, 0)] = pair_set; } Node reason = getReason(tc_rep, tc_term, tc_r_rep, tc_term[0]); std::map< Node, Node >::iterator exp_it = d_membership_tc_exp_cache.find(tc_rep); @@ -259,19 +260,49 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P } else { std::map< Node, std::hash_set< Node, NodeHashFunction > > pair_set; std::hash_set< Node, NodeHashFunction > set; - set.insert(RelsUtils::nthElementOfTuple(atom[0], 1)); - pair_set[RelsUtils::nthElementOfTuple(atom[0], 0)] = set; + set.insert(RelsUtils::nthElementOfTuple(tup_rep, 1)); + pair_set[RelsUtils::nthElementOfTuple(tup_rep, 0)] = set; d_membership_tc_cache[tc_rep] = pair_set; Node reason = getReason(tc_rep, tc_term, tc_r_rep, tc_term[0]); if(!reason.isNull()) { d_membership_tc_exp_cache[tc_rep] = reason; } } - // check if atom[0] already exists in TC graph for conflict + // if(!d_tc_saver.contains(exp) && + // atom[0][0].getKind() != kind::SKOLEM && + // atom[0][1].getKind() != kind::SKOLEM) { + + // TypeNode k_type = tup_rep.getType().getTupleTypes()[1]; + // Node k_0 = NodeManager::currentNM()->mkSkolem("tc_sde_", k_type); + // Node k_1 = NodeManager::currentNM()->mkSkolem("tc_sde_", k_type); + // Node k_2 = NodeManager::currentNM()->mkSkolem("tc_sde_", k_type); + // Node k_3 = NodeManager::currentNM()->mkSkolem("tc_sde_", k_type); + // Node fact = NodeManager::currentNM()->mkNode( kind::AND, MEMBER(RelsUtils::constructPair(tc_rep, tup_rep[0], k_0), tc_r_rep), + // MEMBER(RelsUtils::constructPair(tc_rep, k_0, k_1), tc_r_rep), + // MEMBER(RelsUtils::constructPair(tc_rep, k_1, k_2), tc_r_rep), + // MEMBER(RelsUtils::constructPair(tc_rep, k_2, k_3), tc_r_rep), + // MEMBER(RelsUtils::constructPair(tc_rep, k_3, tup_rep[1]), tc_r_rep) ); + // Node reason = exp; + // if(tc_rep != tc_term) { + // reason = AND(reason, explain(EQUAL(tc_rep, tc_term))); + // } + // if(tc_r_rep != tc_term[0]) { + // reason = AND(reason, explain(EQUAL(tc_r_rep, tc_term[0]))); + // } + + // makeSharedTerm(k_0); + // makeSharedTerm(k_1); + // makeSharedTerm(k_2); + // makeSharedTerm(k_3); + // sendLemma( fact, reason, "tc-decompose" ); + // d_tc_saver.insert(exp); + // } + // check if tup_rep already exists in TC graph for conflict } else { if(tc_graph_it != d_membership_tc_cache.end()) { - checkTCGraphForConflict(atom, tc_rep, d_trueNode, RelsUtils::nthElementOfTuple(atom[0], 0), - RelsUtils::nthElementOfTuple(atom[0], 1), tc_graph_it->second); + Trace("rels-debug") << "********** tc reach here 0" << std::endl; + checkTCGraphForConflict(atom, tc_rep, d_trueNode, RelsUtils::nthElementOfTuple(tup_rep, 0), + RelsUtils::nthElementOfTuple(tup_rep, 1), tc_graph_it->second); } } } @@ -279,8 +310,11 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P void TheorySetsRels::checkTCGraphForConflict (Node atom, Node tc_rep, Node exp, Node a, Node b, std::map< Node, std::hash_set< Node, NodeHashFunction > >& pair_set) { TC_PAIR_IT pair_set_it = pair_set.find(a); + Trace("rels-debug") << "********** tc reach here 1" << " a = " << a << " b = " << b << std::endl; if(pair_set_it != pair_set.end()) { + Trace("rels-debug") << "********** tc reach here 2" << std::endl; if(pair_set_it->second.find(b) != pair_set_it->second.end()) { + Trace("rels-debug") << "********** tc reach here 3" << std::endl; Node reason = AND(exp, findMemExp(tc_rep, constructPair(tc_rep, a, b))); if(atom[1] != tc_rep) { reason = AND(exp, explain(EQUAL(atom[1], tc_rep))); @@ -292,14 +326,20 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P // << AND(reason.negate(), atom) << std::endl; // d_sets_theory.d_out->conflict(AND(reason.negate(), atom)); } else { + Trace("rels-debug") << "********** tc reach here 4" << std::endl; std::hash_set< Node, NodeHashFunction >::iterator set_it = pair_set_it->second.begin(); while(set_it != pair_set_it->second.end()) { - checkTCGraphForConflict(atom, tc_rep, AND(exp, findMemExp(tc_rep, constructPair(tc_rep, a, *set_it))), - *set_it, b, pair_set); + // need to check if *set_it has been looked already + if(!areEqual(*set_it, a)) { + checkTCGraphForConflict(atom, tc_rep, AND(exp, findMemExp(tc_rep, constructPair(tc_rep, a, *set_it))), + *set_it, b, pair_set); + } + Trace("rels-debug") << "********** looping here 6 *set_it = " << *set_it << std::endl; set_it++; } } } + Trace("rels-debug") << "********** tc reach here 5" << std::endl; } @@ -368,14 +408,15 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P reason_2 = Rewriter::rewrite(AND(reason_2, explain(EQUAL(t2, t2_rep)))); } if(polarity) { - if(safeAddToMap(d_membership_db, r1_rep, t1_rep)) { - addToMap(d_membership_exp_db, r1_rep, reason_1); - sendInfer(fact_1, reason_1, "product-split"); - } - if(safeAddToMap(d_membership_db, r2_rep, t2_rep)) { - addToMap(d_membership_exp_db, r2_rep, reason_2); - sendInfer(fact_2, reason_2, "product-split"); - } + sendInfer(fact_1, reason_1, "product-split"); + sendInfer(fact_2, reason_2, "product-split"); + // if(safeAddToMap(d_membership_db, r1_rep, t1_rep)) { + // addToMap(d_membership_exp_db, r1_rep, reason_1); + // } + // if(safeAddToMap(d_membership_db, r2_rep, t2_rep)) { + // addToMap(d_membership_exp_db, r2_rep, reason_2); + + // } } else { sendInfer(fact_1.negate(), reason_1, "product-split"); @@ -457,8 +498,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P fact = MEMBER(t1, r1_rep); if(r1_rep != join_term[0]) { reasons = Rewriter::rewrite(AND(reason, explain(EQUAL(r1_rep, join_term[0])))); - } - addToMembershipDB(r1_rep, t1, reasons); + } sendInfer(fact, reasons, "join-split"); reasons = reason; @@ -466,7 +506,6 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P if(r2_rep != join_term[1]) { reasons = Rewriter::rewrite(AND(reason, explain(EQUAL(r2_rep, join_term[1])))); } - addToMembershipDB(r2_rep, t2, reasons); sendInfer(fact, reasons, "join-split"); // Need to make the skolem "shared_x" as shared term @@ -502,14 +541,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P Trace("rels-debug") << "\n[sets-rels] Apply TRANSPOSE-OCCUR rule on term: " << tp_term << " with explanation: " << exp << std::endl; Node fact = polarity ? MEMBER(reversedTuple, tp_term) : MEMBER(reversedTuple, tp_term).negate(); - if(holds(fact)) { - Trace("rels-debug") << "[sets-rels] New fact: " << fact << " already holds. Skip...." << std::endl; - } else { - sendInfer(fact, exp, "transpose-occur"); - if(polarity) { - addToMembershipDB(tp_term, reversedTuple, exp); - } - } + sendInfer(fact, exp, "transpose-occur"); return; } @@ -539,14 +571,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P } fact = fact.negate(); } - if(holds(fact)) { - Trace("rels-debug") << "[sets-rels] New fact: " << fact << " already holds. Skip...." << std::endl; - } else { - sendInfer(fact, reason, "transpose-rule"); - if(polarity) { - addToMembershipDB(tp_t0_rep, reversedTuple, reason); - } - } + sendInfer(fact, reason, "transpose-rule"); } @@ -676,7 +701,6 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P if(holds(fact)) { Trace("rels-debug") << "[sets-rels] New fact: " << fact << " already holds! Skip..." << std::endl; } else { - addToMembershipDB(n_rep, rev_tup, Rewriter::rewrite(reason)); sendInfer(fact, Rewriter::rewrite(reason), "transpose-rule"); } } @@ -717,8 +741,8 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P for(unsigned int j = 0; j < r2_elements.size(); j++) { std::vector composed_tuple; TypeNode tn = n.getType().getSetElementType(); - Node r2_lmost = RelsUtils::nthElementOfTuple(r2_elements[j], 0); Node r1_rmost = RelsUtils::nthElementOfTuple(r1_elements[i], t1_len-1); + Node r2_lmost = RelsUtils::nthElementOfTuple(r2_elements[j], 0); composed_tuple.push_back(Node::fromExpr(tn.getDatatype()[0].getConstructor())); if((areEqual(r1_rmost, r2_lmost) && n.getKind() == kind::JOIN) || @@ -742,24 +766,35 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P Trace("rels-debug") << "[sets-rels] New fact: " << fact << " already holds! Skip..." << std::endl; } else { std::vector reasons; - reasons.push_back(r1_exps[i]); - reasons.push_back(r2_exps[j]); - if(!isProduct) - reasons.push_back(EQUAL(r1_rmost, r2_lmost)); + //Todo: need more explanation + reasons.push_back(explain(r1_exps[i])); + reasons.push_back(explain(r2_exps[j])); + if(r1_exps[i].getKind() == kind::MEMBER && r1_exps[i][0] != r1_elements[i]) { + Trace("rels-debug") << "************* $ r1 ele = " << r1_elements[i] << " r1 exp ele = " << r1_exps[i][0] << std::endl; + reasons.push_back(explain(EQUAL(r1_elements[i], r1_exps[i][0]))); + } + if(r2_exps[j].getKind() == kind::MEMBER && r2_exps[j][0] != r2_elements[j]) { + Trace("rels-debug") << "************* $ r2 ele = " << r2_elements[j] << " r2 exp ele = " << r2_exps[j][0] << std::endl; + reasons.push_back(explain(EQUAL(r2_elements[j], r2_exps[j][0]))); + } + if(!isProduct) { + if(r1_rmost != r2_lmost) { + reasons.push_back(explain(EQUAL(r1_rmost, r2_lmost))); + } + } if(r1 != r1_rep) { - reasons.push_back(EQUAL(r1, r1_rep)); + reasons.push_back(explain(EQUAL(r1, r1_rep))); } if(r2 != r2_rep) { - reasons.push_back(EQUAL(r2, r2_rep)); + reasons.push_back(explain(EQUAL(r2, r2_rep))); } Node reason = Rewriter::rewrite(nm->mkNode(kind::AND, reasons)); - addToMembershipDB(new_rel_rep, composed_tuple_rep, reason); - - if(isProduct) + if(isProduct) { sendInfer( fact, reason, "product-compose" ); - else + } else { sendInfer( fact, reason, "join-compose" ); + } Trace("rels-debug") << "[sets-rels] Compose tuples: " << r1_elements[i] << " and " << r2_elements[j] @@ -1120,8 +1155,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P context::UserContext* u, eq::EqualityEngine* eq, context::CDO* conflict, - TheorySets& d_set): - d_c(c), + TheorySets& d_set): d_sets_theory(d_set), d_trueNode(NodeManager::currentNM()->mkConst(true)), d_falseNode(NodeManager::currentNM()->mkConst(false)), @@ -1130,6 +1164,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P d_infer_exp(c), d_lemma(u), d_shared_terms(u), + d_tc_saver(u), d_eqEngine(eq), d_conflict(conflict) { @@ -1227,15 +1262,48 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P } TheorySetsRels::EqcInfo::EqcInfo( context::Context* c ) : - d_mem(c), d_not_mem(c), d_tp(c), d_pt(c) {} + d_mem(c), d_not_mem(c), d_in(c), d_out(c), d_tc_mem_exp(c), d_tp(c), d_pt(c), d_join(c), d_tc(c) {} void TheorySetsRels::eqNotifyNewClass( Node n ) { Trace("rels-std") << "[sets-rels] eqNotifyNewClass:" << " t = " << n << std::endl; - if(isRel(n) && (n.getKind() == kind::TRANSPOSE || n.getKind() == kind::PRODUCT)) { + if(isRel(n) && (n.getKind() == kind::TRANSPOSE || + n.getKind() == kind::PRODUCT || + n.getKind() == kind::JOIN || + n.getKind() == kind::TCLOSURE)) { getOrMakeEqcInfo( n, true ); } + Trace("rels-std") << "[sets-rels] eqNotifyNewClass*****:" << " t = " << n << std::endl; } + void TheorySetsRels::addTCMem(EqcInfo* tc_ei, Node mem) { + Node fst = RelsUtils::nthElementOfTuple(mem, 0); + Node snd = RelsUtils::nthElementOfTuple(mem, 1); + + NodeList* in_lst; + NodeList* out_lst; + NodeListMap::iterator tc_in_mem_it = tc_ei->d_in.find(snd); + if(tc_in_mem_it == tc_ei->d_in.end()) { + in_lst = new(d_sets_theory.getSatContext()->getCMM()) NodeList( true, d_sets_theory.getSatContext(), false, + context::ContextMemoryAllocator(d_sets_theory.getSatContext()->getCMM()) ); + tc_ei->d_in.insertDataFromContextMemory(snd, in_lst); + Trace("rels-std") << "Create cache for " << snd << std::endl; + } else { + in_lst = (*tc_in_mem_it).second; + } + Trace("rels-std") << "Add in membership arrow for " << snd << " : " << fst << std::endl; + in_lst->push_back( fst ); + NodeListMap::iterator tc_out_mem_it = tc_ei->d_out.find(fst); + if(tc_out_mem_it == tc_ei->d_out.end()) { + out_lst = new(d_sets_theory.getSatContext()->getCMM()) NodeList( true, d_sets_theory.getSatContext(), false, + context::ContextMemoryAllocator(d_sets_theory.getSatContext()->getCMM()) ); + tc_ei->d_out.insertDataFromContextMemory(fst, out_lst); + Trace("rels-std") << "Create cache for " << fst << std::endl; + } else { + out_lst = (*tc_out_mem_it).second; + } + Trace("rels-std") << "Add out membership arrow for " << fst << " : " << snd << std::endl; + out_lst->push_back( snd ); + } void TheorySetsRels::eqNotifyPostMerge( Node t1, Node t2 ) { Trace("rels-std") << "[sets-rels] eqNotifyPostMerge:" << " t1 = " << t1 << " t2 = " << t2 << std::endl; @@ -1274,18 +1342,170 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P } sendInferProduct(polarity, t2[0], ei->d_pt.get(), exp); } + if(polarity) { + if(!ei->d_tc.get().isNull()) { + addTCMem(ei, t2[0]); + ei->d_tc_mem_exp.insert(t2[0], t2); + sendInferTC(ei, t2[0], t2); + } else { + std::vector tup_types = t2[1].getType().getSetElementType().getTupleTypes(); + if( tup_types.size() == 2 && tup_types[0] == tup_types[1] ) { + Node tc_n = NodeManager::currentNM()->mkNode(kind::TCLOSURE, t2[1]); + EqcInfo* tc_ei = getOrMakeEqcInfo( tc_n ); + if(tc_ei != NULL) { + addTCMem(tc_ei, t2[0]); + Node exp = (tc_n == tc_ei->d_tc.get()) ? t2 : AND(EQUAL(tc_n, tc_ei->d_tc.get()), t2); + tc_ei->d_tc_mem_exp.insert(t2[0], exp); + sendInferTC(tc_ei, t2[0], exp); + } + } + } + } + // Merge two relation eqcs } else if(t1.getType().isSet() && t2.getType().isSet() && t1.getType().getSetElementType().isTuple()) { mergeTransposeEqcs(t1, t2); mergeProductEqcs(t1, t2); + mergeTCEqcs(t1, t2); } Trace("rels-std") << "[sets-rels] done with eqNotifyPostMerge:" << " t1 = " << t1 << " t2 = " << t2 << std::endl; } + void TheorySetsRels::sendInferTC(EqcInfo* tc_ei, Node mem, Node exp) { + Trace("rels-std") << "[sets-rels] sendInferTC member = " << mem << " with explanation = " << exp << std::endl; + if(!tc_ei->d_mem.contains(mem)) { + Node tc_lemma = NodeManager::currentNM()->mkNode(kind::IMPLIES, exp, MEMBER(mem, tc_ei->d_tc.get())); + d_pending_merge.push_back(tc_lemma); + d_lemma.insert(tc_lemma); + tc_ei->d_mem.insert(mem); + } + std::hash_set seen; + seen.insert(RelsUtils::nthElementOfTuple(mem, 0)); + sendInferInTC(tc_ei, RelsUtils::nthElementOfTuple(mem, 0), RelsUtils::nthElementOfTuple(mem, 1), seen, exp); + sendInferOutTC(tc_ei, RelsUtils::nthElementOfTuple(mem, 0), RelsUtils::nthElementOfTuple(mem, 1), seen, exp); + } + + void TheorySetsRels::sendInferInTC(EqcInfo* tc_ei, Node fst, Node snd, std::hash_set seen, Node exp) { + for(NodeListMap::iterator nl_it = tc_ei->d_in.begin(); nl_it != tc_ei->d_in.end(); nl_it++) { + if(areEqual((*nl_it).first, fst)) { + for(NodeList::const_iterator itr = (*nl_it).second->begin(); itr != (*nl_it).second->end(); itr++) { + Node pair = RelsUtils::constructPair(tc_ei->d_tc.get(), *itr, snd); + if(!tc_ei->d_mem.contains(pair)) { + Node reason = ((*nl_it).first == fst) ? + Rewriter::rewrite(AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), *itr, fst)))): + Rewriter::rewrite(AND(EQUAL((*nl_it).first, fst), AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), *itr, fst))))); + Node tc_lemma = NodeManager::currentNM()->mkNode(kind::IMPLIES, reason, MEMBER(pair, tc_ei->d_tc.get())); + d_pending_merge.push_back(tc_lemma); + d_lemma.insert(tc_lemma); + tc_ei->d_mem.insert(pair); + tc_ei->d_tc_mem_exp.insert(pair, reason); + } + if(seen.find(*itr) == seen.end()) { + seen.insert(*itr); + sendInferInTC(tc_ei, *itr, snd, seen, AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), *itr, fst)))); + } + } + } + } + } + + void TheorySetsRels::sendInferOutTC(EqcInfo* tc_ei, Node fst, Node snd, std::hash_set seen, Node exp) { + for(NodeListMap::iterator nl_it = tc_ei->d_out.begin(); nl_it != tc_ei->d_out.end(); nl_it++) { + if(areEqual((*nl_it).first, snd)) { + for(NodeList::const_iterator itr = (*nl_it).second->begin(); itr != (*nl_it).second->end(); itr++) { + Node pair = RelsUtils::constructPair(tc_ei->d_tc.get(), fst, *itr); + if(!tc_ei->d_mem.contains(pair)) { + Node reason = ((*nl_it).first == snd) ? + Rewriter::rewrite(AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), snd, *itr)))) : + Rewriter::rewrite(AND(EQUAL((*nl_it).first, snd), AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), snd, *itr))))); + Node tc_lemma = NodeManager::currentNM()->mkNode(kind::IMPLIES, reason, MEMBER(pair, tc_ei->d_tc.get())); + d_pending_merge.push_back(tc_lemma); + d_lemma.insert(tc_lemma); + tc_ei->d_mem.insert(pair); + tc_ei->d_tc_mem_exp.insert(pair, reason); + } + if(seen.find(*itr) == seen.end()) { + seen.insert(*itr); + sendInferOutTC(tc_ei, snd, *itr, seen, AND(exp, findTCMemExp(tc_ei, RelsUtils::constructPair(tc_ei->d_tc.get(), snd, *itr)))); + } + } + } + } + } + + Node TheorySetsRels::findTCMemExp(EqcInfo* tc_ei, Node mem) { + NodeMap::iterator exp_it = tc_ei->d_tc_mem_exp.find(mem); + Assert(exp_it != tc_ei->d_tc_mem_exp.end()); + return (*exp_it).second; + } + + void TheorySetsRels::mergeTCEqcExp(EqcInfo* ei_1, EqcInfo* ei_2) { + for(NodeMap::iterator itr = ei_2->d_tc_mem_exp.begin(); itr != ei_2->d_tc_mem_exp.end(); itr++) { + NodeMap::iterator exp_it = ei_1->d_tc_mem_exp.find((*itr).first); + if(exp_it != ei_1->d_tc_mem_exp.end()) { + ei_1->d_tc_mem_exp.insert((*itr).first, OR((*itr).second, (*exp_it).second)); + } else { + ei_1->d_tc_mem_exp.insert((*itr).first, (*itr).second); + } + } + } + + void TheorySetsRels::buildTCAndExp(Node n, EqcInfo* ei) { + for(NodeSet::key_iterator mem_it = ei->d_mem.key_begin(); mem_it != ei->d_mem.key_end(); mem_it++) { + addTCMem(ei, *mem_it); + Node exp = (!ei->d_tc.get().isNull() && n == ei->d_tc.get()) ? + AND(MEMBER(*mem_it, n), explain(EQUAL(n, ei->d_tc.get()))) : + MEMBER(*mem_it, n); + ei->d_tc_mem_exp.insert(*mem_it, exp); + } + } + + void TheorySetsRels::mergeTCEqcs(Node t1, Node t2) { + Trace("rels-std") << "[sets-rels] Merge TC eqcs t1 = " << t1 << " and t2 = " << t2 << std::endl; + EqcInfo* t1_ei = getOrMakeEqcInfo(t1); + EqcInfo* t2_ei = getOrMakeEqcInfo(t2); + if(t1_ei != NULL && t2_ei != NULL) { + // Apply TC rule on members of t2 and t1->tc + if(!t1_ei->d_tc.get().isNull()) { + mergeTCEqcExp(t1_ei, t2_ei); + for(NodeSet::key_iterator itr = t2_ei->d_mem.key_begin(); itr != t2_ei->d_mem.key_end(); itr++) { + sendInferTC(t1_ei, *itr, findTCMemExp(t1_ei, *itr)); + if(!t1_ei->d_mem.contains(*itr)) { + + } + } + } else if(!t2_ei->d_tc.get().isNull()) { + t1_ei->d_tc.set(t2_ei->d_tc); + buildTCAndExp(t1, t1_ei); + mergeTCEqcExp(t1_ei, t2_ei); + for(NodeSet::key_iterator itr = t2_ei->d_mem.key_begin(); itr != t2_ei->d_mem.key_end(); itr++) { + sendInferTC(t1_ei, *itr, findTCMemExp(t1_ei, *itr)); + if(!t1_ei->d_mem.contains(*itr) && !t2_ei->d_mem.contains(*itr)) { + + } + } + } + // t1 was created already and t2 was not + } else if(t1_ei != NULL) { + if(t1_ei->d_tc.get().isNull() && t2.getKind() == kind::TCLOSURE) { + t1_ei->d_tc.set( t2 ); + } + } else if(t2_ei != NULL){ + t1_ei = getOrMakeEqcInfo(t1, true); + for(NodeSet::key_iterator itr = t2_ei->d_mem.key_begin(); itr != t2_ei->d_mem.key_end(); itr++) { + t1_ei->d_mem.insert(*itr); + } + if(t1_ei->d_tc.get().isNull() && !t2_ei->d_tc.get().isNull()) { + t1_ei->d_tc.set(t2_ei->d_tc); + } + } + } + void TheorySetsRels::mergeProductEqcs(Node t1, Node t2) { + Trace("rels-std") << "[sets-rels] Merge PRODUCT eqcs t1 = " << t1 << " and t2 = " << t2 << std::endl; EqcInfo* t1_ei = getOrMakeEqcInfo(t1); EqcInfo* t2_ei = getOrMakeEqcInfo(t2); if(t1_ei != NULL && t2_ei != NULL) { @@ -1305,7 +1525,6 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P sendInferProduct( false, *itr, t1_ei->d_pt.get(), AND(explain(EQUAL(t1_ei->d_pt.get(), t2)), explain(MEMBER(*itr, t2).negate())) ); } } - // Apply transpose rule on (non)members of t1 and t2->tp } else if(!t2_ei->d_pt.get().isNull()) { t1_ei->d_pt.set(t2_ei->d_pt); for(NodeSet::key_iterator itr = t1_ei->d_mem.key_begin(); itr != t1_ei->d_mem.key_end(); itr++) { @@ -1326,7 +1545,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P } } else if(t2_ei != NULL){ t1_ei = getOrMakeEqcInfo(t1, true); - for(NodeSet::key_iterator itr = t2_ei->d_mem.key_begin(); itr != t2_ei->d_mem.key_end(); itr++) { + for(NodeSet::key_iterator itr = t2_ei->d_mem.key_begin(); itr != t2_ei->d_mem.key_end(); itr++) { t1_ei->d_mem.insert(*itr); } for(NodeSet::key_iterator itr = t2_ei->d_not_mem.key_begin(); itr != t2_ei->d_not_mem.key_end(); itr++) { @@ -1339,6 +1558,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P } void TheorySetsRels::mergeTransposeEqcs(Node t1, Node t2) { + Trace("rels-std") << "[sets-rels] Merge TRANSPOSE eqcs t1 = " << t1 << " and t2 = " << t2 << std::endl; EqcInfo* t1_ei = getOrMakeEqcInfo(t1); EqcInfo* t2_ei = getOrMakeEqcInfo(t2); if(t1_ei != NULL && t2_ei != NULL) { @@ -1442,9 +1662,10 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P void TheorySetsRels::sendInferProduct( bool polarity, Node t1, Node t2, Node exp ) { Assert(t2.getKind() == kind::PRODUCT); if(polarity && isRel(t1) && isRel(t2)) { + //PRODUCT(x) = PRODUCT(y) => x = y; Assert(t1.getKind() == kind::PRODUCT); Node n = NodeManager::currentNM()->mkNode( kind::IMPLIES, exp, EQUAL(t1[0], t2[0]) ); - Trace("rels-std") << "[sets-rels-lemma] Generate a lemma by applying transpose rule: " + Trace("rels-std") << "[sets-rels-lemma] Generate a lemma by applying product rule: " << n << std::endl; d_pending_merge.push_back(n); d_lemma.insert(n); @@ -1483,14 +1704,14 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P n1 = NodeManager::currentNM()->mkNode( kind::IMPLIES, exp, MEMBER(tuple_1, r1).negate() ); n2 = NodeManager::currentNM()->mkNode( kind::IMPLIES, exp, MEMBER(tuple_2, r2).negate() ); } - Trace("rels-std") << "[sets-rels-lemma] Generate a lemma by applying product rule: " - << n2 << std::endl; - d_pending_merge.push_back(n2); - d_lemma.insert(n2); - Trace("rels-std") << "[sets-rels-lemma] Generate a lemma by applying product rule: " + Trace("rels-std") << "[sets-rels-lemma] Generate a lemma by applying product-split rule: " << n1 << std::endl; d_pending_merge.push_back(n1); d_lemma.insert(n1); + Trace("rels-std") << "[sets-rels-lemma] Generate a lemma by applying product-split rule: " + << n2 << std::endl; + d_pending_merge.push_back(n2); + d_lemma.insert(n2); } @@ -1509,6 +1730,10 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P ei->d_tp = n; } else if(n.getKind() == kind::PRODUCT) { ei->d_pt = n; + } else if(n.getKind() == kind::TCLOSURE) { + ei->d_tc = n; + } else if(n.getKind() == kind::JOIN) { + ei->d_join = n; } return ei; }else{ diff --git a/src/theory/sets/theory_sets_rels.h b/src/theory/sets/theory_sets_rels.h index ff62b67ab..0d24c65b3 100644 --- a/src/theory/sets/theory_sets_rels.h +++ b/src/theory/sets/theory_sets_rels.h @@ -47,6 +47,9 @@ class TheorySetsRels { typedef context::CDChunkList NodeList; typedef context::CDHashSet NodeSet; typedef context::CDHashMap NodeBoolMap; + typedef context::CDHashMap NodeListMap; + typedef context::CDHashMap NodeSetMap; + typedef context::CDHashMap NodeMap; public: TheorySetsRels(context::Context* c, @@ -58,13 +61,15 @@ public: ~TheorySetsRels(); void check(Theory::Effort); void doPendingLemmas(); - context::Context * d_c; private: /** equivalence class info * d_mem tuples that are members of this equivalence class * d_not_mem tuples that are not members of this equivalence class * d_tp is a node of kind TRANSPOSE (if any) in this equivalence class, + * d_pt is a node of kind PRODUCT (if any) in this equivalence class, + * d_join is a node of kind JOIN (if any) in this equivalence class, + * d_tc is a node of kind TCLOSURE (if any) in this equivalence class, */ class EqcInfo { @@ -73,8 +78,13 @@ private: ~EqcInfo(){} NodeSet d_mem; NodeSet d_not_mem; + NodeListMap d_in; + NodeListMap d_out; + NodeMap d_tc_mem_exp; context::CDO< Node > d_tp; context::CDO< Node > d_pt; + context::CDO< Node > d_join; + context::CDO< Node > d_tc; }; /** has eqc info */ @@ -101,6 +111,9 @@ private: NodeList d_infer_exp; NodeSet d_lemma; NodeSet d_shared_terms; + + // tc terms that have been decomposed + NodeSet d_tc_saver; std::hash_set< Node, NodeHashFunction > d_rel_nodes; std::map< Node, std::vector > d_tuple_reps; @@ -123,13 +136,22 @@ public: void eqNotifyPostMerge(Node t1, Node t2); private: - void mergeTransposeEqcs(Node t1, Node t2); - void mergeProductEqcs(Node t1, Node t2); - std::map< Node, EqcInfo* > d_eqc_info; + void doPendingMerge(); + std::map< Node, EqcInfo* > d_eqc_info; EqcInfo* getOrMakeEqcInfo( Node n, bool doMake = false ); + void mergeTransposeEqcs(Node t1, Node t2); + void mergeProductEqcs(Node t1, Node t2); + void mergeTCEqcs(Node t1, Node t2); void sendInferTranspose(bool, Node, Node, Node, bool reverseOnly = false); void sendInferProduct(bool, Node, Node, Node); + void sendInferTC(EqcInfo* tc_ei, Node mem, Node exp); + void sendInferInTC(EqcInfo* tc_ei, Node fst, Node snd, std::hash_set seen, Node exp); + void sendInferOutTC(EqcInfo* tc_ei, Node fst, Node snd, std::hash_set seen, Node exp); + void addTCMem(EqcInfo* tc_ei, Node mem); + Node findTCMemExp(EqcInfo*, Node); + void mergeTCEqcExp(EqcInfo*, EqcInfo*); + void buildTCAndExp(Node, EqcInfo*); void check(); -- 2.30.2