From: PaulMeng Date: Sun, 28 Feb 2016 22:22:43 +0000 (-0600) Subject: implemented a basic solving procedure for finite relations (only for X-Git-Tag: cvc5-1.0.0~6079 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=9f5a29e3ec43821c37f8557f9215cb52a80c1b0b;p=cvc5.git implemented a basic solving procedure for finite relations (only for join, product, transpose operators) --- diff --git a/src/theory/sets/theory_sets.h b/src/theory/sets/theory_sets.h index bc39fcbbd..9e08b597d 100644 --- a/src/theory/sets/theory_sets.h +++ b/src/theory/sets/theory_sets.h @@ -33,6 +33,7 @@ class TheorySets : public Theory { private: friend class TheorySetsPrivate; friend class TheorySetsScrutinize; + friend class TheorySetsRels; TheorySetsPrivate* d_internal; public: diff --git a/src/theory/sets/theory_sets_private.cpp b/src/theory/sets/theory_sets_private.cpp index 5e328b4fd..4cb82b66d 100644 --- a/src/theory/sets/theory_sets_private.cpp +++ b/src/theory/sets/theory_sets_private.cpp @@ -94,9 +94,11 @@ void TheorySetsPrivate::check(Theory::Effort level) { if(d_conflict) { return; } Debug("sets") << "[sets] is complete = " << isComplete() << std::endl; - d_rels->check(level); } - + d_rels->check(level); +// if( level == Theory::EFFORT_FULL ) { +// d_rels->doPendingLemmas(); +// } if( (level == Theory::EFFORT_FULL || options::setsEagerLemmas() ) && !isComplete()) { d_external.d_out->lemma(getLemma()); return; @@ -1111,7 +1113,7 @@ TheorySetsPrivate::TheorySetsPrivate(TheorySets& external, d_rels(NULL) { d_termInfoManager = new TermInfoManager(*this, c, &d_equalityEngine); - d_rels = new TheorySetsRels(c, u, &d_equalityEngine, &d_conflict); + d_rels = new TheorySetsRels(c, u, &d_equalityEngine, &d_conflict, external); d_equalityEngine.addFunctionKind(kind::UNION); d_equalityEngine.addFunctionKind(kind::INTERSECTION); diff --git a/src/theory/sets/theory_sets_private.h b/src/theory/sets/theory_sets_private.h index 8cbc17ae3..ad04ff273 100644 --- a/src/theory/sets/theory_sets_private.h +++ b/src/theory/sets/theory_sets_private.h @@ -71,7 +71,6 @@ public: private: TheorySets& d_external; - TheorySetsRels* d_rels; class Statistics { public: @@ -200,6 +199,7 @@ private: // more debugging stuff friend class TheorySetsScrutinize; TheorySetsScrutinize* d_scrutinize; + TheorySetsRels* d_rels; void dumpAssertionsHumanified() const; /** do some formatting to make them more readable */ };/* class TheorySetsPrivate */ diff --git a/src/theory/sets/theory_sets_rels.cpp b/src/theory/sets/theory_sets_rels.cpp index fcab5b5ca..de70e6a52 100644 --- a/src/theory/sets/theory_sets_rels.cpp +++ b/src/theory/sets/theory_sets_rels.cpp @@ -17,6 +17,9 @@ #include "theory/sets/theory_sets_rels.h" #include "expr/datatype.h" +#include "theory/sets/expr_patterns.h" +#include "theory/sets/theory_sets_private.h" +#include "theory/sets/theory_sets.h" //#include "options/sets_options.h" //#include "smt/smt_statistics_registry.h" //#include "theory/sets/expr_patterns.h" // ONLY included here @@ -27,106 +30,130 @@ using namespace std; -using namespace CVC4::kind; +using namespace CVC4::expr::pattern; namespace CVC4 { namespace theory { namespace sets { - TheorySetsRels::TheorySetsRels(context::Context* c, - context::UserContext* u, - eq::EqualityEngine* eq, - context::CDO* conflict): - d_trueNode(NodeManager::currentNM()->mkConst(true)), - d_falseNode(NodeManager::currentNM()->mkConst(false)), - d_eqEngine(eq), - d_conflict(conflict), - d_relsSaver(c) - { - d_eqEngine->addFunctionKind(kind::PRODUCT); - d_eqEngine->addFunctionKind(kind::JOIN); - d_eqEngine->addFunctionKind(kind::TRANSPOSE); - d_eqEngine->addFunctionKind(kind::TRANSCLOSURE); +typedef std::map > >::iterator term_it; +typedef std::map >::iterator mem_it; + + void TheorySetsRels::check(Theory::Effort level) { + Trace("rels-debug") << "[sets-rels] Start the relational solver..." << std::endl; + collectRelationalInfo(); + check(); +// doPendingFacts(); + doPendingLemmas(); + Assert(d_lemma_cache.empty()); + Assert(d_pending_facts.empty()); + Trace("rels-debug") << "[sets-rels] Done with the relational solver..." << std::endl; } - TheorySetsRels::~TheorySetsRels() {} + void TheorySetsRels::check() { + mem_it m_it = d_membership_cache.begin(); + while(m_it != d_membership_cache.end()) { + std::vector tuples = m_it->second; + Node rel_rep = m_it->first; + // No relational terms found with rel_rep as its representative + if(d_terms_cache.find(rel_rep) == d_terms_cache.end()) { + m_it++; + continue; + } + for(unsigned int i = 0; i < tuples.size(); i++) { + Node tup_rep = tuples[i]; + Node exp = d_membership_exp_cache[rel_rep][i]; + std::map > kind_terms = d_terms_cache[rel_rep]; + + if(kind_terms.find(kind::TRANSPOSE) != kind_terms.end()) { + std::vector tp_terms = kind_terms.at(kind::TRANSPOSE); + // exp is a membership term and tp_terms contains all + // transposed terms that are equal to the right hand side of exp + for(unsigned int j = 0; j < tp_terms.size(); j++) { + applyTransposeRule(exp, rel_rep, tp_terms[j]); + } + } + if(kind_terms.find(kind::JOIN) != kind_terms.end()) { + std::vector conj; + std::vector join_terms = kind_terms.at(kind::JOIN); + // exp is a membership term and join_terms contains all + // joined terms that are in the same equivalence class with the right hand side of exp + for(unsigned int j = 0; j < join_terms.size(); j++) { + applyJoinRule(exp, rel_rep, join_terms[j]); + } + } + if(kind_terms.find(kind::PRODUCT) != kind_terms.end()) { + std::vector product_terms = kind_terms.at(kind::PRODUCT); + for(unsigned int j = 0; j < product_terms.size(); j++) { + applyProductRule(exp, rel_rep, product_terms[j]); + } + } + } + m_it++; + } + } - void TheorySetsRels::check(Theory::Effort level) { - Debug("rels-eqc") << "\nStart iterating equivalence classes......\n" << std::endl; - if (!d_eqEngine->consistent()) - return; + void TheorySetsRels::collectRelationalInfo() { + Trace("rels-debug") << "[sets-rels] Start collecting relational terms..." << std::endl; eq::EqClassesIterator eqcs_i = eq::EqClassesIterator( d_eqEngine ); - while( !eqcs_i.isFinished() ){ - TNode r = (*eqcs_i); + Node r = (*eqcs_i); eq::EqClassIterator eqc_i = eq::EqClassIterator( r, d_eqEngine ); - + Trace("rels-ee") << "[sets-rels] term representative: " << r << std::endl; while( !eqc_i.isFinished() ){ - TNode n = (*eqc_i); - - // only consider membership constraints that involving relatioinal operators - if((d_eqEngine->getRepresentative(r) == d_eqEngine->getRepresentative(d_trueNode) - || d_eqEngine->getRepresentative(r) == d_eqEngine->getRepresentative(d_falseNode)) - && !d_relsSaver.contains(n)) { - - // case: [NOT] (b, a) IS_IN (TRANSPOSE X) - // => [NOT] (a, b) IS_IN X - if(n.getKind() == kind::MEMBER) { - d_relsSaver.insert(n); - if(kind::TRANSPOSE == n[1].getKind()) { - Node reversedTuple = reverseTuple(n[0]); - Node fact = NodeManager::currentNM()->mkNode(kind::MEMBER, reversedTuple, n[1][0]); + Node n = (*eqc_i); + Trace("rels-ee") << " term : " << n << std::endl; + if(getRepresentative(r) == getRepresentative(d_trueNode) || + getRepresentative(r) == getRepresentative(d_falseNode)) { + // collect membership info + if(n.getKind() == kind::MEMBER && n[0].getType().isTuple()) { + Node tup_rep = getRepresentative(n[0]); + Node rel_rep = getRepresentative(n[1]); + // No rel_rep is found + if(d_membership_cache.find(rel_rep) == d_membership_cache.end()) { + std::vector tups; + tups.push_back(tup_rep); + d_membership_cache[rel_rep] = tups; Node exp = n; - if(d_eqEngine->getRepresentative(r) == d_eqEngine->getRepresentative(d_falseNode)) { - fact = fact.negate(); + if(getRepresentative(r) == getRepresentative(d_falseNode)) exp = n.negate(); + tups.clear(); + tups.push_back(exp); + d_membership_exp_cache[rel_rep] = tups; + } else if(std::find(d_membership_cache.at(rel_rep).begin(), + d_membership_cache.at(rel_rep).end(), tup_rep) + == d_membership_cache.at(rel_rep).end()) { + d_membership_cache[rel_rep].push_back(tup_rep); + Node exp = n; + if(getRepresentative(r) == getRepresentative(d_falseNode)) + exp = n.negate(); + d_membership_exp_cache.at(rel_rep).push_back(exp); + } + } + // collect term info + } else if(r.getType().isSet() && r.getType().getSetElementType().isTuple()) { + if(n.getKind() == kind::TRANSPOSE || + n.getKind() == kind::JOIN || + n.getKind() == kind::PRODUCT || + n.getKind() == kind::TRANSCLOSURE) { + std::map > rel_terms; + std::vector terms; + // No r record is found + if(d_terms_cache.find(r) == d_terms_cache.end()) { + terms.push_back(n); + rel_terms[n.getKind()] = terms; + d_terms_cache[r] = rel_terms; + } else { + rel_terms = d_terms_cache[r]; + // No n's kind record is found + if(rel_terms.find(n.getKind()) == rel_terms.end()) { + terms.push_back(n); + rel_terms[n.getKind()] = terms; + } else { + rel_terms.at(n.getKind()).push_back(n); } - d_pending_facts[fact] = exp; - } else if(kind::JOIN == n[1].getKind()) { - TNode r1 = n[1][0]; - TNode r2 = n[1][1]; - // Need to do this efficiently... Join relations after collecting all of them - // So that we would just need to iterate over EE once - joinRelations(r1, r2, n[1].getType().getSetElementType()); - - // case: (a, b) IS_IN (X JOIN Y) - // => (a, z) IS_IN X && (z, b) IS_IN Y - if(d_eqEngine->getRepresentative(r) == d_eqEngine->getRepresentative(d_trueNode)) { - Debug("rels-join") << "Join rules (a, b) IS_IN (X JOIN Y) => ((a, z) IS_IN X && (z, b) IS_IN Y)"<< std::endl; - Assert((r1.getType().getSetElementType()).isDatatype()); - Assert((r2.getType().getSetElementType()).isDatatype()); - - unsigned int i = 0; - std::vector r1_tuple; - std::vector r2_tuple; - Node::iterator child_it = n[0].begin(); - unsigned int s1_len = r1.getType().getSetElementType().getTupleLength(); - Node shared_x = NodeManager::currentNM()->mkSkolem("sde_", r2.getType().getSetElementType().getTupleTypes()[0]); - Datatype dt = r1.getType().getSetElementType().getDatatype(); - - r1_tuple.push_back(Node::fromExpr(dt[0].getConstructor())); - for(; i < s1_len-1; ++child_it) { - r1_tuple.push_back(*child_it); - ++i; - } - r1_tuple.push_back(shared_x); - dt = r2.getType().getSetElementType().getDatatype(); - r2_tuple.push_back(Node::fromExpr(dt[0].getConstructor())); - r2_tuple.push_back(shared_x); - for(; child_it != n[0].end(); ++child_it) { - r2_tuple.push_back(*child_it); - } - Node t1 = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, r1_tuple); - Node t2 = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, r2_tuple); - Node f1 = NodeManager::currentNM()->mkNode(kind::MEMBER, t1, r1); - Node f2 = NodeManager::currentNM()->mkNode(kind::MEMBER, t2, r2); - d_pending_facts[f1] = n; - d_pending_facts[f2] = n; - } - }else if(kind::PRODUCT == n[1].getKind()) { - } } } @@ -134,135 +161,412 @@ namespace sets { } ++eqcs_i; } - doPendingFacts(); + Trace("rels-debug") << "[sets-rels] Done with collecting relational terms!" << std::endl; } - // Join all explicitly specified tuples in r1, r2 - // e.g. If (a, b) in X and (b, c) in Y, (a, c) in (X JOIN Y) - void TheorySetsRels::joinRelations(TNode r1, TNode r2, TypeNode tn) { - if (!d_eqEngine->consistent()) - return; - Debug("rels-join") << "start joining tuples in " - << r1 << " and " << r2 << std::endl; - - std::vector r1_elements; - std::vector r2_elements; - eq::EqClassesIterator eqcs_i = eq::EqClassesIterator( d_eqEngine ); + void TheorySetsRels::doPendingFacts() { + std::map::iterator map_it = d_pending_facts.begin(); + while( !(*d_conflict) && map_it != d_pending_facts.end()) { - // collect all tuples that are in r1, r2 - while( !eqcs_i.isFinished() ){ - TNode r = (*eqcs_i); - eq::EqClassIterator eqc_i = eq::EqClassIterator( r, d_eqEngine ); - while( !eqc_i.isFinished() ){ - TNode n = (*eqc_i); - if(d_eqEngine->getRepresentative(r) == d_eqEngine->getRepresentative(d_trueNode) - && n.getKind() == kind::MEMBER && n[0].getType().isTuple()) { - if(n[1] == r1) { - Debug("rels-join") << "r1 tuple: " << n[0] << std::endl; - r1_elements.push_back(n[0]); - } else if (n[1] == r2) { - Debug("rels-join") << "r2 tuple: " << n[0] << std::endl; - r2_elements.push_back(n[0]); - } + Node fact = map_it->first; + Node exp = d_pending_facts[ fact ]; + if(fact.getKind() == kind::AND) { + for(size_t j=0; j (a, z) IS_IN X && (z, b) IS_IN Y + if(polarity) { + Debug("rels-join") << "[sets-rels] Join rules (a, b) IS_IN (X JOIN Y) => " + "((a, z) IS_IN X && (z, b) IS_IN Y)"<< std::endl; + Assert((r1.getType().getSetElementType()).isDatatype()); + Assert((r2.getType().getSetElementType()).isDatatype()); + + unsigned int i = 0; + std::vector r1_tuple; + std::vector r2_tuple; + Node::iterator child_it = atom[0].begin(); + unsigned int s1_len = r1.getType().getSetElementType().getTupleLength(); + Node shared_x = NodeManager::currentNM()->mkSkolem("sde_", r2.getType().getSetElementType().getTupleTypes()[0]); + Datatype dt = r1.getType().getSetElementType().getDatatype(); + + r1_tuple.push_back(Node::fromExpr(dt[0].getConstructor())); + for(; i < s1_len-1; ++child_it) { + r1_tuple.push_back(*child_it); + ++i; + } + r1_tuple.push_back(shared_x); + dt = r2.getType().getSetElementType().getDatatype(); + r2_tuple.push_back(Node::fromExpr(dt[0].getConstructor())); + r2_tuple.push_back(shared_x); + for(; child_it != atom[0].end(); ++child_it) { + r2_tuple.push_back(*child_it); + } + Node t1 = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, r1_tuple); + Node t2 = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, r2_tuple); + Node f1 = NodeManager::currentNM()->mkNode(kind::MEMBER, t1, r1); + Node f2 = NodeManager::currentNM()->mkNode(kind::MEMBER, t2, r2); + Node reason = exp; + if(atom[1] != join_term) + reason = AND(reason, EQUAL(atom[1], join_term)); + sendInfer(f1, reason, "join-split"); + sendInfer(f2, reason, "join-split"); + } else { + // ONLY need to explicitly compute joins if there are negative literals involving JOIN + computeJoinOrProductRelations(join_term); + } + } + void TheorySetsRels::applyTransposeRule(Node exp, Node rel_rep, Node tp_term) { + Trace("rels-debug") << "\n[sets-rels] Apply transpose rule on term: " << tp_term + << " with explaination: " << exp << std::endl; + bool polarity = exp.getKind() != kind::NOT; + Node atom = polarity ? exp : exp[0]; + Node reversedTuple = reverseTuple(atom[0]); + Node reason = exp; + + if(atom[1] != tp_term) + reason = AND(reason, EQUAL(rel_rep, tp_term)); + Node fact = MEMBER(reversedTuple, tp_term[0]); + + // when the term is nested like (not tup is_in tp(x join/product y)), + // we need to compute what is inside x join/product y + if(!polarity) { + if(d_terms_cache[getRepresentative(fact[1])].find(kind::JOIN) + != d_terms_cache[getRepresentative(fact[1])].end()) { + computeJoinOrProductRelations(fact[1]); + } + if(d_terms_cache[getRepresentative(fact[1])].find(kind::PRODUCT) + != d_terms_cache[getRepresentative(fact[1])].end()) { + computeJoinOrProductRelations(fact[1]); + } + fact = fact.negate(); + } + sendInfer(fact, exp, "transpose-rule"); } - void TheorySetsRels::joinTuples(TNode r1, TNode r2, std::vector& r1_elements, std::vector& r2_elements, TypeNode tn) { + void TheorySetsRels::computeJoinOrProductRelations(Node n) { + switch(n[0].getKind()) { + case kind::JOIN: + computeJoinOrProductRelations(n[0]); + break; + case kind::TRANSPOSE: + computeTransposeRelations(n[0]); + break; + case kind::PRODUCT: + computeJoinOrProductRelations(n[0]); + break; + default: + break; + } + + switch(n[1].getKind()) { + case kind::JOIN: + computeJoinOrProductRelations(n[1]); + break; + case kind::TRANSPOSE: + computeTransposeRelations(n[1]); + break; + case kind::PRODUCT: + computeJoinOrProductRelations(n[1]); + break; + default: + break; + } + + if(d_membership_cache.find(getRepresentative(n[0])) == d_membership_cache.end() || + d_membership_cache.find(getRepresentative(n[1])) == d_membership_cache.end()) + return; + composeRelations(n); + } + + void TheorySetsRels::computeTransposeRelations(Node n) { + switch(n[0].getKind()) { + case kind::JOIN: + computeJoinOrProductRelations(n[0]); + break; + case kind::TRANSPOSE: + computeTransposeRelations(n[0]); + break; + case kind::PRODUCT: + computeJoinOrProductRelations(n[0]); + break; + default: + break; + } + + if(d_membership_cache.find(getRepresentative(n[0])) == d_membership_cache.end()) + return; + std::vector rev_tuples; + std::vector rev_exps; + Node n_rep = getRepresentative(n); + Node n0_rep = getRepresentative(n[0]); + + if(d_membership_cache.find(n_rep) != d_membership_cache.end()) { + rev_tuples = d_membership_cache[n_rep]; + rev_exps = d_membership_exp_cache[n_rep]; + } + std::vector tuples = d_membership_cache[n0_rep]; + std::vector exps = d_membership_exp_cache[n0_rep]; + for(unsigned int i = 0; i < tuples.size(); i++) { + // Todo: Need to consider duplicates + Node reason = exps[i]; + Node rev_tup = reverseTuple(tuples[i]); + if(exps[i][1] != n0_rep) + reason = AND(reason, EQUAL(exps[i][1], n0_rep)); + rev_tuples.push_back(rev_tup); + rev_exps.push_back(Rewriter::rewrite(reason)); + sendInfer(MEMBER(rev_tup, n_rep), Rewriter::rewrite(reason), "transpose-rule"); +// if(std::find(rev_tuples.begin(), rev_tuples.end(), reverseTuple(tuples[i])) == rev_tuples.end()) { +// +// } + } + d_membership_cache[n_rep] = rev_tuples; + d_membership_exp_cache[n_rep] = rev_exps; + } + + // Join all explicitly specified tuples in r1, r2 + // e.g. If (a, b) in X and (b, c) in Y, (a, c) in (X JOIN Y) + void TheorySetsRels::composeRelations(Node n) { + Node r1 = n[0]; + Node r2 = n[1]; + Node r1_rep = getRepresentative(r1); + Node r2_rep = getRepresentative(r2); + Trace("rels-debug") << "[sets-rels] start joining tuples in " + << r1 << " and " << r2 + << "\n r1_rep: " << r1_rep + << "\n r2_rep: " << r2_rep << std::endl; + + if(d_membership_cache.find(r1_rep) == d_membership_cache.end() || + d_membership_cache.find(r2_rep) == d_membership_cache.end()) + return; + + TypeNode tn = n.getType().getSetElementType(); Datatype dt = tn.getDatatype(); + std::vector new_tups; + std::vector new_exps; + std::vector r1_elements = d_membership_cache[r1_rep]; + std::vector r2_elements = d_membership_cache[r2_rep]; + std::vector r1_exps = d_membership_exp_cache[r1_rep]; + std::vector r2_exps = d_membership_exp_cache[r2_rep]; + Node new_rel = n.getKind() == kind::JOIN ? NodeManager::currentNM()->mkNode(kind::JOIN, r1_rep, r2_rep) + : NodeManager::currentNM()->mkNode(kind::PRODUCT, r1_rep, r2_rep); unsigned int t1_len = r1_elements.front().getType().getTupleLength(); unsigned int t2_len = r2_elements.front().getType().getTupleLength(); for(unsigned int i = 0; i < r1_elements.size(); i++) { for(unsigned int j = 0; j < r2_elements.size(); j++) { - if(r1_elements[i][t1_len-1] == r2_elements[j][0]) { - std::vector joinedTuple; - joinedTuple.push_back(Node::fromExpr(dt[0].getConstructor())); - for(unsigned int k = 0; k < t1_len - 1; ++k) { + std::vector joinedTuple; + joinedTuple.push_back(Node::fromExpr(dt[0].getConstructor())); + Debug("rels-debug") << "areEqual(r1_elements[i][t1_len-1], r2_elements[j][0]):\n" + << " r1_elements[i][t1_len-1] = " << r1_elements[i][t1_len-1] + << " r2_elements[j][0]) = " << r2_elements[j][0] + << " are equal? " << areEqual(r1_elements[i][t1_len-1], r2_elements[j][0]) << std::endl; + if((areEqual(r1_elements[i][t1_len-1], r2_elements[j][0]) && n.getKind() == kind::JOIN) || + n.getKind() == kind::PRODUCT) { + unsigned int k = 0; + unsigned int l = 1; + for(; k < t1_len - 1; ++k) { joinedTuple.push_back(r1_elements[i][k]); } - for(unsigned int l = 1; l < t2_len; ++l) { + if(kind::PRODUCT == n.getKind()) { + joinedTuple.push_back(r1_elements[i][k]); + joinedTuple.push_back(r1_elements[j][0]); + } + for(; l < t2_len; ++l) { joinedTuple.push_back(r2_elements[j][l]); } Node fact = NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, joinedTuple); - fact = NodeManager::currentNM()->mkNode(kind::MEMBER, fact, NodeManager::currentNM()->mkNode(kind::JOIN, r1, r2)); - Node reason = NodeManager::currentNM()->mkNode(kind::AND, - NodeManager::currentNM()->mkNode(kind::MEMBER, r1_elements[i], r1), - NodeManager::currentNM()->mkNode(kind::MEMBER, r2_elements[j], r2)); - Debug("rels-join") << "join tuples: " << r1_elements[i] + new_tups.push_back(fact); + fact = MEMBER(fact, new_rel); + std::vector reasons; + reasons.push_back(r1_exps[i]); + reasons.push_back(r2_exps[j]); + + //Todo: think about how to deal with shared terms(?) + if(n.getKind() == kind::JOIN) + reasons.push_back(EQUAL(r1_elements[i][t1_len-1], r2_elements[j][0])); + + if(r1 != r1_rep) { + reasons.push_back(EQUAL(r1, r1_rep)); + } + if(r2 != r2_rep) { + reasons.push_back(EQUAL(r2, r2_rep)); + } + Node reason = theory::Rewriter::rewrite(NodeManager::currentNM()->mkNode(kind::AND, reasons)); + new_exps.push_back(reason); + Trace("rels-debug") << "[sets-rels] compose tuples: " << r1_elements[i] << " and " << r2_elements[j] - << "\nnew fact: " << fact - << "\nreason: " << reason<< std::endl; - d_pending_facts[fact] = reason; + << "\n new fact: " << fact + << "\n reason: " << reason<< std::endl; + if(kind::JOIN == n.getKind()) + sendInfer(fact, reason, "join-compose"); + else if(kind::PRODUCT == n.getKind()) + sendInfer(fact, reason, "product-compose"); } } } - } - - - void TheorySetsRels::sendLemma(TNode fact, TNode reason, bool polarity) { + Node new_rel_rep = getRepresentative( new_rel ); + if(d_membership_cache.find( new_rel_rep ) != d_membership_cache.end()) { + std::vector tups = d_membership_cache[new_rel_rep]; + std::vector exps = d_membership_exp_cache[new_rel_rep]; + // Todo: Need to take care of duplicate tuples + tups.insert( tups.end(), new_tups.begin(), new_tups.end() ); + exps.insert( exps.end(), new_exps.begin(), new_exps.end() ); + } else { + d_membership_cache[new_rel_rep] = new_tups; + d_membership_exp_cache[new_rel_rep] = new_exps; + } + Trace("rels-debug") << "[sets-rels] Done with joining tuples !" << std::endl; } - void TheorySetsRels::doPendingFacts() { - std::map::iterator map_it = d_pending_facts.begin(); - while( !(*d_conflict) && map_it != d_pending_facts.end()) { - - Node fact = map_it->first; - Node exp = d_pending_facts[ fact ]; - Debug("rels") << "sending out pending fact: " << fact - << " reason: " << exp - << std::endl; - if(fact.getKind() == kind::AND) { - for(size_t j=0; jlemma( d_lemma_cache[i] ); + } + for( std::map::iterator child_it = d_pending_facts.begin(); + child_it != d_pending_facts.end(); child_it++ ) { + Trace("rels-debug") << "[sets-rels] Process pending fact as lemma : " << child_it->first << std::endl; + d_sets.d_out->lemma(child_it->first); } - map_it++; } d_pending_facts.clear(); + d_lemma_cache.clear(); } + void TheorySetsRels::sendSplit(Node a, Node b, const char * c) { + Node eq = a.eqNode( b ); + Node neq = NOT( eq ); + Node lemma_or = OR( eq, neq ); + Trace("rels-lemma") << "[sets-rels] Lemma " << c << " SPLIT : " << lemma_or << std::endl; + d_lemma_cache.push_back( lemma_or ); + } + void TheorySetsRels::sendLemma(Node fact, Node reason, bool polarity) { - Node TheorySetsRels::reverseTuple(TNode tuple) { - Assert(tuple.getType().isTuple()); + } + void TheorySetsRels::sendInfer( Node fact, Node exp, const char * c ) { + Trace("rels-lemma") << "[sets-rels] Infer " << fact << " from " << exp << " by " << c << std::endl; + d_pending_facts[fact] = exp; + d_infer.push_back( fact ); + d_infer_exp.push_back( exp ); + } + + Node TheorySetsRels::reverseTuple( Node tuple ) { + Assert( tuple.getType().isTuple() ); std::vector elements; std::vector tuple_types = tuple.getType().getTupleTypes(); - std::reverse(tuple_types.begin(), tuple_types.end()); - TypeNode tn = NodeManager::currentNM()->mkTupleType(tuple_types); + std::reverse( tuple_types.begin(), tuple_types.end() ); + TypeNode tn = NodeManager::currentNM()->mkTupleType( tuple_types ); Datatype dt = tn.getDatatype(); - elements.push_back(Node::fromExpr(dt[0].getConstructor())); + elements.push_back( Node::fromExpr(dt[0].getConstructor() ) ); for(Node::iterator child_it = tuple.end()-1; - child_it != tuple.begin()-1; --child_it) { - elements.push_back(*child_it); + child_it != tuple.begin()-1; --child_it) { + elements.push_back( *child_it ); + } + return NodeManager::currentNM()->mkNode( kind::APPLY_CONSTRUCTOR, elements ); + } + + void TheorySetsRels::assertMembership( Node fact, Node reason, bool polarity ) { + d_eqEngine->assertPredicate( fact, polarity, reason ); + } + + Node TheorySetsRels::getRepresentative( Node t ) { + if( d_eqEngine->hasTerm( t ) ){ + return d_eqEngine->getRepresentative( t ); + }else{ + return t; } - return NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, elements); } - void TheorySetsRels::assertMembership(TNode fact, TNode reason, bool polarity) { - Debug("rels") << "fact: " << fact - << "\npolarity : " << polarity - << "\nreason: " << reason << std::endl; - d_eqEngine->assertPredicate(fact, polarity, reason); + bool TheorySetsRels::hasTerm( Node a ){ + return d_eqEngine->hasTerm( a ); + } + + bool TheorySetsRels::areEqual( Node a, Node b ){ + if( hasTerm( a ) && hasTerm( b ) ){ +// Trace("rels-debug") << "has a and b " << a << " " << b << " are equal? "<< d_eqEngine->areEqual( a, b ) << std::endl; + return d_eqEngine->areEqual( a, b ); + }else if( a.isConst() && b.isConst() ){ + return a == b; + }else { +// Trace("rels-debug") << "to split a and b " << a << " " << b << std::endl; + addSharedTerm(a); + addSharedTerm(b); + sendSplit(a, b, "tuple-element-equality"); + return false; + } + } + + void TheorySetsRels::addSharedTerm(TNode n) { + Trace("rels-debug") << "[sets-rels] Add a shared term: " << n << std::endl; + d_sets.addSharedTerm(n); + d_eqEngine->addTriggerTerm(n, THEORY_SETS); + } + + bool TheorySetsRels::exists( std::vector& v, Node n ){ + return std::find(v.begin(), v.end(), n) != v.end(); } + + TheorySetsRels::TheorySetsRels(context::Context* c, + context::UserContext* u, + eq::EqualityEngine* eq, + context::CDO* conflict, + TheorySets& d_set): + d_sets(d_set), + d_trueNode(NodeManager::currentNM()->mkConst(true)), + d_falseNode(NodeManager::currentNM()->mkConst(false)), + d_infer(c), + d_infer_exp(c), + d_eqEngine(eq), + d_conflict(conflict) + { + d_eqEngine->addFunctionKind(kind::PRODUCT); + d_eqEngine->addFunctionKind(kind::JOIN); + d_eqEngine->addFunctionKind(kind::TRANSPOSE); + d_eqEngine->addFunctionKind(kind::TRANSCLOSURE); + } + + TheorySetsRels::~TheorySetsRels() {} + + } } } diff --git a/src/theory/sets/theory_sets_rels.h b/src/theory/sets/theory_sets_rels.h index 537fc2d43..4eb30ab12 100644 --- a/src/theory/sets/theory_sets_rels.h +++ b/src/theory/sets/theory_sets_rels.h @@ -20,53 +20,77 @@ #include "theory/theory.h" #include "theory/uf/equality_engine.h" #include "context/cdhashset.h" +#include "context/cdchunk_list.h" namespace CVC4 { namespace theory { namespace sets { +class TheorySets; + class TheorySetsRels { + typedef context::CDChunkList NodeList; + public: TheorySetsRels(context::Context* c, context::UserContext* u, eq::EqualityEngine*, - context::CDO* ); + context::CDO*, + TheorySets&); ~TheorySetsRels(); - void check(Theory::Effort); + void doPendingLemmas(); + private: + TheorySets& d_sets; + /** True and false constant nodes */ Node d_trueNode; Node d_falseNode; // Facts and lemmas to be sent to EE - std::map< Node, Node> d_pending_facts; + std::map< Node, Node > d_pending_facts; std::vector< Node > d_lemma_cache; - // Relation pairs to be joined -// std::map d_rel_pairs; -// std::hash_set d_rels; + /** inferences: maintained to ensure ref count for internally introduced nodes */ + NodeList d_infer; + NodeList d_infer_exp; + + std::map< Node, std::vector > d_membership_cache; + std::map< Node, std::vector > d_membership_exp_cache; + std::map< Node, std::map > > d_terms_cache; eq::EqualityEngine *d_eqEngine; context::CDO *d_conflict; - // save all the relational terms seen so far - context::CDHashSet d_relsSaver; - - void assertMembership(TNode fact, TNode reason, bool polarity); - - void joinRelations(TNode, TNode, TypeNode); - void joinTuples(TNode, TNode, std::vector&, std::vector&, TypeNode tn); - - Node reverseTuple(TNode); - - void sendLemma(TNode fact, TNode reason, bool polarity); - void doPendingLemmas(); + void check(); + void collectRelationalInfo(); + void assertMembership( Node fact, Node reason, bool polarity ); + void composeProductRelations( Node ); + void composeJoinRelations( Node ); + void composeRelations( Node ); + void applyTransposeRule( Node, Node, Node ); + void applyJoinRule( Node, Node, Node ); + void applyProductRule( Node, Node, Node ); + void computeJoinOrProductRelations( Node ); + void computeTransposeRelations( Node ); + Node reverseTuple( Node ); + + void sendInfer( Node fact, Node exp, const char * c ); + void sendLemma( Node fact, Node reason, bool polarity ); + void sendSplit( Node a, Node b, const char * c ); void doPendingFacts(); + void addSharedTerm( TNode n ); + + // Helper functions + Node getRepresentative( Node t ); + bool hasTerm( Node a ); + bool areEqual( Node a, Node b ); + bool exists( std::vector&, Node ); }; diff --git a/src/theory/sets/theory_sets_rewriter.cpp b/src/theory/sets/theory_sets_rewriter.cpp index 635f9856a..dac554d4f 100644 --- a/src/theory/sets/theory_sets_rewriter.cpp +++ b/src/theory/sets/theory_sets_rewriter.cpp @@ -62,6 +62,27 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) { bool isMember = checkConstantMembership(node[0], S); return RewriteResponse(REWRITE_DONE, nm->mkConst(isMember)); } + if(node[1].getKind() == kind::TRANSPOSE) { + // only work for node[0] is an actual tuple like (a, b), won't work for tuple variables + if(node[0].isVar()) + return RewriteResponse(REWRITE_DONE, node); + std::vector elements; + std::vector tuple_types = node[0].getType().getTupleTypes(); + std::reverse(tuple_types.begin(), tuple_types.end()); + TypeNode tn = NodeManager::currentNM()->mkTupleType(tuple_types); + Datatype dt = tn.getDatatype(); + elements.push_back(Node::fromExpr(dt[0].getConstructor())); + for(Node::iterator child_it = node[0].end()-1; + child_it != node[0].begin()-1; --child_it) { + elements.push_back(*child_it); + } + Node new_node = NodeManager::currentNM()->mkNode(kind::MEMBER, + NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, elements), + node[1][0]); + if(node.getKind() == kind::NOT) + new_node = NodeManager::currentNM()->mkNode(kind::NOT, new_node); + return RewriteResponse(REWRITE_AGAIN, new_node); + } break; }//kind::MEMBER @@ -176,6 +197,17 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) { break; }//kind::UNION + case kind::TRANSPOSE: { + if(node[0].getKind() != kind::TRANSPOSE) { + Trace("sets-postrewrite") << "Sets::postRewrite returning " << node << std::endl; + return RewriteResponse(REWRITE_DONE, node); + } + if(node[0].getKind() == kind::TRANSPOSE) { + return RewriteResponse(REWRITE_AGAIN, node[0][0]); + } + break; + } + default: break; }//switch(node.getKind()) diff --git a/test/regress/regress0/sets/rels/rel_join_0.cvc b/test/regress/regress0/sets/rels/rel_join_0.cvc index a251218c6..406b8d312 100644 --- a/test/regress/regress0/sets/rels/rel_join_0.cvc +++ b/test/regress/regress0/sets/rels/rel_join_0.cvc @@ -1,3 +1,4 @@ +% EXPECT: unsat OPTION "logic" "ALL_SUPPORTED"; IntPair: TYPE = [INT, INT]; x : SET OF IntPair; @@ -18,8 +19,6 @@ ASSERT (7, 5) IS_IN y; ASSERT z IS_IN x; ASSERT zt IS_IN y; -%ASSERT a IS_IN (x JOIN y); -%ASSERT NOT (v IS_IN (x JOIN y)); ASSERT NOT (a IS_IN (x JOIN y)); CHECKSAT; diff --git a/test/regress/regress0/sets/rels/rel_transpose_0.cvc b/test/regress/regress0/sets/rels/rel_transpose_0.cvc index d06528fd2..95c27edf0 100644 --- a/test/regress/regress0/sets/rels/rel_transpose_0.cvc +++ b/test/regress/regress0/sets/rels/rel_transpose_0.cvc @@ -1,3 +1,4 @@ +% EXPECT: unsat OPTION "logic" "ALL_SUPPORTED"; IntPair: TYPE = [INT, INT]; x : SET OF IntPair;