From f65c5c4cbc59527dc0c9c57283a373ef501792c5 Mon Sep 17 00:00:00 2001 From: Clark Barrett Date: Mon, 11 Jul 2011 19:53:44 +0000 Subject: [PATCH] Clark's work on array theory - can now solve all QF_AX problems --- src/prop/cnf_stream.h | 16 +- src/prop/prop_engine.cpp | 4 + src/prop/prop_engine.h | 4 + src/theory/arrays/Makefile.am | 4 +- src/theory/arrays/theory_arrays.cpp | 244 +++++++++++++++++++ src/theory/arrays/theory_arrays.h | 17 ++ src/theory/arrays/theory_arrays_rewriter.h | 76 +++--- src/theory/booleans/theory_bool_rewriter.cpp | 37 +++ src/theory/valuation.cpp | 4 + src/theory/valuation.h | 5 + src/util/ntuple.h | 9 +- 11 files changed, 367 insertions(+), 53 deletions(-) diff --git a/src/prop/cnf_stream.h b/src/prop/cnf_stream.h index ef75e635b..e53b46d9b 100644 --- a/src/prop/cnf_stream.h +++ b/src/prop/cnf_stream.h @@ -146,14 +146,6 @@ protected: */ bool isTranslated(TNode node) const; - /** - * Returns true if the node has an assigned literal (it might not be translated). - * Caches the pair of the node and the literal corresponding to the - * translation. - * @param node the node - */ - bool hasLiteral(TNode node) const; - /** * Acquires a new variable from the SAT solver to represent the node * and inserts the necessary data it into the mapping tables. @@ -207,6 +199,14 @@ public: */ TNode getNode(const SatLiteral& literal); + /** + * Returns true if the node has an assigned literal (it might not be translated). + * Caches the pair of the node and the literal corresponding to the + * translation. + * @param node the node + */ + bool hasLiteral(TNode node) const; + /** * Returns the literal that represents the given node in the SAT CNF * representation. diff --git a/src/prop/prop_engine.cpp b/src/prop/prop_engine.cpp index 4c9b66020..3aa014782 100644 --- a/src/prop/prop_engine.cpp +++ b/src/prop/prop_engine.cpp @@ -170,6 +170,10 @@ Node PropEngine::getValue(TNode node) { } } +bool PropEngine::isSatLiteral(TNode node) { + return d_cnfStream->hasLiteral(node); +} + bool PropEngine::hasValue(TNode node, bool& value) { Assert(node.getType().isBoolean()); SatLiteral lit = d_cnfStream->getLiteral(node); diff --git a/src/prop/prop_engine.h b/src/prop/prop_engine.h index f44ad16f7..f6e66bef1 100644 --- a/src/prop/prop_engine.h +++ b/src/prop/prop_engine.h @@ -114,6 +114,10 @@ public: */ Node getValue(TNode node); + /* + * Return true if node has an associated SAT literal + */ + bool isSatLiteral(TNode node); /** * Check if the node has a value and return it if yes. */ diff --git a/src/theory/arrays/Makefile.am b/src/theory/arrays/Makefile.am index 1e070cdaf..3dde70145 100644 --- a/src/theory/arrays/Makefile.am +++ b/src/theory/arrays/Makefile.am @@ -13,6 +13,8 @@ libarrays_la_SOURCES = \ union_find.h \ union_find.cpp \ array_info.h \ - array_info.cpp + array_info.cpp \ + static_fact_manager.h \ + static_fact_manager.cpp EXTRA_DIST = kinds diff --git a/src/theory/arrays/theory_arrays.cpp b/src/theory/arrays/theory_arrays.cpp index 37c49b341..dab78c17a 100644 --- a/src/theory/arrays/theory_arrays.cpp +++ b/src/theory/arrays/theory_arrays.cpp @@ -21,6 +21,7 @@ #include "theory/valuation.h" #include "expr/kind.h" #include +#include "theory/rewriter.h" using namespace std; using namespace CVC4; @@ -184,6 +185,208 @@ Node TheoryArrays::getValue(TNode n) { } } +Theory::SolveStatus TheoryArrays::solve(TNode in, SubstitutionMap& outSubstitutions) { + switch(in.getKind()) { + case kind::EQUAL: + { + d_staticFactManager.addEq(in); + if (in[0].getMetaKind() == kind::metakind::VARIABLE && !in[1].hasSubterm(in[0])) { + outSubstitutions.addSubstitution(in[0], in[1]); + return SOLVE_STATUS_SOLVED; + } + if (in[1].getMetaKind() == kind::metakind::VARIABLE && !in[0].hasSubterm(in[1])) { + outSubstitutions.addSubstitution(in[1], in[0]); + return SOLVE_STATUS_SOLVED; + } + break; + } + case kind::NOT: + { + Assert(in[0].getKind() == kind::EQUAL || + in[0].getKind() == kind::IFF ); + Node a = in[0][0]; + Node b = in[0][1]; + d_staticFactManager.addDiseq(in[0]); + break; + } + default: + break; + } + return SOLVE_STATUS_UNSOLVED; +} + +Node TheoryArrays::preprocessTerm(TNode term) { + switch (term.getKind()) { + case kind::SELECT: { + // select(store(a,i,v),j) = select(a,j) + // IF i != j + if (term[0].getKind() == kind::STORE && + d_staticFactManager.areDiseq(term[0][1], term[1])) { + return NodeBuilder<2>(kind::SELECT) << term[0][0] << term[1]; + } + break; + } + case kind::STORE: { + // store(store(a,i,v),j,w) = store(store(a,j,w),i,v) + // IF i != j and j comes before i in the ordering + if (term[0].getKind() == kind::STORE && + (term[1] < term[0][1]) && + d_staticFactManager.areDiseq(term[1], term[0][1])) { + Node inner = NodeBuilder<3>(kind::STORE) << term[0][0] << term[1] << term[2]; + Node outer = NodeBuilder<3>(kind::STORE) << inner << term[0][1] << term[0][2]; + return outer; + } + break; + } + case kind::EQUAL: { + if (term[0].getKind() == kind::STORE || + term[1].getKind() == kind::STORE) { + TNode left = term[0]; + TNode right = term[1]; + int leftWrites = 0, rightWrites = 0; + + // Count nested writes + TNode e1 = left; + while (e1.getKind() == kind::STORE) { + ++leftWrites; + e1 = e1[0]; + } + + TNode e2 = right; + while (e2.getKind() == kind::STORE) { + ++rightWrites; + e2 = e2[0]; + } + + if (rightWrites > leftWrites) { + TNode tmp = left; + left = right; + right = tmp; + int tmpWrites = leftWrites; + leftWrites = rightWrites; + rightWrites = tmpWrites; + } + + NodeManager* nm = NodeManager::currentNM(); + if (rightWrites == 0) { + if (e1 == e2) { + // write(store, index_0, v_0, index_1, v_1, ..., index_n, v_n) = store IFF + // + // read(store, index_n) = v_n & + // index_{n-1} != index_n -> read(store, index_{n-1}) = v_{n-1} & + // (index_{n-2} != index_{n-1} & index_{n-2} != index_n) -> read(store, index_{n-2}) = v_{n-2} & + // ... + // (index_1 != index_2 & ... & index_1 != index_n) -> read(store, index_1) = v_1 + // (index_0 != index_1 & index_0 != index_2 & ... & index_0 != index_n) -> read(store, index_0) = v_0 + TNode write_i, write_j, index_i, index_j; + Node conc; + NodeBuilder<> result(kind::AND); + int i, j; + write_i = left; + for (i = leftWrites-1; i >= 0; --i) { + index_i = write_i[1]; + + // build: [index_i /= index_n && index_i /= index_(n-1) && + // ... && index_i /= index_(i+1)] -> read(store, index_i) = v_i + write_j = left; + { + NodeBuilder<> hyp(kind::AND); + for (j = leftWrites - 1; j > i; --j) { + index_j = write_j[1]; + if (d_staticFactManager.areDiseq(index_i, index_j)) { + continue; + } + Node hyp2(index_i.getType() == nm->booleanType()? + index_i.iffNode(index_j) : index_i.eqNode(index_j)); + hyp << hyp2.notNode(); + write_j = write_j[0]; + } + + Node r1 = nm->mkNode(kind::SELECT, e1, index_i); + conc = (r1.getType() == nm->booleanType())? + r1.iffNode(write_i[2]) : r1.eqNode(write_i[2]); + if (hyp.getNumChildren() != 0) { + if (hyp.getNumChildren() == 1) { + conc = hyp.getChild(0).impNode(conc); + } + else { + r1 = hyp; + conc = r1.impNode(conc); + } + } + + // And into result + result << conc; + + // Prepare for next iteration + write_i = write_i[0]; + } + } + Assert(result.getNumChildren() > 0); + if (result.getNumChildren() == 1) { + return result.getChild(0); + } + return result; + } + break; + } + else { + // store(...) = store(a,i,v) ==> + // store(store(...),i,select(a,i)) = a && select(store(...),i)=v + Node l = left; + Node tmp; + NodeBuilder<> nb(kind::AND); + while (right.getKind() == STORE) { + tmp = nm->mkNode(kind::SELECT, l, right[1]); + nb << tmp.eqNode(right[2]); + tmp = nm->mkNode(kind::SELECT, right[0], right[1]); + l = nm->mkNode(kind::STORE, l, right[1], tmp); + right = right[0]; + } + nb << l.eqNode(right); + return nb; + } + } + break; + } + default: + break; + } + return term; +} + +Node TheoryArrays::recursivePreprocessTerm(TNode term) { + unsigned nc = term.getNumChildren(); + if (nc == 0 || + (theoryOf(term) != theory::THEORY_ARRAY && + term.getType() != NodeManager::currentNM()->booleanType())) { + return term; + } + NodeMap::iterator find = d_ppCache.find(term); + if (find != d_ppCache.end()) { + return (*find).second; + } + NodeBuilder<> newNode(term.getKind()); + unsigned i; + for (i = 0; i < nc; ++i) { + newNode << recursivePreprocessTerm(term[i]); + } + Node newTerm = Rewriter::rewrite(newNode); + Node newTerm2 = preprocessTerm(newTerm); + if (newTerm != newTerm2) { + newTerm = recursivePreprocessTerm(Rewriter::rewrite(newTerm2)); + } + d_ppCache[term] = newTerm; + return newTerm; +} + +Node TheoryArrays::preprocess(TNode atom) { + if (d_donePreregister) return atom; + Assert(atom.getKind() == kind::EQUAL); + return recursivePreprocessTerm(atom); +} + + void TheoryArrays::merge(TNode a, TNode b) { Assert(d_conflict.isNull()); @@ -508,7 +711,48 @@ bool TheoryArrays::isRedundantInContext(TNode a, TNode b, TNode i, TNode j) { checkRowForIndex(j,b); // why am i doing this? checkRowForIndex(i,a); return true; + } + Node literal1 = Rewriter::rewrite(i.eqNode(j)); + bool hasValue1, satValue1; + Node ff = nm->mkConst(false); + Node tt = nm->mkConst(true); + if (literal1 == ff) { + hasValue1 = true; + satValue1 = false; + } + else if (literal1 == tt) { + hasValue1 = true; + satValue1 = true; + } + else hasValue1 = (d_valuation.isSatLiteral(literal1) && d_valuation.hasSatValue(literal1, satValue1)); + if (hasValue1) { + if (satValue1) return true; + Node literal2 = Rewriter::rewrite(aj.eqNode(bj)); + bool hasValue2, satValue2; + if (literal2 == ff) { + hasValue2 = true; + satValue2 = false; } + else if (literal2 == tt) { + hasValue2 = true; + satValue2 = true; + } + else hasValue2 = (d_valuation.isSatLiteral(literal2) && d_valuation.hasSatValue(literal2, satValue2)); + if (hasValue2) { + if (satValue2) return true; + // conflict + Assert(!satValue1 && !satValue2); + Assert(literal1.getKind() == kind::EQUAL && literal2.getKind() == kind::EQUAL); + NodeBuilder<2> nb(kind::AND); + literal1 = areDisequal(literal1[0],literal1[1]); + literal2 = areDisequal(literal2[0],literal2[1]); + Assert(!literal1.isNull() && !literal2.isNull()); + nb << literal1.notNode() << literal2.notNode(); + literal1 = nb; + d_out->conflict(literal1, false); + return true; + } + } if(alreadyAddedRow(a,b,i,j)) { // Debug("arrays-lem")<<"isRedundantInContext already added "< #include @@ -113,6 +114,18 @@ private: CongruenceClosure d_cc; + /** + * (Temporary) fact manager for preprocessing - eventually handle this with + * something more standard (like congruence closure module) + */ + StaticFactManager d_staticFactManager; + + /** + * Cache for proprocessing of atoms. + */ + typedef std::hash_map NodeMap; + NodeMap d_ppCache; + /** * Union find for storing the equalities. */ @@ -347,6 +360,8 @@ private: bool d_donePreregister; + Node preprocessTerm(TNode term); + Node recursivePreprocessTerm(TNode term); public: TheoryArrays(context::Context* c, OutputChannel& out, Valuation valuation); @@ -464,6 +479,8 @@ public: void explain(TNode n); Node getValue(TNode n); + SolveStatus solve(TNode in, SubstitutionMap& outSubstitutions); + Node preprocess(TNode atom); void shutdown() { } std::string identify() const { return std::string("TheoryArrays"); } diff --git a/src/theory/arrays/theory_arrays_rewriter.h b/src/theory/arrays/theory_arrays_rewriter.h index c37cbe68c..059b7ce8b 100644 --- a/src/theory/arrays/theory_arrays_rewriter.h +++ b/src/theory/arrays/theory_arrays_rewriter.h @@ -34,51 +34,51 @@ public: static RewriteResponse postRewrite(TNode node) { Debug("arrays-postrewrite") << "Arrays::postRewrite start " << node << std::endl; - if(node.getKind() == kind::EQUAL || node.getKind() == kind::IFF) { - if(node[0] == node[1]) { - return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true)); - } - // checks for RoW axiom: (select ( store a i v) i) = v and rewrites it - // to true - if(node[0].getKind()==kind::SELECT) { - TNode a = node[0][0]; - TNode j = node[0][1]; - if(a.getKind()==kind::STORE) { - TNode b = a[0]; - TNode i = a[1]; - TNode v = a[2]; - if(v == node[1] && i == j) { - Debug("arrays-postrewrite") << "Arrays::postRewrite true" << std::endl; - return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true)); - } + switch (node.getKind()) { + case kind::SELECT: { + // select(store(a,i,v),i) = v + TNode store = node[0]; + if (store.getKind() == kind::STORE && + store[1] == node[1]) { + return RewriteResponse(REWRITE_DONE, store[2]); } + break; } - - if (node[0] > node[1]) { - Node newNode = NodeManager::currentNM()->mkNode(node.getKind(), node[1], node[0]); - // If we've switched theories, we need to rewrite again (TODO: THIS IS HACK, once theories accept eq, change) - if (Theory::theoryOf(newNode[0]) != Theory::theoryOf(newNode[1])) { + case kind::STORE: { + TNode store = node[0]; + TNode value = node[2]; + // store(a,i,select(a,i)) = a + if (value.getKind() == kind::SELECT && + value[0] == store && + value[1] == node[1]) { + return RewriteResponse(REWRITE_DONE, store); + } + // store(store(a,i,v),i,w) = store(a,i,w) + if (store.getKind() == kind::STORE && + store[1] == node[1]) { + Node newNode = NodeManager::currentNM()->mkNode(kind::STORE, store[0], store[1], value); return RewriteResponse(REWRITE_AGAIN_FULL, newNode); - } else { - return RewriteResponse(REWRITE_DONE, newNode); } + break; } - } - // FIXME: would it be better to move in preRewrite? - // if yes don't need the above case - if (node.getKind()==kind::SELECT) { - // we are rewriting (select (store a i v) i) to v - TNode a = node[0]; - TNode i = node[1]; - if(a.getKind() == kind::STORE) { - TNode b = a[0]; - TNode j = a[1]; - TNode v = a[2]; - if(i==j) { - Debug("arrays-postrewrite") << "Arrays::postrewrite to " << v << std::endl; - return RewriteResponse(REWRITE_AGAIN_FULL, v); + case kind::EQUAL: + case kind::IFF: { + if(node[0] == node[1]) { + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true)); + } + if (node[0] > node[1]) { + Node newNode = NodeManager::currentNM()->mkNode(node.getKind(), node[1], node[0]); + // If we've switched theories, we need to rewrite again (TODO: THIS IS HACK, once theories accept eq, change) + if (Theory::theoryOf(newNode[0]) != Theory::theoryOf(newNode[1])) { + return RewriteResponse(REWRITE_AGAIN_FULL, newNode); + } else { + return RewriteResponse(REWRITE_DONE, newNode); + } } + break; } + default: + break; } return RewriteResponse(REWRITE_DONE, node); diff --git a/src/theory/booleans/theory_bool_rewriter.cpp b/src/theory/booleans/theory_bool_rewriter.cpp index 18aa71667..d2693268f 100644 --- a/src/theory/booleans/theory_bool_rewriter.cpp +++ b/src/theory/booleans/theory_bool_rewriter.cpp @@ -42,14 +42,51 @@ RewriteResponse TheoryBoolRewriter::preRewrite(TNode n) { if (n[0] == ff) return RewriteResponse(REWRITE_AGAIN, n[1]); if (n[1] == ff) return RewriteResponse(REWRITE_AGAIN, n[0]); } + else { + bool done = true; + TNode::iterator i = n.begin(), iend = n.end(); + for(; i != iend; ++i) { + if (*i == tt) return RewriteResponse(REWRITE_DONE, tt); + if (*i == ff) done = false; + } + if (!done) { + NodeBuilder<> nb(kind::OR); + for(i = n.begin(); i != iend; ++i) { + if (*i != ff) nb << *i; + } + if (nb.getNumChildren() == 0) return RewriteResponse(REWRITE_DONE, ff); + if (nb.getNumChildren() == 1) return RewriteResponse(REWRITE_AGAIN, nb.getChild(0)); + return RewriteResponse(REWRITE_AGAIN, nb.constructNode()); + } + } break; } case kind::AND: { + //TODO: Why REWRITE_AGAIN here? if (n.getNumChildren() == 2) { if (n[0] == ff || n[1] == ff) return RewriteResponse(REWRITE_DONE, ff); if (n[0] == tt) return RewriteResponse(REWRITE_AGAIN, n[1]); if (n[1] == tt) return RewriteResponse(REWRITE_AGAIN, n[0]); } + else { + bool done = true; + TNode::iterator i = n.begin(), iend = n.end(); + for(; i != iend; ++i) { + if (*i == ff) return RewriteResponse(REWRITE_DONE, ff); + if (*i == tt) done = false; + } + if (!done) { + NodeBuilder<> nb(kind::AND); + for(i = n.begin(); i != iend; ++i) { + if (*i != tt) { + nb << *i; + } + } + if (nb.getNumChildren() == 0) return RewriteResponse(REWRITE_DONE, tt); + if (nb.getNumChildren() == 1) return RewriteResponse(REWRITE_AGAIN, nb.getChild(0)); + return RewriteResponse(REWRITE_AGAIN, nb.constructNode()); + } + } break; } case kind::IMPLIES: { diff --git a/src/theory/valuation.cpp b/src/theory/valuation.cpp index 0aefd7f21..5002c8a59 100644 --- a/src/theory/valuation.cpp +++ b/src/theory/valuation.cpp @@ -27,6 +27,10 @@ Node Valuation::getValue(TNode n) const { return d_engine->getValue(n); } +bool Valuation::isSatLiteral(TNode n) const { + return d_engine->getPropEngine()->isSatLiteral(n); +} + bool Valuation::hasSatValue(TNode n, bool& value) const { return d_engine->getPropEngine()->hasValue(n, value); } diff --git a/src/theory/valuation.h b/src/theory/valuation.h index ea6772ce8..58615f481 100644 --- a/src/theory/valuation.h +++ b/src/theory/valuation.h @@ -41,6 +41,11 @@ public: Node getValue(TNode n) const; + /* + * Return true if n has an associated SAT literal + */ + bool isSatLiteral(TNode n) const; + /** * Get the current SAT assignment to the node n. * diff --git a/src/util/ntuple.h b/src/util/ntuple.h index a3b0dfdf4..4c9a033a1 100644 --- a/src/util/ntuple.h +++ b/src/util/ntuple.h @@ -45,12 +45,9 @@ public: T2 second; T3 third; T4 fourth; - quad(const T1& t1, const T2& t2, const T3& t3, const T4& t4) { - first = t1; - second = t2; - third = t3; - fourth = t4; - } + quad(const T1& t1, const T2& t2, const T3& t3, const T4& t4) + : first(t1), second(t2), third(t3), fourth(t4) + { } };/* class quad<> */ template -- 2.30.2