From 89620a0d73e7134437a39d742e91de11a08a4962 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Wed, 21 Apr 2021 21:42:08 -0500 Subject: [PATCH] Move expand definition from Theory to TheoryRewriter (#6408) This is work towards eliminating global calls to getCurrentSmtEngine()->expandDefinition. The next step will be to add Rewriter::expandDefinition. --- src/smt/expand_definitions.cpp | 3 +- src/theory/arith/arith_preprocess.cpp | 4 +- src/theory/arith/arith_preprocess.h | 5 +- src/theory/arith/arith_rewriter.cpp | 12 + src/theory/arith/arith_rewriter.h | 10 + src/theory/arith/theory_arith.cpp | 13 +- src/theory/arith/theory_arith.h | 7 +- src/theory/arrays/theory_arrays.cpp | 58 +- src/theory/arrays/theory_arrays.h | 2 - src/theory/arrays/theory_arrays_rewriter.cpp | 616 +++++++++++++++++++ src/theory/arrays/theory_arrays_rewriter.h | 458 +------------- src/theory/bags/bags_rewriter.h | 1 - src/theory/bags/theory_bags.cpp | 6 - src/theory/bags/theory_bags.h | 1 - src/theory/bv/theory_bv.cpp | 25 +- src/theory/bv/theory_bv.h | 2 - src/theory/bv/theory_bv_rewriter.cpp | 20 + src/theory/bv/theory_bv_rewriter.h | 2 + src/theory/datatypes/datatypes_rewriter.cpp | 106 ++++ src/theory/datatypes/datatypes_rewriter.h | 2 + src/theory/datatypes/theory_datatypes.cpp | 118 +--- src/theory/datatypes/theory_datatypes.h | 1 - src/theory/fp/theory_fp.cpp | 5 - src/theory/fp/theory_fp.h | 1 - src/theory/fp/theory_fp_rewriter.h | 4 +- src/theory/sets/theory_sets.cpp | 6 - src/theory/sets/theory_sets.h | 2 - src/theory/strings/sequences_rewriter.cpp | 25 + src/theory/strings/sequences_rewriter.h | 2 + src/theory/strings/theory_strings.cpp | 23 - src/theory/strings/theory_strings.h | 2 - src/theory/theory.h | 33 - src/theory/theory_rewriter.cpp | 6 + src/theory/theory_rewriter.h | 24 + 34 files changed, 852 insertions(+), 753 deletions(-) diff --git a/src/smt/expand_definitions.cpp b/src/smt/expand_definitions.cpp index c5080db81..d331e8e78 100644 --- a/src/smt/expand_definitions.cpp +++ b/src/smt/expand_definitions.cpp @@ -255,9 +255,10 @@ TrustNode ExpandDefs::expandDefinitions( // do not do any theory stuff if expandOnly is true theory::Theory* t = d_smt.getTheoryEngine()->theoryOf(node); + theory::TheoryRewriter* tr = t->getTheoryRewriter(); Assert(t != NULL); - TrustNode trn = t->expandDefinition(n); + TrustNode trn = tr->expandDefinition(n); if (!trn.isNull()) { node = trn.getNode(); diff --git a/src/theory/arith/arith_preprocess.cpp b/src/theory/arith/arith_preprocess.cpp index a33d802f1..d5533de24 100644 --- a/src/theory/arith/arith_preprocess.cpp +++ b/src/theory/arith/arith_preprocess.cpp @@ -26,8 +26,8 @@ namespace arith { ArithPreprocess::ArithPreprocess(ArithState& state, InferenceManager& im, ProofNodeManager* pnm, - const LogicInfo& info) - : d_im(im), d_opElim(pnm, info), d_reduced(state.getUserContext()) + OperatorElim& oe) + : d_im(im), d_opElim(oe), d_reduced(state.getUserContext()) { } TrustNode ArithPreprocess::eliminate(TNode n, diff --git a/src/theory/arith/arith_preprocess.h b/src/theory/arith/arith_preprocess.h index ea24e5066..63b4515e7 100644 --- a/src/theory/arith/arith_preprocess.h +++ b/src/theory/arith/arith_preprocess.h @@ -31,6 +31,7 @@ namespace arith { class ArithState; class InferenceManager; +class OperatorElim; /** * This module can be used for (on demand) elimination of extended arithmetic @@ -45,7 +46,7 @@ class ArithPreprocess ArithPreprocess(ArithState& state, InferenceManager& im, ProofNodeManager* pnm, - const LogicInfo& info); + OperatorElim& oe); ~ArithPreprocess() {} /** * Call eliminate operators on formula n, return the resulting trust node, @@ -80,7 +81,7 @@ class ArithPreprocess /** Reference to the inference manager */ InferenceManager& d_im; /** The operator elimination utility */ - OperatorElim d_opElim; + OperatorElim& d_opElim; /** The set of assertions that were reduced */ context::CDHashMap d_reduced; }; diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp index 83aaaadd8..b8135127d 100644 --- a/src/theory/arith/arith_rewriter.cpp +++ b/src/theory/arith/arith_rewriter.cpp @@ -26,6 +26,7 @@ #include "theory/arith/arith_rewriter.h" #include "theory/arith/arith_utilities.h" #include "theory/arith/normal_form.h" +#include "theory/arith/operator_elim.h" #include "theory/theory.h" #include "util/iand.h" @@ -33,6 +34,8 @@ namespace cvc5 { namespace theory { namespace arith { +ArithRewriter::ArithRewriter(OperatorElim& oe) : d_opElim(oe) {} + bool ArithRewriter::isAtom(TNode n) { Kind k = n.getKind(); return arith::isRelationOperator(k) || k == kind::IS_INTEGER @@ -893,6 +896,15 @@ RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre) return RewriteResponse(REWRITE_DONE, t); } +TrustNode ArithRewriter::expandDefinition(Node node) +{ + // call eliminate operators, to eliminate partial operators only + std::vector lems; + TrustNode ret = d_opElim.eliminate(node, lems, true); + Assert(lems.empty()); + return ret; +} + RewriteResponse ArithRewriter::returnRewrite(TNode t, Node ret, Rewrite r) { Trace("arith-rewrite") << "ArithRewriter : " << t << " == " << ret << " by " diff --git a/src/theory/arith/arith_rewriter.h b/src/theory/arith/arith_rewriter.h index e476fbd62..6a92ba1cc 100644 --- a/src/theory/arith/arith_rewriter.h +++ b/src/theory/arith/arith_rewriter.h @@ -28,11 +28,19 @@ namespace cvc5 { namespace theory { namespace arith { +class OperatorElim; + class ArithRewriter : public TheoryRewriter { public: + ArithRewriter(OperatorElim& oe); RewriteResponse preRewrite(TNode n) override; RewriteResponse postRewrite(TNode n) override; + /** + * Expand definition, which eliminates extended operators like div/mod in + * the given node. + */ + TrustNode expandDefinition(Node node) override; private: static Node makeSubtractionNode(TNode l, TNode r); @@ -70,6 +78,8 @@ class ArithRewriter : public TheoryRewriter } /** return rewrite */ static RewriteResponse returnRewrite(TNode t, Node ret, Rewrite r); + /** The operator elimination utility */ + OperatorElim& d_opElim; }; /* class ArithRewriter */ } // namespace arith diff --git a/src/theory/arith/theory_arith.cpp b/src/theory/arith/theory_arith.cpp index 181a816c2..1843ddb8a 100644 --- a/src/theory/arith/theory_arith.cpp +++ b/src/theory/arith/theory_arith.cpp @@ -49,7 +49,9 @@ TheoryArith::TheoryArith(context::Context* c, d_astate(*d_internal, c, u, valuation), d_im(*this, d_astate, pnm), d_nonlinearExtension(nullptr), - d_arithPreproc(d_astate, d_im, pnm, logicInfo) + d_opElim(pnm, logicInfo), + d_arithPreproc(d_astate, d_im, pnm, d_opElim), + d_rewriter(d_opElim) { // indicate we are using the theory state object and inference manager d_theoryState = &d_astate; @@ -103,15 +105,6 @@ void TheoryArith::preRegisterTerm(TNode n) d_internal->preRegisterTerm(n); } -TrustNode TheoryArith::expandDefinition(Node node) -{ - // call eliminate operators, to eliminate partial operators only - std::vector lems; - TrustNode ret = d_arithPreproc.eliminate(node, lems, true); - Assert(lems.empty()); - return ret; -} - void TheoryArith::notifySharedTerm(TNode n) { d_internal->notifySharedTerm(n); } TrustNode TheoryArith::ppRewrite(TNode atom, std::vector& lems) diff --git a/src/theory/arith/theory_arith.h b/src/theory/arith/theory_arith.h index 26a33e247..43c962e30 100644 --- a/src/theory/arith/theory_arith.h +++ b/src/theory/arith/theory_arith.h @@ -74,11 +74,6 @@ class TheoryArith : public Theory { /** finish initialization */ void finishInit() override; //--------------------------------- end initialization - /** - * Expand definition, which eliminates extended operators like div/mod in - * the given node. - */ - TrustNode expandDefinition(Node node) override; /** * Does non-context dependent setup for a node connected to a theory. */ @@ -158,6 +153,8 @@ class TheoryArith : public Theory { * arithmetic. */ std::unique_ptr d_nonlinearExtension; + /** The operator elimination utility */ + OperatorElim d_opElim; /** The preprocess utility */ ArithPreprocess d_arithPreproc; /** The theory rewriter for this theory. */ diff --git a/src/theory/arrays/theory_arrays.cpp b/src/theory/arrays/theory_arrays.cpp index 1a1090f68..e887feccb 100644 --- a/src/theory/arrays/theory_arrays.cpp +++ b/src/theory/arrays/theory_arrays.cpp @@ -300,7 +300,7 @@ Node TheoryArrays::solveWrite(TNode term, bool solve1, bool solve2, bool ppCheck TrustNode TheoryArrays::ppRewrite(TNode term, std::vector& lems) { // first, see if we need to expand definitions - TrustNode texp = expandDefinition(term); + TrustNode texp = d_rewriter.expandDefinition(term); if (!texp.isNull()) { return texp; @@ -2068,62 +2068,6 @@ std::string TheoryArrays::TheoryArraysDecisionStrategy::identify() const return std::string("th_arrays_dec"); } -TrustNode TheoryArrays::expandDefinition(Node node) -{ - NodeManager* nm = NodeManager::currentNM(); - Kind kind = node.getKind(); - - /* Expand - * - * (eqrange a b i j) - * - * to - * - * forall k . i <= k <= j => a[k] = b[k] - * - */ - if (kind == kind::EQ_RANGE) - { - TNode a = node[0]; - TNode b = node[1]; - TNode i = node[2]; - TNode j = node[3]; - Node k = nm->mkBoundVar(i.getType()); - Node bvl = nm->mkNode(kind::BOUND_VAR_LIST, k); - TypeNode type = k.getType(); - - Kind kle; - Node range; - if (type.isBitVector()) - { - kle = kind::BITVECTOR_ULE; - } - else if (type.isFloatingPoint()) - { - kle = kind::FLOATINGPOINT_LEQ; - } - else if (type.isInteger() || type.isReal()) - { - kle = kind::LEQ; - } - else - { - Unimplemented() << "Type " << type << " is not supported for predicate " - << kind; - } - - range = nm->mkNode(kind::AND, nm->mkNode(kle, i, k), nm->mkNode(kle, k, j)); - - Node eq = nm->mkNode(kind::EQUAL, - nm->mkNode(kind::SELECT, a, k), - nm->mkNode(kind::SELECT, b, k)); - Node implies = nm->mkNode(kind::IMPLIES, range, eq); - Node ret = nm->mkNode(kind::FORALL, bvl, implies); - return TrustNode::mkTrustRewrite(node, ret, nullptr); - } - return TrustNode::null(); -} - void TheoryArrays::computeRelevantTerms(std::set& termSet) { NodeManager* nm = NodeManager::currentNM(); diff --git a/src/theory/arrays/theory_arrays.h b/src/theory/arrays/theory_arrays.h index 7cf8d52e3..f9813cd3f 100644 --- a/src/theory/arrays/theory_arrays.h +++ b/src/theory/arrays/theory_arrays.h @@ -158,8 +158,6 @@ class TheoryArrays : public Theory { std::string identify() const override { return std::string("TheoryArrays"); } - TrustNode expandDefinition(Node node) override; - ///////////////////////////////////////////////////////////////////////////// // PREPROCESSING ///////////////////////////////////////////////////////////////////////////// diff --git a/src/theory/arrays/theory_arrays_rewriter.cpp b/src/theory/arrays/theory_arrays_rewriter.cpp index 323dd0046..6269cb5dd 100644 --- a/src/theory/arrays/theory_arrays_rewriter.cpp +++ b/src/theory/arrays/theory_arrays_rewriter.cpp @@ -45,6 +45,622 @@ void setMostFrequentValueCount(TNode store, uint64_t count) { return store.setAttribute(ArrayConstantMostFrequentValueCountAttr(), count); } +Node TheoryArraysRewriter::normalizeConstant(TNode node) +{ + return normalizeConstant(node, node[1].getType().getCardinality()); +} + +// this function is called by printers when using the option "--model-u-dt-enum" +Node TheoryArraysRewriter::normalizeConstant(TNode node, Cardinality indexCard) +{ + TNode store = node[0]; + TNode index = node[1]; + TNode value = node[2]; + + std::vector indices; + std::vector elements; + + // Normal form for nested stores is just ordering by index - but also need + // to check if we are writing to default value + + // Go through nested stores looking for where to insert index + // Also check whether we are replacing an existing store + TNode replacedValue; + uint32_t depth = 1; + uint32_t valCount = 1; + while (store.getKind() == kind::STORE) + { + if (index == store[1]) + { + replacedValue = store[2]; + store = store[0]; + break; + } + else if (index >= store[1]) + { + break; + } + if (value == store[2]) + { + valCount += 1; + } + depth += 1; + indices.push_back(store[1]); + elements.push_back(store[2]); + store = store[0]; + } + Node n = store; + + // Get the default value at the bottom of the nested stores + while (store.getKind() == kind::STORE) + { + if (value == store[2]) + { + valCount += 1; + } + depth += 1; + store = store[0]; + } + Assert(store.getKind() == kind::STORE_ALL); + ArrayStoreAll storeAll = store.getConst(); + Node defaultValue = storeAll.getValue(); + NodeManager* nm = NodeManager::currentNM(); + + // Check if we are writing to default value - if so the store + // to index can be ignored + if (value == defaultValue) + { + if (replacedValue.isNull()) + { + // Quick exit - if writing to default value and nothing was + // replaced, we can just return node[0] + return node[0]; + } + // else rebuild the store without the replaced write and then exit + } + else + { + n = nm->mkNode(kind::STORE, n, index, value); + } + + // Build the rest of the store after inserting/deleting + while (!indices.empty()) + { + n = nm->mkNode(kind::STORE, n, indices.back(), elements.back()); + indices.pop_back(); + elements.pop_back(); + } + + // Ready to exit if write was to the default value (see previous comment) + if (value == defaultValue) + { + return n; + } + + if (indexCard.isInfinite()) + { + return n; + } + + // When index sort is finite, we have to check whether there is any value + // that is written to more than the default value. If so, it must become + // the new default value + + TNode mostFrequentValue; + uint32_t mostFrequentValueCount = 0; + store = node[0]; + if (store.getKind() == kind::STORE) + { + mostFrequentValue = getMostFrequentValue(store); + mostFrequentValueCount = getMostFrequentValueCount(store); + } + + // Compute the most frequently written value for n + if (valCount > mostFrequentValueCount + || (valCount == mostFrequentValueCount && value < mostFrequentValue)) + { + mostFrequentValue = value; + mostFrequentValueCount = valCount; + } + + // Need to make sure the default value count is larger, or the same and the + // default value is expression-order-less-than nextValue + Cardinality::CardinalityComparison compare = + indexCard.compare(mostFrequentValueCount + depth); + Assert(compare != Cardinality::UNKNOWN); + if (compare == Cardinality::GREATER + || (compare == Cardinality::EQUAL && (defaultValue < mostFrequentValue))) + { + return n; + } + + // Bad case: have to recompute value counts and/or possibly switch out + // default value + store = n; + std::unordered_set indexSet; + std::unordered_map elementsMap; + std::unordered_map::iterator it; + uint32_t count; + uint32_t max = 0; + TNode maxValue; + while (store.getKind() == kind::STORE) + { + indices.push_back(store[1]); + indexSet.insert(store[1]); + elements.push_back(store[2]); + it = elementsMap.find(store[2]); + if (it != elementsMap.end()) + { + (*it).second = (*it).second + 1; + count = (*it).second; + } + else + { + elementsMap[store[2]] = 1; + count = 1; + } + if (count > max || (count == max && store[2] < maxValue)) + { + max = count; + maxValue = store[2]; + } + store = store[0]; + } + + Assert(depth == indices.size()); + compare = indexCard.compare(max + depth); + Assert(compare != Cardinality::UNKNOWN); + if (compare == Cardinality::GREATER + || (compare == Cardinality::EQUAL && (defaultValue < maxValue))) + { + Assert(!replacedValue.isNull() && mostFrequentValue == replacedValue); + return n; + } + + // Out of luck: have to swap out default value + + // Enumerate values from index type into newIndices and sort + std::vector newIndices; + TypeEnumerator te(index.getType()); + bool needToSort = false; + uint32_t numTe = 0; + while (!te.isFinished() + && (!indexCard.isFinite() + || numTe < indexCard.getFiniteCardinality().toUnsignedInt())) + { + if (indexSet.find(*te) == indexSet.end()) + { + if (!newIndices.empty() && (!(newIndices.back() < (*te)))) + { + needToSort = true; + } + newIndices.push_back(*te); + } + ++numTe; + ++te; + } + Assert(indexCard.compare(newIndices.size() + depth) == Cardinality::EQUAL); + if (needToSort) + { + std::sort(newIndices.begin(), newIndices.end()); + } + + n = nm->mkConst(ArrayStoreAll(node.getType(), maxValue)); + std::vector::iterator itNew = newIndices.begin(), + it_end = newIndices.end(); + while (itNew != it_end || !indices.empty()) + { + if (itNew != it_end && (indices.empty() || (*itNew) < indices.back())) + { + n = nm->mkNode(kind::STORE, n, (*itNew), defaultValue); + ++itNew; + } + else if (itNew == it_end || indices.back() < (*itNew)) + { + if (elements.back() != maxValue) + { + n = nm->mkNode(kind::STORE, n, indices.back(), elements.back()); + } + indices.pop_back(); + elements.pop_back(); + } + } + return n; +} + +RewriteResponse TheoryArraysRewriter::postRewrite(TNode node) +{ + Trace("arrays-postrewrite") + << "Arrays::postRewrite start " << node << std::endl; + switch (node.getKind()) + { + case kind::SELECT: + { + TNode store = node[0]; + TNode index = node[1]; + Node n; + bool val; + while (store.getKind() == kind::STORE) + { + if (index == store[1]) + { + val = true; + } + else if (index.isConst() && store[1].isConst()) + { + val = false; + } + else + { + n = Rewriter::rewrite(mkEqNode(store[1], index)); + if (n.getKind() != kind::CONST_BOOLEAN) + { + break; + } + val = n.getConst(); + } + if (val) + { + // select(store(a,i,v),j) = v if i = j + Trace("arrays-postrewrite") + << "Arrays::postRewrite returning " << store[2] << std::endl; + return RewriteResponse(REWRITE_DONE, store[2]); + } + // select(store(a,i,v),j) = select(a,j) if i /= j + store = store[0]; + } + if (store.getKind() == kind::STORE_ALL) + { + // select(store_all(v),i) = v + ArrayStoreAll storeAll = store.getConst(); + n = storeAll.getValue(); + Trace("arrays-postrewrite") + << "Arrays::postRewrite returning " << n << std::endl; + Assert(n.isConst()); + return RewriteResponse(REWRITE_DONE, n); + } + else if (store != node[0]) + { + n = NodeManager::currentNM()->mkNode(kind::SELECT, store, index); + Trace("arrays-postrewrite") + << "Arrays::postRewrite returning " << n << std::endl; + return RewriteResponse(REWRITE_DONE, n); + } + break; + } + case kind::STORE: + { + TNode store = node[0]; + TNode value = node[2]; + // store(a,i,select(a,i)) = a + if (value.getKind() == kind::SELECT && value[0] == store + && value[1] == node[1]) + { + Trace("arrays-postrewrite") + << "Arrays::postRewrite returning " << store << std::endl; + return RewriteResponse(REWRITE_DONE, store); + } + TNode index = node[1]; + if (store.isConst() && index.isConst() && value.isConst()) + { + // normalize constant + Node n = normalizeConstant(node); + Assert(n.isConst()); + Trace("arrays-postrewrite") + << "Arrays::postRewrite returning " << n << std::endl; + return RewriteResponse(REWRITE_DONE, n); + } + if (store.getKind() == kind::STORE) + { + // store(store(a,i,v),j,w) + bool val; + if (index == store[1]) + { + val = true; + } + else if (index.isConst() && store[1].isConst()) + { + val = false; + } + else + { + Node eqRewritten = Rewriter::rewrite(mkEqNode(store[1], index)); + if (eqRewritten.getKind() != kind::CONST_BOOLEAN) + { + Trace("arrays-postrewrite") + << "Arrays::postRewrite returning " << node << std::endl; + return RewriteResponse(REWRITE_DONE, node); + } + val = eqRewritten.getConst(); + } + NodeManager* nm = NodeManager::currentNM(); + if (val) + { + // store(store(a,i,v),i,w) = store(a,i,w) + Node result = nm->mkNode(kind::STORE, store[0], index, value); + Trace("arrays-postrewrite") + << "Arrays::postRewrite returning " << result << std::endl; + return RewriteResponse(REWRITE_AGAIN, result); + } + else if (index < store[1]) + { + // store(store(a,i,v),j,w) = store(store(a,j,w),i,v) + // IF i != j and j comes before i in the ordering + std::vector indices; + std::vector elements; + indices.push_back(store[1]); + elements.push_back(store[2]); + store = store[0]; + Node n; + while (store.getKind() == kind::STORE) + { + if (index == store[1]) + { + val = true; + } + else if (index.isConst() && store[1].isConst()) + { + val = false; + } + else + { + n = Rewriter::rewrite(mkEqNode(store[1], index)); + if (n.getKind() != kind::CONST_BOOLEAN) + { + break; + } + val = n.getConst(); + } + if (val) + { + store = store[0]; + break; + } + else if (!(index < store[1])) + { + break; + } + indices.push_back(store[1]); + elements.push_back(store[2]); + store = store[0]; + } + if (value.getKind() == kind::SELECT && value[0] == store + && value[1] == index) + { + n = store; + } + else + { + n = nm->mkNode(kind::STORE, store, index, value); + } + while (!indices.empty()) + { + n = nm->mkNode(kind::STORE, n, indices.back(), elements.back()); + indices.pop_back(); + elements.pop_back(); + } + Assert(n != node); + Trace("arrays-postrewrite") + << "Arrays::postRewrite returning " << n << std::endl; + return RewriteResponse(REWRITE_AGAIN, n); + } + } + break; + } + case kind::EQUAL: + { + if (node[0] == node[1]) + { + Trace("arrays-postrewrite") + << "Arrays::postRewrite returning true" << std::endl; + return RewriteResponse(REWRITE_DONE, + NodeManager::currentNM()->mkConst(true)); + } + else if (node[0].isConst() && node[1].isConst()) + { + Trace("arrays-postrewrite") + << "Arrays::postRewrite returning false" << std::endl; + return RewriteResponse(REWRITE_DONE, + NodeManager::currentNM()->mkConst(false)); + } + if (node[0] > node[1]) + { + Node newNode = + NodeManager::currentNM()->mkNode(node.getKind(), node[1], node[0]); + Trace("arrays-postrewrite") + << "Arrays::postRewrite returning " << newNode << std::endl; + return RewriteResponse(REWRITE_DONE, newNode); + } + break; + } + default: break; + } + Trace("arrays-postrewrite") + << "Arrays::postRewrite returning " << node << std::endl; + return RewriteResponse(REWRITE_DONE, node); +} + +RewriteResponse TheoryArraysRewriter::preRewrite(TNode node) +{ + Trace("arrays-prerewrite") + << "Arrays::preRewrite start " << node << std::endl; + switch (node.getKind()) + { + case kind::SELECT: + { + TNode store = node[0]; + TNode index = node[1]; + Node n; + bool val; + while (store.getKind() == kind::STORE) + { + if (index == store[1]) + { + val = true; + } + else if (index.isConst() && store[1].isConst()) + { + val = false; + } + else + { + n = Rewriter::rewrite(mkEqNode(store[1], index)); + if (n.getKind() != kind::CONST_BOOLEAN) + { + break; + } + val = n.getConst(); + } + if (val) + { + // select(store(a,i,v),j) = v if i = j + Trace("arrays-prerewrite") + << "Arrays::preRewrite returning " << store[2] << std::endl; + return RewriteResponse(REWRITE_AGAIN, store[2]); + } + // select(store(a,i,v),j) = select(a,j) if i /= j + store = store[0]; + } + if (store.getKind() == kind::STORE_ALL) + { + // select(store_all(v),i) = v + ArrayStoreAll storeAll = store.getConst(); + n = storeAll.getValue(); + Trace("arrays-prerewrite") + << "Arrays::preRewrite returning " << n << std::endl; + Assert(n.isConst()); + return RewriteResponse(REWRITE_DONE, n); + } + else if (store != node[0]) + { + n = NodeManager::currentNM()->mkNode(kind::SELECT, store, index); + Trace("arrays-prerewrite") + << "Arrays::preRewrite returning " << n << std::endl; + return RewriteResponse(REWRITE_DONE, n); + } + break; + } + case kind::STORE: + { + TNode store = node[0]; + TNode value = node[2]; + // store(a,i,select(a,i)) = a + if (value.getKind() == kind::SELECT && value[0] == store + && value[1] == node[1]) + { + Trace("arrays-prerewrite") + << "Arrays::preRewrite returning " << store << std::endl; + return RewriteResponse(REWRITE_AGAIN, store); + } + if (store.getKind() == kind::STORE) + { + // store(store(a,i,v),j,w) + TNode index = node[1]; + bool val; + if (index == store[1]) + { + val = true; + } + else if (index.isConst() && store[1].isConst()) + { + val = false; + } + else + { + Node eqRewritten = Rewriter::rewrite(mkEqNode(store[1], index)); + if (eqRewritten.getKind() != kind::CONST_BOOLEAN) + { + break; + } + val = eqRewritten.getConst(); + } + NodeManager* nm = NodeManager::currentNM(); + if (val) + { + // store(store(a,i,v),i,w) = store(a,i,w) + Node newNode = nm->mkNode(kind::STORE, store[0], index, value); + Trace("arrays-prerewrite") + << "Arrays::preRewrite returning " << newNode << std::endl; + return RewriteResponse(REWRITE_DONE, newNode); + } + } + break; + } + case kind::EQUAL: + { + if (node[0] == node[1]) + { + Trace("arrays-prerewrite") + << "Arrays::preRewrite returning true" << std::endl; + return RewriteResponse(REWRITE_DONE, + NodeManager::currentNM()->mkConst(true)); + } + break; + } + default: break; + } + + Trace("arrays-prerewrite") + << "Arrays::preRewrite returning " << node << std::endl; + return RewriteResponse(REWRITE_DONE, node); +} + +TrustNode TheoryArraysRewriter::expandDefinition(Node node) +{ + NodeManager* nm = NodeManager::currentNM(); + Kind kind = node.getKind(); + + /* Expand + * + * (eqrange a b i j) + * + * to + * + * forall k . i <= k <= j => a[k] = b[k] + * + */ + if (kind == kind::EQ_RANGE) + { + TNode a = node[0]; + TNode b = node[1]; + TNode i = node[2]; + TNode j = node[3]; + Node k = nm->mkBoundVar(i.getType()); + Node bvl = nm->mkNode(kind::BOUND_VAR_LIST, k); + TypeNode type = k.getType(); + + Kind kle; + Node range; + if (type.isBitVector()) + { + kle = kind::BITVECTOR_ULE; + } + else if (type.isFloatingPoint()) + { + kle = kind::FLOATINGPOINT_LEQ; + } + else if (type.isInteger() || type.isReal()) + { + kle = kind::LEQ; + } + else + { + Unimplemented() << "Type " << type << " is not supported for predicate " + << kind; + } + + range = nm->mkNode(kind::AND, nm->mkNode(kle, i, k), nm->mkNode(kle, k, j)); + + Node eq = nm->mkNode(kind::EQUAL, + nm->mkNode(kind::SELECT, a, k), + nm->mkNode(kind::SELECT, b, k)); + Node implies = nm->mkNode(kind::IMPLIES, range, eq); + Node ret = nm->mkNode(kind::FORALL, bvl, implies); + return TrustNode::mkTrustRewrite(node, ret, nullptr); + } + return TrustNode::null(); +} + } // namespace arrays } // namespace theory } // namespace cvc5 diff --git a/src/theory/arrays/theory_arrays_rewriter.h b/src/theory/arrays/theory_arrays_rewriter.h index 0bbfc0846..498266ce3 100644 --- a/src/theory/arrays/theory_arrays_rewriter.h +++ b/src/theory/arrays/theory_arrays_rewriter.h @@ -43,459 +43,21 @@ static inline Node mkEqNode(Node a, Node b) { class TheoryArraysRewriter : public TheoryRewriter { - static Node normalizeConstant(TNode node) { - return normalizeConstant(node, node[1].getType().getCardinality()); - } + /** + * Puts array constant node into normal form. This is so that array constants + * that are distinct nodes are semantically disequal. + */ + static Node normalizeConstant(TNode node); public: - //this function is called by printers when using the option "--model-u-dt-enum" - static Node normalizeConstant(TNode node, Cardinality indexCard) { - TNode store = node[0]; - TNode index = node[1]; - TNode value = node[2]; + /** Normalize a constant whose index type has cardinality indexCard */ + static Node normalizeConstant(TNode node, Cardinality indexCard); - std::vector indices; - std::vector elements; + RewriteResponse postRewrite(TNode node) override; - // Normal form for nested stores is just ordering by index - but also need - // to check if we are writing to default value + RewriteResponse preRewrite(TNode node) override; - // Go through nested stores looking for where to insert index - // Also check whether we are replacing an existing store - TNode replacedValue; - unsigned depth = 1; - unsigned valCount = 1; - while (store.getKind() == kind::STORE) { - if (index == store[1]) { - replacedValue = store[2]; - store = store[0]; - break; - } - else if (!(index < store[1])) { - break; - } - if (value == store[2]) { - valCount += 1; - } - depth += 1; - indices.push_back(store[1]); - elements.push_back(store[2]); - store = store[0]; - } - Node n = store; - - // Get the default value at the bottom of the nested stores - while (store.getKind() == kind::STORE) { - if (value == store[2]) { - valCount += 1; - } - depth += 1; - store = store[0]; - } - Assert(store.getKind() == kind::STORE_ALL); - ArrayStoreAll storeAll = store.getConst(); - Node defaultValue = storeAll.getValue(); - NodeManager* nm = NodeManager::currentNM(); - - // Check if we are writing to default value - if so the store - // to index can be ignored - if (value == defaultValue) { - if (replacedValue.isNull()) { - // Quick exit - if writing to default value and nothing was - // replaced, we can just return node[0] - return node[0]; - } - // else rebuild the store without the replaced write and then exit - } - else { - n = nm->mkNode(kind::STORE, n, index, value); - } - - // Build the rest of the store after inserting/deleting - while (!indices.empty()) { - n = nm->mkNode(kind::STORE, n, indices.back(), elements.back()); - indices.pop_back(); - elements.pop_back(); - } - - // Ready to exit if write was to the default value (see previous comment) - if (value == defaultValue) { - return n; - } - - if (indexCard.isInfinite()) { - return n; - } - - // When index sort is finite, we have to check whether there is any value - // that is written to more than the default value. If so, it must become - // the new default value - - TNode mostFrequentValue; - unsigned mostFrequentValueCount = 0; - store = node[0]; - if (store.getKind() == kind::STORE) { - mostFrequentValue = getMostFrequentValue(store); - mostFrequentValueCount = getMostFrequentValueCount(store); - } - - // Compute the most frequently written value for n - if (valCount > mostFrequentValueCount || - (valCount == mostFrequentValueCount && value < mostFrequentValue)) { - mostFrequentValue = value; - mostFrequentValueCount = valCount; - } - - // Need to make sure the default value count is larger, or the same and the default value is expression-order-less-than nextValue - Cardinality::CardinalityComparison compare = indexCard.compare(mostFrequentValueCount + depth); - Assert(compare != Cardinality::UNKNOWN); - if (compare == Cardinality::GREATER || - (compare == Cardinality::EQUAL && (defaultValue < mostFrequentValue))) { - return n; - } - - // Bad case: have to recompute value counts and/or possibly switch out - // default value - store = n; - std::unordered_set indexSet; - std::unordered_map elementsMap; - std::unordered_map::iterator it; - unsigned count; - unsigned max = 0; - TNode maxValue; - while (store.getKind() == kind::STORE) { - indices.push_back(store[1]); - indexSet.insert(store[1]); - elements.push_back(store[2]); - it = elementsMap.find(store[2]); - if (it != elementsMap.end()) { - (*it).second = (*it).second + 1; - count = (*it).second; - } - else { - elementsMap[store[2]] = 1; - count = 1; - } - if (count > max || - (count == max && store[2] < maxValue)) { - max = count; - maxValue = store[2]; - } - store = store[0]; - } - - Assert(depth == indices.size()); - compare = indexCard.compare(max + depth); - Assert(compare != Cardinality::UNKNOWN); - if (compare == Cardinality::GREATER || - (compare == Cardinality::EQUAL && (defaultValue < maxValue))) { - Assert(!replacedValue.isNull() && mostFrequentValue == replacedValue); - return n; - } - - // Out of luck: have to swap out default value - - // Enumerate values from index type into newIndices and sort - std::vector newIndices; - TypeEnumerator te(index.getType()); - bool needToSort = false; - unsigned numTe = 0; - while (!te.isFinished() && (!indexCard.isFinite() || numTemkConst(ArrayStoreAll(node.getType(), maxValue)); - std::vector::iterator itNew = newIndices.begin(), it_end = newIndices.end(); - while (itNew != it_end || !indices.empty()) { - if (itNew != it_end && (indices.empty() || (*itNew) < indices.back())) { - n = nm->mkNode(kind::STORE, n, (*itNew), defaultValue); - ++itNew; - } - else if (itNew == it_end || indices.back() < (*itNew)) { - if (elements.back() != maxValue) { - n = nm->mkNode(kind::STORE, n, indices.back(), elements.back()); - } - indices.pop_back(); - elements.pop_back(); - } - } - return n; - } - - public: - RewriteResponse postRewrite(TNode node) override - { - Trace("arrays-postrewrite") << "Arrays::postRewrite start " << node << std::endl; - switch (node.getKind()) { - case kind::SELECT: { - TNode store = node[0]; - TNode index = node[1]; - Node n; - bool val; - while (store.getKind() == kind::STORE) { - if (index == store[1]) { - val = true; - } - else if (index.isConst() && store[1].isConst()) { - val = false; - } - else { - n = Rewriter::rewrite(mkEqNode(store[1], index)); - if (n.getKind() != kind::CONST_BOOLEAN) { - break; - } - val = n.getConst(); - } - if (val) { - // select(store(a,i,v),j) = v if i = j - Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << store[2] << std::endl; - return RewriteResponse(REWRITE_DONE, store[2]); - } - // select(store(a,i,v),j) = select(a,j) if i /= j - store = store[0]; - } - if (store.getKind() == kind::STORE_ALL) { - // select(store_all(v),i) = v - ArrayStoreAll storeAll = store.getConst(); - n = storeAll.getValue(); - Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << n << std::endl; - Assert(n.isConst()); - return RewriteResponse(REWRITE_DONE, n); - } - else if (store != node[0]) { - n = NodeManager::currentNM()->mkNode(kind::SELECT, store, index); - Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << n << std::endl; - return RewriteResponse(REWRITE_DONE, n); - } - break; - } - case kind::STORE: { - TNode store = node[0]; - TNode value = node[2]; - // store(a,i,select(a,i)) = a - if (value.getKind() == kind::SELECT && - value[0] == store && - value[1] == node[1]) { - Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << store << std::endl; - return RewriteResponse(REWRITE_DONE, store); - } - TNode index = node[1]; - if (store.isConst() && index.isConst() && value.isConst()) { - // normalize constant - Node n = normalizeConstant(node); - Assert(n.isConst()); - Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << n << std::endl; - return RewriteResponse(REWRITE_DONE, n); - } - if (store.getKind() == kind::STORE) { - // store(store(a,i,v),j,w) - bool val; - if (index == store[1]) { - val = true; - } - else if (index.isConst() && store[1].isConst()) { - val = false; - } - else { - Node eqRewritten = Rewriter::rewrite(mkEqNode(store[1], index)); - if (eqRewritten.getKind() != kind::CONST_BOOLEAN) { - Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << node << std::endl; - return RewriteResponse(REWRITE_DONE, node); - } - val = eqRewritten.getConst(); - } - NodeManager* nm = NodeManager::currentNM(); - if (val) { - // store(store(a,i,v),i,w) = store(a,i,w) - Node result = nm->mkNode(kind::STORE, store[0], index, value); - Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << result << std::endl; - return RewriteResponse(REWRITE_AGAIN, result); - } - else if (index < store[1]) { - // store(store(a,i,v),j,w) = store(store(a,j,w),i,v) - // IF i != j and j comes before i in the ordering - std::vector indices; - std::vector elements; - indices.push_back(store[1]); - elements.push_back(store[2]); - store = store[0]; - Node n; - while (store.getKind() == kind::STORE) { - if (index == store[1]) { - val = true; - } - else if (index.isConst() && store[1].isConst()) { - val = false; - } - else { - n = Rewriter::rewrite(mkEqNode(store[1], index)); - if (n.getKind() != kind::CONST_BOOLEAN) { - break; - } - val = n.getConst(); - } - if (val) { - store = store[0]; - break; - } - else if (!(index < store[1])) { - break; - } - indices.push_back(store[1]); - elements.push_back(store[2]); - store = store[0]; - } - if (value.getKind() == kind::SELECT && - value[0] == store && - value[1] == index) { - n = store; - } - else { - n = nm->mkNode(kind::STORE, store, index, value); - } - while (!indices.empty()) { - n = nm->mkNode(kind::STORE, n, indices.back(), elements.back()); - indices.pop_back(); - elements.pop_back(); - } - Assert(n != node); - Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << n << std::endl; - return RewriteResponse(REWRITE_AGAIN, n); - } - } - break; - } - case kind::EQUAL:{ - if(node[0] == node[1]) { - Trace("arrays-postrewrite") << "Arrays::postRewrite returning true" << std::endl; - return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true)); - } - else if (node[0].isConst() && node[1].isConst()) { - Trace("arrays-postrewrite") << "Arrays::postRewrite returning false" << std::endl; - return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(false)); - } - if (node[0] > node[1]) { - Node newNode = NodeManager::currentNM()->mkNode(node.getKind(), node[1], node[0]); - Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << newNode << std::endl; - return RewriteResponse(REWRITE_DONE, newNode); - } - break; - } - default: - break; - } - Trace("arrays-postrewrite") << "Arrays::postRewrite returning " << node << std::endl; - return RewriteResponse(REWRITE_DONE, node); - } - - RewriteResponse preRewrite(TNode node) override - { - Trace("arrays-prerewrite") << "Arrays::preRewrite start " << node << std::endl; - switch (node.getKind()) { - case kind::SELECT: { - TNode store = node[0]; - TNode index = node[1]; - Node n; - bool val; - while (store.getKind() == kind::STORE) { - if (index == store[1]) { - val = true; - } - else if (index.isConst() && store[1].isConst()) { - val = false; - } - else { - n = Rewriter::rewrite(mkEqNode(store[1], index)); - if (n.getKind() != kind::CONST_BOOLEAN) { - break; - } - val = n.getConst(); - } - if (val) { - // select(store(a,i,v),j) = v if i = j - Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << store[2] << std::endl; - return RewriteResponse(REWRITE_AGAIN, store[2]); - } - // select(store(a,i,v),j) = select(a,j) if i /= j - store = store[0]; - } - if (store.getKind() == kind::STORE_ALL) { - // select(store_all(v),i) = v - ArrayStoreAll storeAll = store.getConst(); - n = storeAll.getValue(); - Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << n << std::endl; - Assert(n.isConst()); - return RewriteResponse(REWRITE_DONE, n); - } - else if (store != node[0]) { - n = NodeManager::currentNM()->mkNode(kind::SELECT, store, index); - Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << n << std::endl; - return RewriteResponse(REWRITE_DONE, n); - } - break; - } - case kind::STORE: { - TNode store = node[0]; - TNode value = node[2]; - // store(a,i,select(a,i)) = a - if (value.getKind() == kind::SELECT && - value[0] == store && - value[1] == node[1]) { - Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << store << std::endl; - return RewriteResponse(REWRITE_AGAIN, store); - } - if (store.getKind() == kind::STORE) { - // store(store(a,i,v),j,w) - TNode index = node[1]; - bool val; - if (index == store[1]) { - val = true; - } - else if (index.isConst() && store[1].isConst()) { - val = false; - } - else { - Node eqRewritten = Rewriter::rewrite(mkEqNode(store[1], index)); - if (eqRewritten.getKind() != kind::CONST_BOOLEAN) { - break; - } - val = eqRewritten.getConst(); - } - NodeManager* nm = NodeManager::currentNM(); - if (val) { - // store(store(a,i,v),i,w) = store(a,i,w) - Node newNode = nm->mkNode(kind::STORE, store[0], index, value); - Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << newNode << std::endl; - return RewriteResponse(REWRITE_DONE, newNode); - } - } - break; - } - case kind::EQUAL:{ - if(node[0] == node[1]) { - Trace("arrays-prerewrite") << "Arrays::preRewrite returning true" << std::endl; - return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true)); - } - break; - } - default: - break; - } - - Trace("arrays-prerewrite") << "Arrays::preRewrite returning " << node << std::endl; - return RewriteResponse(REWRITE_DONE, node); - } + TrustNode expandDefinition(Node node) override; static inline void init() {} static inline void shutdown() {} diff --git a/src/theory/bags/bags_rewriter.h b/src/theory/bags/bags_rewriter.h index 54c1d7253..83f364f9d 100644 --- a/src/theory/bags/bags_rewriter.h +++ b/src/theory/bags/bags_rewriter.h @@ -56,7 +56,6 @@ class BagsRewriter : public TheoryRewriter * See the rewrite rules for these kinds below. */ RewriteResponse preRewrite(TNode n) override; - private: /** * rewrites for n include: diff --git a/src/theory/bags/theory_bags.cpp b/src/theory/bags/theory_bags.cpp index 89312f99a..d5917f91c 100644 --- a/src/theory/bags/theory_bags.cpp +++ b/src/theory/bags/theory_bags.cpp @@ -226,12 +226,6 @@ void TheoryBags::preRegisterTerm(TNode n) } } -TrustNode TheoryBags::expandDefinition(Node n) -{ - // TODO(projects#224): add choose and is_singleton here - return TrustNode::null(); -} - void TheoryBags::presolve() {} /**************************** eq::NotifyClass *****************************/ diff --git a/src/theory/bags/theory_bags.h b/src/theory/bags/theory_bags.h index 7b9299f54..4ed131e64 100644 --- a/src/theory/bags/theory_bags.h +++ b/src/theory/bags/theory_bags.h @@ -72,7 +72,6 @@ class TheoryBags : public Theory Node getModelValue(TNode) override; std::string identify() const override { return "THEORY_BAGS"; } void preRegisterTerm(TNode n) override; - TrustNode expandDefinition(Node n) override; void presolve() override; private: diff --git a/src/theory/bv/theory_bv.cpp b/src/theory/bv/theory_bv.cpp index 06f837c7f..a0f3f28f7 100644 --- a/src/theory/bv/theory_bv.cpp +++ b/src/theory/bv/theory_bv.cpp @@ -134,29 +134,6 @@ void TheoryBV::finishInit() } } -TrustNode TheoryBV::expandDefinition(Node node) -{ - Debug("bitvector-expandDefinition") - << "TheoryBV::expandDefinition(" << node << ")" << std::endl; - - Node ret; - switch (node.getKind()) - { - case kind::BITVECTOR_SDIV: - case kind::BITVECTOR_SREM: - case kind::BITVECTOR_SMOD: - ret = TheoryBVRewriter::eliminateBVSDiv(node); - break; - - default: break; - } - if (!ret.isNull() && node != ret) - { - return TrustNode::mkTrustRewrite(node, ret, nullptr); - } - return TrustNode::null(); -} - void TheoryBV::preRegisterTerm(TNode node) { d_internal->preRegisterTerm(node); @@ -211,7 +188,7 @@ Theory::PPAssertStatus TheoryBV::ppAssert( TrustNode TheoryBV::ppRewrite(TNode t, std::vector& lems) { // first, see if we need to expand definitions - TrustNode texp = expandDefinition(t); + TrustNode texp = d_rewriter.expandDefinition(t); if (!texp.isNull()) { return texp; diff --git a/src/theory/bv/theory_bv.h b/src/theory/bv/theory_bv.h index 0546e83ac..1f14f05b0 100644 --- a/src/theory/bv/theory_bv.h +++ b/src/theory/bv/theory_bv.h @@ -63,8 +63,6 @@ class TheoryBV : public Theory void finishInit() override; - TrustNode expandDefinition(Node node) override; - void preRegisterTerm(TNode n) override; bool preCheck(Effort e) override; diff --git a/src/theory/bv/theory_bv_rewriter.cpp b/src/theory/bv/theory_bv_rewriter.cpp index 076ea67a9..9b3fde6fb 100644 --- a/src/theory/bv/theory_bv_rewriter.cpp +++ b/src/theory/bv/theory_bv_rewriter.cpp @@ -54,6 +54,26 @@ RewriteResponse TheoryBVRewriter::postRewrite(TNode node) { return res; } +TrustNode TheoryBVRewriter::expandDefinition(Node node) +{ + Debug("bitvector-expandDefinition") + << "TheoryBV::expandDefinition(" << node << ")" << std::endl; + Node ret; + switch (node.getKind()) + { + case kind::BITVECTOR_SDIV: + case kind::BITVECTOR_SREM: + case kind::BITVECTOR_SMOD: ret = eliminateBVSDiv(node); break; + + default: break; + } + if (!ret.isNull() && node != ret) + { + return TrustNode::mkTrustRewrite(node, ret, nullptr); + } + return TrustNode::null(); +} + RewriteResponse TheoryBVRewriter::RewriteBitOf(TNode node, bool prerewrite) { Node resultNode = LinearRewriteStrategy>::apply(node); diff --git a/src/theory/bv/theory_bv_rewriter.h b/src/theory/bv/theory_bv_rewriter.h index dca6bc9b4..e35184084 100644 --- a/src/theory/bv/theory_bv_rewriter.h +++ b/src/theory/bv/theory_bv_rewriter.h @@ -47,6 +47,8 @@ class TheoryBVRewriter : public TheoryRewriter RewriteResponse postRewrite(TNode node) override; RewriteResponse preRewrite(TNode node) override; + TrustNode expandDefinition(Node node) override; + private: static RewriteResponse IdentityRewrite(TNode node, bool prerewrite = false); static RewriteResponse UndefinedRewrite(TNode node, bool prerewrite = false); diff --git a/src/theory/datatypes/datatypes_rewriter.cpp b/src/theory/datatypes/datatypes_rewriter.cpp index b85ac44bf..a6d3a45bc 100644 --- a/src/theory/datatypes/datatypes_rewriter.cpp +++ b/src/theory/datatypes/datatypes_rewriter.cpp @@ -18,6 +18,7 @@ #include "expr/dtype.h" #include "expr/dtype_cons.h" #include "expr/node_algorithm.h" +#include "expr/skolem_manager.h" #include "expr/sygus_datatype.h" #include "options/datatypes_options.h" #include "theory/datatypes/sygus_datatype_utils.h" @@ -793,6 +794,111 @@ Node DatatypesRewriter::replaceDebruijn(Node n, return n; } +TrustNode DatatypesRewriter::expandDefinition(Node n) +{ + NodeManager* nm = NodeManager::currentNM(); + TypeNode tn = n.getType(); + Node ret; + switch (n.getKind()) + { + case kind::APPLY_SELECTOR: + { + Node selector = n.getOperator(); + // APPLY_SELECTOR always applies to an external selector, cindexOf is + // legal here + size_t cindex = utils::cindexOf(selector); + const DType& dt = utils::datatypeOf(selector); + const DTypeConstructor& c = dt[cindex]; + Node selector_use; + TypeNode ndt = n[0].getType(); + if (options::dtSharedSelectors()) + { + size_t selectorIndex = utils::indexOf(selector); + Trace("dt-expand") << "...selector index = " << selectorIndex + << std::endl; + Assert(selectorIndex < c.getNumArgs()); + selector_use = c.getSelectorInternal(ndt, selectorIndex); + } + else + { + selector_use = selector; + } + Node sel = nm->mkNode(kind::APPLY_SELECTOR_TOTAL, selector_use, n[0]); + if (options::dtRewriteErrorSel()) + { + ret = sel; + } + else + { + Node tester = c.getTester(); + Node tst = nm->mkNode(APPLY_TESTER, tester, n[0]); + SkolemManager* sm = nm->getSkolemManager(); + TypeNode tnw = nm->mkFunctionType(ndt, n.getType()); + Node f = + sm->mkSkolemFunction(SkolemFunId::SELECTOR_WRONG, tnw, selector); + Node sk = nm->mkNode(kind::APPLY_UF, f, n[0]); + ret = nm->mkNode(kind::ITE, tst, sel, sk); + Trace("dt-expand") << "Expand def : " << n << " to " << ret + << std::endl; + } + } + break; + case TUPLE_UPDATE: + case RECORD_UPDATE: + { + Assert(tn.isDatatype()); + const DType& dt = tn.getDType(); + NodeBuilder b(APPLY_CONSTRUCTOR); + b << dt[0].getConstructor(); + size_t size, updateIndex; + if (n.getKind() == TUPLE_UPDATE) + { + Assert(tn.isTuple()); + size = tn.getTupleLength(); + updateIndex = n.getOperator().getConst().getIndex(); + } + else + { + Assert(tn.isRecord()); + const DTypeConstructor& recCons = dt[0]; + size = recCons.getNumArgs(); + // get the index for the name + updateIndex = recCons.getSelectorIndexForName( + n.getOperator().getConst().getField()); + } + Debug("tuprec") << "expr is " << n << std::endl; + Debug("tuprec") << "updateIndex is " << updateIndex << std::endl; + Debug("tuprec") << "t is " << tn << std::endl; + Debug("tuprec") << "t has arity " << size << std::endl; + for (size_t i = 0; i < size; ++i) + { + if (i == updateIndex) + { + b << n[1]; + Debug("tuprec") << "arg " << i << " gets updated to " << n[1] + << std::endl; + } + else + { + b << nm->mkNode( + APPLY_SELECTOR_TOTAL, dt[0].getSelectorInternal(tn, i), n[0]); + Debug("tuprec") << "arg " << i << " copies " + << b[b.getNumChildren() - 1] << std::endl; + } + } + ret = b; + Debug("tuprec") << "return " << ret << std::endl; + } + break; + default: break; + } + if (!ret.isNull() && n != ret) + { + return TrustNode::mkTrustRewrite(n, ret, nullptr); + } + return TrustNode::null(); +} + } // namespace datatypes } // namespace theory } // namespace cvc5 diff --git a/src/theory/datatypes/datatypes_rewriter.h b/src/theory/datatypes/datatypes_rewriter.h index 3b9b14fb7..c9a40ff7b 100644 --- a/src/theory/datatypes/datatypes_rewriter.h +++ b/src/theory/datatypes/datatypes_rewriter.h @@ -48,6 +48,8 @@ class DatatypesRewriter : public TheoryRewriter * on all top-level codatatype subterms of n. */ static Node normalizeConstant(Node n); + /** expand defintions */ + TrustNode expandDefinition(Node n) override; private: /** rewrite constructor term in */ diff --git a/src/theory/datatypes/theory_datatypes.cpp b/src/theory/datatypes/theory_datatypes.cpp index 01ef77172..f9d08dfc2 100644 --- a/src/theory/datatypes/theory_datatypes.cpp +++ b/src/theory/datatypes/theory_datatypes.cpp @@ -482,127 +482,11 @@ void TheoryDatatypes::preRegisterTerm(TNode n) d_im.process(); } -TrustNode TheoryDatatypes::expandDefinition(Node n) -{ - NodeManager* nm = NodeManager::currentNM(); - TypeNode tn = n.getType(); - Node ret; - switch (n.getKind()) - { - case kind::APPLY_SELECTOR: - { - Node selector = n.getOperator(); - // APPLY_SELECTOR always applies to an external selector, cindexOf is - // legal here - size_t cindex = utils::cindexOf(selector); - const DType& dt = utils::datatypeOf(selector); - const DTypeConstructor& c = dt[cindex]; - Node selector_use; - TypeNode ndt = n[0].getType(); - if (options::dtSharedSelectors()) - { - size_t selectorIndex = utils::indexOf(selector); - Trace("dt-expand") << "...selector index = " << selectorIndex - << std::endl; - Assert(selectorIndex < c.getNumArgs()); - selector_use = c.getSelectorInternal(ndt, selectorIndex); - }else{ - selector_use = selector; - } - Node sel = nm->mkNode(kind::APPLY_SELECTOR_TOTAL, selector_use, n[0]); - if (options::dtRewriteErrorSel()) - { - ret = sel; - } - else - { - Node tester = c.getTester(); - Node tst = nm->mkNode(APPLY_TESTER, tester, n[0]); - tst = Rewriter::rewrite(tst); - if (tst == d_true) - { - ret = sel; - }else{ - SkolemManager* sm = nm->getSkolemManager(); - TypeNode tnw = nm->mkFunctionType(ndt, n.getType()); - Node f = - sm->mkSkolemFunction(SkolemFunId::SELECTOR_WRONG, tnw, selector); - Node sk = nm->mkNode(kind::APPLY_UF, f, n[0]); - if (tst == nm->mkConst(false)) - { - ret = sk; - } - else - { - ret = nm->mkNode(kind::ITE, tst, sel, sk); - } - } - Trace("dt-expand") << "Expand def : " << n << " to " << ret - << std::endl; - } - } - break; - case TUPLE_UPDATE: - case RECORD_UPDATE: - { - Assert(tn.isDatatype()); - const DType& dt = tn.getDType(); - NodeBuilder b(APPLY_CONSTRUCTOR); - b << dt[0].getConstructor(); - size_t size, updateIndex; - if (n.getKind() == TUPLE_UPDATE) - { - Assert(tn.isTuple()); - size = tn.getTupleLength(); - updateIndex = n.getOperator().getConst().getIndex(); - } - else - { - Assert(tn.isRecord()); - const DTypeConstructor& recCons = dt[0]; - size = recCons.getNumArgs(); - // get the index for the name - updateIndex = recCons.getSelectorIndexForName( - n.getOperator().getConst().getField()); - } - Debug("tuprec") << "expr is " << n << std::endl; - Debug("tuprec") << "updateIndex is " << updateIndex << std::endl; - Debug("tuprec") << "t is " << tn << std::endl; - Debug("tuprec") << "t has arity " << size << std::endl; - for (size_t i = 0; i < size; ++i) - { - if (i == updateIndex) - { - b << n[1]; - Debug("tuprec") << "arg " << i << " gets updated to " << n[1] - << std::endl; - } - else - { - b << nm->mkNode( - APPLY_SELECTOR_TOTAL, dt[0].getSelectorInternal(tn, i), n[0]); - Debug("tuprec") << "arg " << i << " copies " - << b[b.getNumChildren() - 1] << std::endl; - } - } - ret = b; - Debug("tuprec") << "return " << ret << std::endl; - } - break; - default: break; - } - if (!ret.isNull() && n != ret) - { - return TrustNode::mkTrustRewrite(n, ret, nullptr); - } - return TrustNode::null(); -} - TrustNode TheoryDatatypes::ppRewrite(TNode in, std::vector& lems) { Debug("tuprec") << "TheoryDatatypes::ppRewrite(" << in << ")" << endl; // first, see if we need to expand definitions - TrustNode texp = expandDefinition(in); + TrustNode texp = d_rewriter.expandDefinition(in); if (!texp.isNull()) { return texp; diff --git a/src/theory/datatypes/theory_datatypes.h b/src/theory/datatypes/theory_datatypes.h index eb55ce6b0..1ae122f5e 100644 --- a/src/theory/datatypes/theory_datatypes.h +++ b/src/theory/datatypes/theory_datatypes.h @@ -227,7 +227,6 @@ private: void notifyFact(TNode atom, bool pol, TNode fact, bool isInternal) override; //--------------------------------- end standard check void preRegisterTerm(TNode n) override; - TrustNode expandDefinition(Node n) override; TrustNode ppRewrite(TNode n, std::vector& lems) override; EqualityStatus getEqualityStatus(TNode a, TNode b) override; std::string identify() const override diff --git a/src/theory/fp/theory_fp.cpp b/src/theory/fp/theory_fp.cpp index 6629a839d..01dace411 100644 --- a/src/theory/fp/theory_fp.cpp +++ b/src/theory/fp/theory_fp.cpp @@ -681,11 +681,6 @@ void TheoryFp::preRegisterTerm(TNode node) return; } -TrustNode TheoryFp::expandDefinition(Node node) -{ - return d_rewriter.expandDefinition(node); -} - void TheoryFp::handleLemma(Node node, InferenceId id) { Trace("fp") << "TheoryFp::handleLemma(): asserting " << node << std::endl; diff --git a/src/theory/fp/theory_fp.h b/src/theory/fp/theory_fp.h index 78791b9b4..8cf4c4cc5 100644 --- a/src/theory/fp/theory_fp.h +++ b/src/theory/fp/theory_fp.h @@ -62,7 +62,6 @@ class TheoryFp : public Theory //--------------------------------- end initialization void preRegisterTerm(TNode node) override; - TrustNode expandDefinition(Node node) override; TrustNode ppRewrite(TNode node, std::vector& lems) override; //--------------------------------- standard check diff --git a/src/theory/fp/theory_fp_rewriter.h b/src/theory/fp/theory_fp_rewriter.h index 97c0e216b..027dd9819 100644 --- a/src/theory/fp/theory_fp_rewriter.h +++ b/src/theory/fp/theory_fp_rewriter.h @@ -46,8 +46,8 @@ class TheoryFpRewriter : public TheoryRewriter // often this will suffice return postRewrite(equality).d_node; } - /** Expand definitions in node. */ - TrustNode expandDefinition(Node node); + /** Expand definitions in node */ + TrustNode expandDefinition(Node node) override; protected: /** TODO: document (projects issue #265) */ diff --git a/src/theory/sets/theory_sets.cpp b/src/theory/sets/theory_sets.cpp index fdb744d67..8406bd14a 100644 --- a/src/theory/sets/theory_sets.cpp +++ b/src/theory/sets/theory_sets.cpp @@ -131,12 +131,6 @@ void TheorySets::preRegisterTerm(TNode node) d_internal->preRegisterTerm(node); } -TrustNode TheorySets::expandDefinition(Node n) -{ - // we currently do not expand any set operators - return TrustNode::null(); -} - TrustNode TheorySets::ppRewrite(TNode n, std::vector& lems) { Kind nk = n.getKind(); diff --git a/src/theory/sets/theory_sets.h b/src/theory/sets/theory_sets.h index bb8741e35..e99d25d36 100644 --- a/src/theory/sets/theory_sets.h +++ b/src/theory/sets/theory_sets.h @@ -78,8 +78,6 @@ class TheorySets : public Theory Node getModelValue(TNode) override; std::string identify() const override { return "THEORY_SETS"; } void preRegisterTerm(TNode node) override; - /** Expand partial operators (choose) from n. */ - TrustNode expandDefinition(Node n) override; /** * If the sets-ext option is not set and we have an extended operator, * we throw an exception. Additionally, we expand operators like choose diff --git a/src/theory/strings/sequences_rewriter.cpp b/src/theory/strings/sequences_rewriter.cpp index 431f488a5..84127e8e3 100644 --- a/src/theory/strings/sequences_rewriter.cpp +++ b/src/theory/strings/sequences_rewriter.cpp @@ -21,6 +21,7 @@ #include "theory/rewriter.h" #include "theory/strings/arith_entail.h" #include "theory/strings/regexp_entail.h" +#include "theory/strings/skolem_cache.h" #include "theory/strings/strings_rewriter.h" #include "theory/strings/theory_strings_utils.h" #include "theory/strings/word.h" @@ -1514,6 +1515,30 @@ RewriteResponse SequencesRewriter::preRewrite(TNode node) return RewriteResponse(REWRITE_DONE, node); } +TrustNode SequencesRewriter::expandDefinition(Node node) +{ + Trace("strings-exp-def") << "SequencesRewriter::expandDefinition : " << node + << std::endl; + + if (node.getKind() == kind::SEQ_NTH) + { + NodeManager* nm = NodeManager::currentNM(); + Node s = node[0]; + Node n = node[1]; + // seq.nth(s, n) --> ite(0 <= n < len(s), seq.nth_total(s,n), Uf(s, n)) + Node cond = nm->mkNode(AND, + nm->mkNode(LEQ, nm->mkConst(Rational(0)), n), + nm->mkNode(LT, n, nm->mkNode(STRING_LENGTH, s))); + Node ss = nm->mkNode(SEQ_NTH_TOTAL, s, n); + Node uf = SkolemCache::mkSkolemSeqNth(s.getType(), "Uf"); + Node u = nm->mkNode(APPLY_UF, uf, s, n); + Node ret = nm->mkNode(ITE, cond, ss, u); + Trace("strings-exp-def") << "...return " << ret << std::endl; + return TrustNode::mkTrustRewrite(node, ret, nullptr); + } + return TrustNode::null(); +} + Node SequencesRewriter::rewriteSeqNth(Node node) { Assert(node.getKind() == SEQ_NTH || node.getKind() == SEQ_NTH_TOTAL); diff --git a/src/theory/strings/sequences_rewriter.h b/src/theory/strings/sequences_rewriter.h index 97db2c7f4..7af24596a 100644 --- a/src/theory/strings/sequences_rewriter.h +++ b/src/theory/strings/sequences_rewriter.h @@ -127,6 +127,8 @@ class SequencesRewriter : public TheoryRewriter public: RewriteResponse postRewrite(TNode node) override; RewriteResponse preRewrite(TNode node) override; + /** Expand definition */ + TrustNode expandDefinition(Node n) override; /** rewrite equality * diff --git a/src/theory/strings/theory_strings.cpp b/src/theory/strings/theory_strings.cpp index 0ed003cc7..956f2148c 100644 --- a/src/theory/strings/theory_strings.cpp +++ b/src/theory/strings/theory_strings.cpp @@ -553,29 +553,6 @@ void TheoryStrings::preRegisterTerm(TNode n) d_extTheory.registerTerm(n); } -TrustNode TheoryStrings::expandDefinition(Node node) -{ - Trace("strings-exp-def") << "TheoryStrings::expandDefinition : " << node << std::endl; - - if (node.getKind() == kind::SEQ_NTH) - { - NodeManager* nm = NodeManager::currentNM(); - Node s = node[0]; - Node n = node[1]; - // seq.nth(s, n) --> ite(0 <= n < len(s), seq.nth_total(s,n), Uf(s, n)) - Node cond = nm->mkNode(AND, - nm->mkNode(LEQ, d_zero, n), - nm->mkNode(LT, n, nm->mkNode(STRING_LENGTH, s))); - Node ss = nm->mkNode(SEQ_NTH_TOTAL, s, n); - Node uf = SkolemCache::mkSkolemSeqNth(s.getType(), "Uf"); - Node u = nm->mkNode(APPLY_UF, uf, s, n); - Node ret = nm->mkNode(ITE, cond, ss, u); - Trace("strings-exp-def") << "...return " << ret << std::endl; - return TrustNode::mkTrustRewrite(node, ret, nullptr); - } - return TrustNode::null(); -} - bool TheoryStrings::preNotifyFact( TNode atom, bool pol, TNode fact, bool isPrereg, bool isInternal) { diff --git a/src/theory/strings/theory_strings.h b/src/theory/strings/theory_strings.h index fb6df80c7..01111880d 100644 --- a/src/theory/strings/theory_strings.h +++ b/src/theory/strings/theory_strings.h @@ -94,8 +94,6 @@ class TheoryStrings : public Theory { void shutdown() override {} /** preregister term */ void preRegisterTerm(TNode n) override; - /** Expand definition */ - TrustNode expandDefinition(Node n) override; //--------------------------------- standard check /** Do we need a check call at last call effort? */ bool needsCheckLastEffort() override; diff --git a/src/theory/theory.h b/src/theory/theory.h index 247ebcf46..9cf663a4f 100644 --- a/src/theory/theory.h +++ b/src/theory/theory.h @@ -497,39 +497,6 @@ class Theory { */ TheoryInferenceManager* getInferenceManager() { return d_inferManager; } - /** - * Expand definitions in the term node. This returns a term that is - * equivalent to node. It wraps this term in a TrustNode of kind - * TrustNodeKind::REWRITE. If node is unchanged by this method, the - * null TrustNode may be returned. This is an optimization to avoid - * constructing the trivial equality (= node node) internally within - * TrustNode. - * - * The purpose of this method is typically to eliminate the operators in node - * that are syntax sugar that cannot otherwise be eliminated during rewriting. - * For example, division relies on the introduction of an uninterpreted - * function for the divide-by-zero case, which we do not introduce with - * the rewriter, since this function may be cached in a non-global fashion. - * - * Some theories have kinds that are effectively definitions and should be - * expanded before they are handled. Definitions allow a much wider range of - * actions than the normal forms given by the rewriter. However no - * assumptions can be made about subterms having been expanded or rewritten. - * Where possible rewrite rules should be used, definitions should only be - * used when rewrites are not possible, for example in handling - * under-specified operations using partially defined functions. - * - * Some theories like sets use expandDefinition as a "context - * independent preRegisterTerm". This is required for cases where - * a theory wants to be notified about a term before preprocessing - * and simplification but doesn't necessarily want to rewrite it. - */ - virtual TrustNode expandDefinition(Node node) - { - // by default, do nothing - return TrustNode::null(); - } - /** * Pre-register a term. Done one time for a Node per SAT context level. */ diff --git a/src/theory/theory_rewriter.cpp b/src/theory/theory_rewriter.cpp index 75bcbff0e..42e9148c2 100644 --- a/src/theory/theory_rewriter.cpp +++ b/src/theory/theory_rewriter.cpp @@ -60,5 +60,11 @@ TrustNode TheoryRewriter::rewriteEqualityExtWithProof(Node node) return TrustNode::null(); } +TrustNode TheoryRewriter::expandDefinition(Node node) +{ + // no expansion + return TrustNode::null(); +} + } // namespace theory } // namespace cvc5 diff --git a/src/theory/theory_rewriter.h b/src/theory/theory_rewriter.h index 2477de51e..031e32db4 100644 --- a/src/theory/theory_rewriter.h +++ b/src/theory/theory_rewriter.h @@ -138,6 +138,30 @@ class TheoryRewriter * node if no rewrites are applied. */ virtual TrustNode rewriteEqualityExtWithProof(Node node); + + /** + * Expand definitions in the term node. This returns a term that is + * equivalent to node. It wraps this term in a TrustNode of kind + * TrustNodeKind::REWRITE. If node is unchanged by this method, the + * null TrustNode may be returned. This is an optimization to avoid + * constructing the trivial equality (= node node) internally within + * TrustNode. + * + * The purpose of this method is typically to eliminate the operators in node + * that are syntax sugar that cannot otherwise be eliminated during rewriting. + * For example, division relies on the introduction of an uninterpreted + * function for the divide-by-zero case, which we do not introduce with + * the standard rewrite methods. + * + * Some theories have kinds that are effectively definitions and should be + * expanded before they are handled. Definitions allow a much wider range of + * actions than the normal forms given by the rewriter. However no + * assumptions can be made about subterms having been expanded or rewritten. + * Where possible rewrite rules should be used, definitions should only be + * used when rewrites are not possible, for example in handling + * under-specified operations using partially defined functions. + */ + virtual TrustNode expandDefinition(Node node); }; } // namespace theory -- 2.30.2