From b8cce053839961e89ce71d7862f60b5c745258ee Mon Sep 17 00:00:00 2001 From: PaulMeng Date: Tue, 12 Apr 2016 10:02:26 -0500 Subject: [PATCH] fixed explanation for transitive closure inferences --- src/theory/sets/theory_sets_rels.cpp | 184 +++++++++++++++++---------- src/theory/sets/theory_sets_rels.h | 1 + 2 files changed, 120 insertions(+), 65 deletions(-) diff --git a/src/theory/sets/theory_sets_rels.cpp b/src/theory/sets/theory_sets_rels.cpp index 5df44d9f8..0e20b9bfa 100644 --- a/src/theory/sets/theory_sets_rels.cpp +++ b/src/theory/sets/theory_sets_rels.cpp @@ -59,7 +59,6 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P MEM_IT m_it = d_membership_constraints_cache.begin(); while(m_it != d_membership_constraints_cache.end()) { Node rel_rep = m_it->first; - Trace("rels-debug") << "[sets-rels] Processing rel_rep = " << rel_rep << std::endl; // No relational terms found with rel_rep as its representative // But TRANSPOSE(rel_rep) may occur in the context @@ -201,7 +200,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P * ----------------------------------------------------------- * x <= TRANSCLOSURE(x) && (x JOIN x) <= TRANSCLOSURE(x) .... * - * TC(x) = TC(y) => x = y + * TC(x) = TC(y) => x = y ? * */ @@ -237,10 +236,12 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P Node tc_r_rep = getRepresentative(tc_term[0]); // build the TC graph for tc_rep if it was not created before - if( d_membership_tc_cache.find(tc_rep) == d_membership_tc_cache.end() ) { + if( d_tc_nodes.find(tc_rep) == d_tc_nodes.end() ) { + Trace("rels-debug") << "[sets-rels] Start building the TC graph!" << std::endl; buildTCGraph(tc_r_rep, tc_rep, tc_term); + d_tc_nodes.insert(tc_rep); } - // insert atom[0] in the tc_graph + // insert atom[0] in the tc_graph if it is not in the graph already TC_IT tc_graph_it = d_membership_tc_cache.find(tc_rep); if(polarity) { if(tc_graph_it != d_membership_tc_cache.end()) { @@ -268,7 +269,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P d_membership_tc_exp_cache[tc_rep] = reason; } } - // check if atom[0] exists in TC graph for conflict + // check if atom[0] already exists in TC graph for conflict } else { if(tc_graph_it != d_membership_tc_cache.end()) { checkTCGraphForConflict(atom, tc_rep, d_trueNode, nthElementOfTuple(atom[0], 0), @@ -284,11 +285,11 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P if(pair_set_it->second.find(b) != pair_set_it->second.end()) { Node reason = AND(exp, findMemExp(tc_rep, constructPair(tc_rep, a, b))); if(atom[1] != tc_rep) { - reason = AND(exp, EQUAL(atom[1], tc_rep)); + reason = AND(exp, explain(EQUAL(atom[1], tc_rep))); } Trace("rels-debug") << "[sets-rels] found a conflict and send out lemma : " - << NodeManager::currentNM()->mkNode(kind::IMPLIES, reason, atom) << std::endl; - d_sets_theory.d_out->lemma(NodeManager::currentNM()->mkNode(kind::IMPLIES, reason, atom)); + << NodeManager::currentNM()->mkNode(kind::IMPLIES, Rewriter::rewrite(reason), atom) << std::endl; + d_sets_theory.d_out->lemma(NodeManager::currentNM()->mkNode(kind::IMPLIES, Rewriter::rewrite(reason), atom)); // Trace("rels-debug") << "[sets-rels] found a conflict and send out lemma : " // << AND(reason.negate(), atom) << std::endl; // d_sets_theory.d_out->conflict(AND(reason.negate(), atom)); @@ -319,53 +320,67 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P Node atom = polarity ? exp : exp[0]; Node r1_rep = getRepresentative(product_term[0]); Node r2_rep = getRepresentative(product_term[1]); + Trace("rels-debug") << "\n[sets-rels] Apply PRODUCT-SPLIT rule on term: " << product_term + << " with explanation: " << exp << std::endl; + std::vector r1_element; + std::vector r2_element; - if(polarity) { - Trace("rels-debug") << "\n[sets-rels] Apply PRODUCT-SPLIT rule on term: " << product_term - << " with explanation: " << exp << std::endl; - std::vector r1_element; - std::vector r2_element; - - NodeManager *nm = NodeManager::currentNM(); - Datatype dt = r1_rep.getType().getSetElementType().getDatatype(); - unsigned int i = 0; - unsigned int s1_len = r1_rep.getType().getSetElementType().getTupleLength(); - unsigned int tup_len = product_term.getType().getSetElementType().getTupleLength(); + NodeManager *nm = NodeManager::currentNM(); + Datatype dt = r1_rep.getType().getSetElementType().getDatatype(); + unsigned int i = 0; + unsigned int s1_len = r1_rep.getType().getSetElementType().getTupleLength(); + unsigned int tup_len = product_term.getType().getSetElementType().getTupleLength(); - r1_element.push_back(Node::fromExpr(dt[0].getConstructor())); - for(; i < s1_len; ++i) { - r1_element.push_back(nthElementOfTuple(atom[0], i)); - } + r1_element.push_back(Node::fromExpr(dt[0].getConstructor())); + for(; i < s1_len; ++i) { + r1_element.push_back(nthElementOfTuple(atom[0], i)); + } - dt = r2_rep.getType().getSetElementType().getDatatype(); - r2_element.push_back(Node::fromExpr(dt[0].getConstructor())); - for(; i < tup_len; ++i) { - r2_element.push_back(nthElementOfTuple(atom[0], i)); - } + dt = r2_rep.getType().getSetElementType().getDatatype(); + r2_element.push_back(Node::fromExpr(dt[0].getConstructor())); + for(; i < tup_len; ++i) { + r2_element.push_back(nthElementOfTuple(atom[0], i)); + } - Node fact; - Node reason = exp; - Node t1 = getRepresentative(nm->mkNode(kind::APPLY_CONSTRUCTOR, r1_element)); - Node t2 = getRepresentative(nm->mkNode(kind::APPLY_CONSTRUCTOR, r2_element)); - - if(!hasMember(r1_rep, t1)) { - fact = MEMBER( t1, r1_rep ); - if(r1_rep != product_term[0]) - reason = Rewriter::rewrite(AND(reason, EQUAL(r1_rep, product_term[0]))); - addToMap(d_membership_db, r1_rep, t1); - addToMap(d_membership_exp_db, r1_rep, reason); - sendInfer(fact, reason, "product-split"); + Node fact_1; + Node fact_2; + Node reason_1 = exp; + Node reason_2 = exp; + Node t1 = nm->mkNode(kind::APPLY_CONSTRUCTOR, r1_element); + Node t1_rep = getRepresentative(t1); + Node t2 = nm->mkNode(kind::APPLY_CONSTRUCTOR, r2_element); + Node t2_rep = getRepresentative(t2); + + fact_1 = MEMBER( t1, r1_rep ); + fact_2 = MEMBER( t2, r2_rep ); + if(r1_rep != product_term[0]) { + reason_1 = AND(reason_1, explain(EQUAL(r1_rep, product_term[0]))); + } + if(t1 != t1_rep) { + reason_1 = Rewriter::rewrite(AND(reason_1, explain(EQUAL(t1, t1_rep)))); + } + if(r2_rep != product_term[1]) { + reason_2 = AND(reason_2, explain(EQUAL(r2_rep, product_term[1]))); + } + if(t2 != t2_rep) { + reason_2 = Rewriter::rewrite(AND(reason_2, explain(EQUAL(t2, t2_rep)))); + } + if(polarity) { + if(!hasMember(r1_rep, t1_rep)) { + addToMap(d_membership_db, r1_rep, t1_rep); + addToMap(d_membership_exp_db, r1_rep, reason_1); + sendInfer(fact_1, reason_1, "product-split"); } - if(!hasMember(r2_rep, t2)) { - fact = MEMBER( t2, r2_rep ); - if(r2_rep != product_term[1]) - reason = Rewriter::rewrite(AND(reason, EQUAL(r2_rep, product_term[1]))); addToMap(d_membership_db, r2_rep, t2); - addToMap(d_membership_exp_db, r2_rep, reason); - sendInfer(fact, reason, "product-split"); + addToMap(d_membership_exp_db, r2_rep, reason_2); + sendInfer(fact_2, reason_2, "product-split"); } + } else { +// sendInfer(fact_1.negate(), reason_1, "product-split"); +// sendInfer(fact_2.negate(), reason_2, "product-split"); + // ONLY need to explicitly compute joins if there are negative literals involving PRODUCT Trace("rels-debug") << "\n[sets-rels] Apply PRODUCT-COMPOSE rule on term: " << product_term << " with explanation: " << exp << std::endl; @@ -528,15 +543,16 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P } } - // Todo: need to add equality between two pair's left and right elements as explanation + void TheorySetsRels::inferTC( Node exp, Node tc_rep, std::map< Node, std::hash_set< Node, NodeHashFunction > >& tc_graph, Node start_node, Node cur_node, std::hash_set< Node, NodeHashFunction >& elements, bool first_round ) { Node pair = constructPair(tc_rep, start_node, cur_node); if(safeAddToMap(d_membership_db, tc_rep, pair)) { - addToMap(d_membership_exp_db, tc_rep, exp); - sendLemma( MEMBER(pair, tc_rep), exp, "Transitivity" ); + addToMap(d_membership_exp_cache, tc_rep, Rewriter::rewrite(exp)); + sendLemma( MEMBER(pair, tc_rep), Rewriter::rewrite(exp), "Transitivity" ); } + // check if cur_node has been traversed or not if(!first_round) { std::hash_set< Node, NodeHashFunction >::iterator ele_it = elements.begin(); while(ele_it != elements.end()) { @@ -547,8 +563,10 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P } } std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator pair_set_it = tc_graph.begin(); + Node reason = exp; while(pair_set_it != tc_graph.end()) { if(areEqual(pair_set_it->first, cur_node)) { + reason = AND(exp, EQUAL(pair_set_it->first, cur_node)); break; } pair_set_it++; @@ -557,10 +575,9 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P for(std::hash_set< Node, NodeHashFunction >::iterator set_it = pair_set_it->second.begin(); set_it != pair_set_it->second.end(); set_it++) { Node p = constructPair( tc_rep, cur_node, *set_it ); - Node reason = AND( findMemExp(tc_rep, p), exp ); Assert(!reason.isNull()); elements.insert(*set_it); - inferTC( reason, tc_rep, tc_graph, start_node, *set_it, elements, false ); + inferTC( AND( findMemExp(tc_rep, p), reason ), tc_rep, tc_graph, start_node, *set_it, elements, false ); } } } @@ -574,7 +591,6 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P std::hash_set elements; Node pair = constructPair(tc_rep, pair_set_it->first, *set_it); Node exp = findMemExp(tc_rep, pair); - Trace("rels-debug") << "[sets-rels] pair = " << pair << std::endl; if(d_membership_tc_exp_cache.find(tc_rep) != d_membership_tc_exp_cache.end()) { exp = AND(d_membership_tc_exp_cache[tc_rep], exp); } @@ -753,7 +769,8 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P void TheorySetsRels::doPendingLemmas() { if( !(*d_conflict) && (!d_lemma_cache.empty() || !d_pending_facts.empty())){ for( unsigned i=0; i < d_lemma_cache.size(); i++ ){ - if(holds( d_lemma_cache[i] )) { + Assert(d_lemma_cache[i].getKind() == kind::IMPLIES); + if(holds( d_lemma_cache[i][1] )) { Trace("rels-lemma") << "[sets-rels-lemma-skip] Skip the already held lemma: " << d_lemma_cache[i]<< std::endl; continue; @@ -775,6 +792,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P d_sets_theory.d_out->lemma(NodeManager::currentNM()->mkNode(kind::IMPLIES, child_it->second, child_it->first)); } } + d_tc_nodes.clear(); d_pending_facts.clear(); d_membership_constraints_cache.clear(); d_membership_tc_cache.clear(); @@ -890,7 +908,6 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P } bool TheorySetsRels::areEqual( Node a, Node b ){ - Trace("rels-debug") << "[sets-rels] areEqual( a = " << a << ", b = " << b << ")" << std::endl; if(a == b) { return true; } else if( hasTerm( a ) && hasTerm( b ) ){ @@ -936,28 +953,49 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P } inline Node TheorySetsRels::getReason(Node tc_rep, Node tc_term, Node tc_r_rep, Node tc_r) { + Trace("rels-reason") << "[sets-rels] getReason(" << tc_rep << ", " << tc_term << ", " << tc_r_rep << ", " << tc_r << std::endl; if(tc_term != tc_rep) { Node reason = explain(EQUAL(tc_term, tc_rep)); if(tc_term[0] != tc_r_rep) { return AND(reason, explain(EQUAL(tc_term[0], tc_r_rep))); } } + Trace("rels-reason") << "[sets-rels] done getReason(" << tc_rep << ", " << tc_term << ", " << tc_r_rep << ", " << tc_r << std::endl; return Node::null(); } - // tuple might be a member of tc_rep; or it might be a member of tc_terms + // tuple might be a member of tc_rep; or it might be a member of rels or tc_terms such that + // tc_terms are transitive closure of rels and are modulo equal to tc_rep Node TheorySetsRels::findMemExp(Node tc_rep, Node tuple) { Trace("rels-exp") << "TheorySetsRels::findMemExp ( tc_rep = " << tc_rep << ", tuple = " << tuple << ")" << std::endl; std::vector tc_terms = d_terms_cache.find(tc_rep)->second[kind::TRANSCLOSURE]; Assert(tc_terms.size() > 0); for(unsigned int i = 0; i < tc_terms.size(); i++) { - Node r_rep = getRepresentative(tc_terms[i][0]); - Trace("rels-exp") << "TheorySetsRels::findMemExp ( r_rep = " << r_rep << ", tuple = " << tuple << ")" << std::endl; - std::map< Node, std::vector< Node > >::iterator tc_r_mems = d_membership_db.find(r_rep); + Node tc_term = tc_terms[i]; + Node tc_r_rep = getRepresentative(tc_term[0]); + + Trace("rels-exp") << "TheorySetsRels::findMemExp ( r_rep = " << tc_r_rep << ", tuple = " << tuple << ")" << std::endl; + std::map< Node, std::vector< Node > >::iterator tc_r_mems = d_membership_db.find(tc_r_rep); if(tc_r_mems != d_membership_db.end()) { for(unsigned int i = 0; i < tc_r_mems->second.size(); i++) { if(areEqual(tc_r_mems->second[i], tuple)) { - return explain(d_membership_exp_db[r_rep][i]); + Node exp = d_trueNode; + if(tc_r_rep != tc_term[0]) { + exp = explain(EQUAL(tc_r_rep, tc_term[0])); + } + if(tc_rep != tc_term) { + exp = AND(exp, explain(EQUAL(tc_rep, tc_term))); + } + if(tc_r_mems->second[i] != tuple) { + if(nthElementOfTuple(tc_r_mems->second[i], 0) != nthElementOfTuple(tuple, 0)) { + exp = AND(exp, explain(EQUAL(nthElementOfTuple(tc_r_mems->second[i], 0), nthElementOfTuple(tuple, 0)))); + } + if(nthElementOfTuple(tc_r_mems->second[i], 1) != nthElementOfTuple(tuple, 1)) { + exp = AND(exp, explain(EQUAL(nthElementOfTuple(tc_r_mems->second[i], 1), nthElementOfTuple(tuple, 1)))); + } + exp = AND(exp, EQUAL(tc_r_mems->second[i], tuple)); + } + return Rewriter::rewrite(AND(exp, explain(d_membership_exp_db[tc_r_rep][i]))); } } } @@ -966,9 +1004,25 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P std::map< Node, std::vector< Node > >::iterator tc_t_mems = d_membership_db.find(tc_term_rep); Trace("rels-exp") << "TheorySetsRels::findMemExp ( tc_t_rep = " << tc_term_rep << ", tuple = " << tuple << ")" << std::endl; if(tc_t_mems != d_membership_db.end()) { - for(unsigned int i = 0; i < tc_t_mems->second.size(); i++) { - if(areEqual(tc_t_mems->second[i], tuple)) { - return explain(d_membership_exp_db[tc_term_rep][i]); + for(unsigned int j = 0; j < tc_t_mems->second.size(); j++) { + if(areEqual(tc_t_mems->second[j], tuple)) { + Node exp = d_trueNode; + if(tc_rep != tc_terms[i]) { + exp = AND(exp, explain(EQUAL(tc_rep, tc_terms[i]))); + } + if(tc_term_rep != tc_terms[i]) { + exp = AND(exp, explain(EQUAL(tc_term_rep, tc_terms[i]))); + } + if(tc_t_mems->second[j] != tuple) { + if(nthElementOfTuple(tc_t_mems->second[j], 0) != nthElementOfTuple(tuple, 0)) { + exp = AND(exp, explain(EQUAL(nthElementOfTuple(tc_t_mems->second[j], 0), nthElementOfTuple(tuple, 0)))); + } + if(nthElementOfTuple(tc_t_mems->second[j], 1) != nthElementOfTuple(tuple, 1)) { + exp = AND(exp, explain(EQUAL(nthElementOfTuple(tc_t_mems->second[j], 1), nthElementOfTuple(tuple, 1)))); + } + exp = AND(exp, EQUAL(tc_t_mems->second[j], tuple)); + } + return Rewriter::rewrite(AND(exp, explain(d_membership_exp_db[tc_term_rep][j]))); } } } @@ -1155,7 +1209,7 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P Node TheorySetsRels::explain(Node literal) { - Trace("rels-debug") << "[sets-rels] TheorySetsRels::explain(" << literal << ")"<< std::endl; + Trace("rels-exp") << "[sets-rels] TheorySetsRels::explain(" << literal << ")"<< std::endl; bool polarity = literal.getKind() != kind::NOT; TNode atom = polarity ? literal : literal[0]; @@ -1169,11 +1223,11 @@ typedef std::map< Node, std::hash_set< Node, NodeHashFunction > >::iterator TC_P } d_eqEngine->explainPredicate(atom, polarity, assumptions); } else { - Trace("rels-debug") << "unhandled: " << literal << "; (" << atom << ", " + Trace("rels-exp") << "unhandled: " << literal << "; (" << atom << ", " << polarity << "); kind" << atom.getKind() << std::endl; Unhandled(); } - Trace("rels-debug") << "[sets-rels] ****** done with TheorySetsRels::explain(" << literal << ")"<< std::endl; + Trace("rels-exp") << "[sets-rels] ****** done with TheorySetsRels::explain(" << literal << ")"<< std::endl; return mkAnd(assumptions); } diff --git a/src/theory/sets/theory_sets_rels.h b/src/theory/sets/theory_sets_rels.h index 8fc107a82..0876cc5b3 100644 --- a/src/theory/sets/theory_sets_rels.h +++ b/src/theory/sets/theory_sets_rels.h @@ -100,6 +100,7 @@ private: NodeSet d_lemma; NodeSet d_shared_terms; + std::hash_set< Node, NodeHashFunction > d_tc_nodes; std::map< Node, std::vector > d_tuple_reps; std::map< Node, TupleTrie > d_membership_trie; std::hash_set< Node, NodeHashFunction > d_symbolic_tuples; -- 2.30.2