From eaad5bdc7a38fcc38baa0e3b73f6c39a0ec6fb05 Mon Sep 17 00:00:00 2001 From: mudathirmahgoub Date: Mon, 25 Jan 2021 14:38:45 -0600 Subject: [PATCH] Refactor bags::SolverState (#5783) Couple of changes: SolverState now keep tracks of elements per bag instead of per type. bags::InferInfo now stores multiple conclusions (conjuncts). BagSolver applies upward/downward closures for bag elements --- src/theory/bags/bag_solver.cpp | 91 ++++++++++--- src/theory/bags/bag_solver.h | 19 ++- src/theory/bags/bags_rewriter.h | 16 +-- src/theory/bags/infer_info.cpp | 28 ++-- src/theory/bags/infer_info.h | 13 +- src/theory/bags/inference_generator.cpp | 122 +++++++++--------- src/theory/bags/inference_generator.h | 84 +++++++----- src/theory/bags/inference_manager.h | 2 +- src/theory/bags/solver_state.cpp | 105 ++++++++++----- src/theory/bags/solver_state.h | 50 +++++-- src/theory/bags/theory_bags.cpp | 65 +++------- test/regress/CMakeLists.txt | 2 + .../regress1/bags/difference_remove1.smt2 | 10 ++ test/regress/regress1/bags/issue5759.smt2 | 10 ++ 14 files changed, 391 insertions(+), 226 deletions(-) create mode 100644 test/regress/regress1/bags/difference_remove1.smt2 create mode 100644 test/regress/regress1/bags/issue5759.smt2 diff --git a/src/theory/bags/bag_solver.cpp b/src/theory/bags/bag_solver.cpp index 5621a7c1c..495f73723 100644 --- a/src/theory/bags/bag_solver.cpp +++ b/src/theory/bags/bag_solver.cpp @@ -39,25 +39,63 @@ BagSolver::~BagSolver() {} void BagSolver::postCheck() { + d_state.initialize(); + + // At this point, all bag and count representatives should be in the solver + // state. + for (const Node& bag : d_state.getBags()) + { + // iterate through all bags terms in each equivalent class + eq::EqClassIterator it = + eq::EqClassIterator(bag, d_state.getEqualityEngine()); + while (!it.isFinished()) + { + Node n = (*it); + Kind k = n.getKind(); + switch (k) + { + case kind::MK_BAG: checkMkBag(n); break; + case kind::UNION_DISJOINT: checkUnionDisjoint(n); break; + case kind::UNION_MAX: checkUnionMax(n); break; + case kind::DIFFERENCE_SUBTRACT: checkDifferenceSubtract(n); break; + case kind::DIFFERENCE_REMOVE: checkDifferenceRemove(n); break; + default: break; + } + it++; + } + } + + // add non negative constraints for all multiplicities for (const Node& n : d_state.getBags()) { - Kind k = n.getKind(); - switch (k) + for (const Node& e : d_state.getElements(n)) { - case kind::MK_BAG: checkMkBag(n); break; - case kind::UNION_DISJOINT: checkUnionDisjoint(n); break; - case kind::UNION_MAX: checkUnionMax(n); break; - case kind::DIFFERENCE_SUBTRACT: checkDifferenceSubtract(n); break; - default: break; + checkNonNegativeCountTerms(n, e); } } } +set BagSolver::getElementsForBinaryOperator(const Node& n) +{ + set elements; + const set& downwards = d_state.getElements(n); + const set& upwards0 = d_state.getElements(n[0]); + const set& upwards1 = d_state.getElements(n[1]); + + set_union(downwards.begin(), + downwards.end(), + upwards0.begin(), + upwards0.end(), + inserter(elements, elements.begin())); + elements.insert(upwards1.begin(), upwards1.end()); + return elements; +} + void BagSolver::checkUnionDisjoint(const Node& n) { Assert(n.getKind() == UNION_DISJOINT); - TypeNode elementType = n.getType().getBagElementType(); - for (const Node& e : d_state.getElements(elementType)) + std::set elements = getElementsForBinaryOperator(n); + for (const Node& e : elements) { InferenceGenerator ig(&d_state); InferInfo i = ig.unionDisjoint(n, e); @@ -69,8 +107,8 @@ void BagSolver::checkUnionDisjoint(const Node& n) void BagSolver::checkUnionMax(const Node& n) { Assert(n.getKind() == UNION_MAX); - TypeNode elementType = n.getType().getBagElementType(); - for (const Node& e : d_state.getElements(elementType)) + std::set elements = getElementsForBinaryOperator(n); + for (const Node& e : elements) { InferenceGenerator ig(&d_state); InferInfo i = ig.unionMax(n, e); @@ -82,8 +120,8 @@ void BagSolver::checkUnionMax(const Node& n) void BagSolver::checkDifferenceSubtract(const Node& n) { Assert(n.getKind() == DIFFERENCE_SUBTRACT); - TypeNode elementType = n.getType().getBagElementType(); - for (const Node& e : d_state.getElements(elementType)) + std::set elements = getElementsForBinaryOperator(n); + for (const Node& e : elements) { InferenceGenerator ig(&d_state); InferInfo i = ig.differenceSubtract(n, e); @@ -91,11 +129,14 @@ void BagSolver::checkDifferenceSubtract(const Node& n) Trace("bags::BagSolver::postCheck") << i << endl; } } + void BagSolver::checkMkBag(const Node& n) { Assert(n.getKind() == MK_BAG); - TypeNode elementType = n.getType().getBagElementType(); - for (const Node& e : d_state.getElements(elementType)) + Trace("bags::BagSolver::postCheck") + << "BagSolver::checkMkBag Elements of " << n + << " are: " << d_state.getElements(n) << std::endl; + for (const Node& e : d_state.getElements(n)) { InferenceGenerator ig(&d_state); InferInfo i = ig.mkBag(n, e); @@ -103,6 +144,26 @@ void BagSolver::checkMkBag(const Node& n) Trace("bags::BagSolver::postCheck") << i << endl; } } +void BagSolver::checkNonNegativeCountTerms(const Node& bag, const Node& element) +{ + InferenceGenerator ig(&d_state); + InferInfo i = ig.nonNegativeCount(bag, element); + i.process(&d_im, true); + Trace("bags::BagSolver::postCheck") << i << endl; +} + +void BagSolver::checkDifferenceRemove(const Node& n) +{ + Assert(n.getKind() == DIFFERENCE_REMOVE); + std::set elements = getElementsForBinaryOperator(n); + for (const Node& e : elements) + { + InferenceGenerator ig(&d_state); + InferInfo i = ig.differenceRemove(n, e); + i.process(&d_im, true); + Trace("bags::BagSolver::postCheck") << i << endl; + } +} } // namespace bags } // namespace theory diff --git a/src/theory/bags/bag_solver.h b/src/theory/bags/bag_solver.h index 48583d134..b4b18c00c 100644 --- a/src/theory/bags/bag_solver.h +++ b/src/theory/bags/bag_solver.h @@ -41,14 +41,31 @@ class BagSolver void postCheck(); private: - /** apply inference rules for MK_BAG operator */ + /** + * apply inference rules for MK_BAG operator. + * Example: Suppose n = (bag x c), and we have two count terms (bag.count x n) + * and (bag.count y n). + * This function will add inferences for the count terms as documented in + * InferenceGenerator::mkBag. + * Note that element y may not be in bag n. See the documentation of + * SolverState::getElements. + */ void checkMkBag(const Node& n); + /** + * @param n is a bag that has the form (op A B) + * @return the set union of known elements in (op A B) , A, and B. + */ + std::set getElementsForBinaryOperator(const Node& n); /** apply inference rules for union disjoint */ void checkUnionDisjoint(const Node& n); /** apply inference rules for union max */ void checkUnionMax(const Node& n); /** apply inference rules for difference subtract */ void checkDifferenceSubtract(const Node& n); + /** apply inference rules for difference remove */ + void checkDifferenceRemove(const Node& n); + /** apply non negative constraints for multiplicities */ + void checkNonNegativeCountTerms(const Node& bag, const Node& element); /** The solver state object */ SolverState& d_state; diff --git a/src/theory/bags/bags_rewriter.h b/src/theory/bags/bags_rewriter.h index fb76fb1c2..48cd9c419 100644 --- a/src/theory/bags/bags_rewriter.h +++ b/src/theory/bags/bags_rewriter.h @@ -70,8 +70,8 @@ class BagsRewriter : public TheoryRewriter /** * rewrites for n include: - * - (mkBag x 0) = (emptybag T) where T is the type of x - * - (mkBag x (-c)) = (emptybag T) where T is the type of x, and c > 0 is a + * - (bag x 0) = (emptybag T) where T is the type of x + * - (bag x (-c)) = (emptybag T) where T is the type of x, and c > 0 is a * constant * - otherwise = n */ @@ -87,7 +87,7 @@ class BagsRewriter : public TheoryRewriter /** * rewrites for n include: - * - (duplicate_removal (mkBag x n)) = (mkBag x 1) + * - (duplicate_removal (bag x n)) = (bag x 1) * where n is a positive constant */ BagsRewriteResponse rewriteDuplicateRemoval(const TNode& n) const; @@ -171,13 +171,13 @@ class BagsRewriter : public TheoryRewriter BagsRewriteResponse rewriteDifferenceRemove(const TNode& n) const; /** * rewrites for n include: - * - (bag.choose (mkBag x c)) = x where c is a constant > 0 + * - (bag.choose (bag x c)) = x where c is a constant > 0 * - otherwise = n */ BagsRewriteResponse rewriteChoose(const TNode& n) const; /** * rewrites for n include: - * - (bag.card (mkBag x c)) = c where c is a constant > 0 + * - (bag.card (bag x c)) = c where c is a constant > 0 * - (bag.card (union-disjoint A B)) = (+ (bag.card A) (bag.card B)) * - otherwise = n */ @@ -185,19 +185,19 @@ class BagsRewriter : public TheoryRewriter /** * rewrites for n include: - * - (bag.is_singleton (mkBag x c)) = (c == 1) + * - (bag.is_singleton (bag x c)) = (c == 1) */ BagsRewriteResponse rewriteIsSingleton(const TNode& n) const; /** * rewrites for n include: - * - (bag.from_set (singleton (singleton_op Int) x)) = (mkBag x 1) + * - (bag.from_set (singleton (singleton_op Int) x)) = (bag x 1) */ BagsRewriteResponse rewriteFromSet(const TNode& n) const; /** * rewrites for n include: - * - (bag.to_set (mkBag x n)) = (singleton (singleton_op T) x) + * - (bag.to_set (bag x n)) = (singleton (singleton_op T) x) * where n is a positive constant and T is the type of the bag's elements */ BagsRewriteResponse rewriteToSet(const TNode& n) const; diff --git a/src/theory/bags/infer_info.cpp b/src/theory/bags/infer_info.cpp index 1244a43ac..5b3274617 100644 --- a/src/theory/bags/infer_info.cpp +++ b/src/theory/bags/infer_info.cpp @@ -25,6 +25,8 @@ const char* toString(Inference i) switch (i) { case Inference::NONE: return "NONE"; + case Inference::BAG_NON_NEGATIVE_COUNT: return "BAG_NON_NEGATIVE_COUNT"; + case Inference::BAG_MK_BAG_SAME_ELEMENT: return "BAG_MK_BAG_SAME_ELEMENT"; case Inference::BAG_MK_BAG: return "BAG_MK_BAG"; case Inference::BAG_EQUALITY: return "BAG_EQUALITY"; case Inference::BAG_DISEQUALITY: return "BAG_DISEQUALITY"; @@ -62,9 +64,19 @@ bool InferInfo::process(TheoryInferenceManager* im, bool asLemma) if (asLemma) { TrustNode trustedLemma = TrustNode::mkTrustLemma(lemma, nullptr); - return im->trustedLemma(trustedLemma); + im->trustedLemma(trustedLemma); } - Unimplemented(); + else + { + Unimplemented(); + } + for (const auto& pair : d_skolems) + { + Node n = pair.first.eqNode(pair.second); + TrustNode trustedLemma = TrustNode::mkTrustLemma(n, nullptr); + im->trustedLemma(trustedLemma); + } + return true; } bool InferInfo::isTrivial() const @@ -87,21 +99,15 @@ bool InferInfo::isFact() const return !atom.isConst() && atom.getKind() != kind::OR; } -Node InferInfo::getPremises() const -{ - // d_noExplain is a subset of d_ant - NodeManager* nm = NodeManager::currentNM(); - return nm->mkAnd(d_premises); -} - std::ostream& operator<<(std::ostream& out, const InferInfo& ii) { - out << "(infer " << ii.d_id << " " << ii.d_conclusion << std::endl; + out << "(infer :id " << ii.d_id << std::endl; + out << ":conclusion " << ii.d_conclusion << std::endl; if (!ii.d_premises.empty()) { out << " :premise (" << ii.d_premises << ")" << std::endl; } - + out << ":skolems " << ii.d_skolems << std::endl; out << ")"; return out; } diff --git a/src/theory/bags/infer_info.h b/src/theory/bags/infer_info.h index 3edbef737..ecfc354d1 100644 --- a/src/theory/bags/infer_info.h +++ b/src/theory/bags/infer_info.h @@ -33,6 +33,8 @@ namespace bags { enum class Inference : uint32_t { NONE, + BAG_NON_NEGATIVE_COUNT, + BAG_MK_BAG_SAME_ELEMENT, BAG_MK_BAG, BAG_EQUALITY, BAG_DISEQUALITY, @@ -81,7 +83,7 @@ class InferInfo : public TheoryInference bool process(TheoryInferenceManager* im, bool asLemma) override; /** The inference identifier */ Inference d_id; - /** The conclusion */ + /** The conclusions */ Node d_conclusion; /** * The premise(s) of the inference, interpreted conjunctively. These are @@ -90,11 +92,10 @@ class InferInfo : public TheoryInference std::vector d_premises; /** - * A list of new skolems introduced as a result of this inference. They - * are mapped to by a length status, indicating the length constraint that - * can be assumed for them. + * A map of nodes to their skolem variables introduced as a result of this + * inference. */ - std::vector d_newSkolem; + std::map d_skolems; /** Is this infer info trivial? True if d_conc is true. */ bool isTrivial() const; /** @@ -108,8 +109,6 @@ class InferInfo : public TheoryInference * engine with no new external premises (d_noExplain). */ bool isFact() const; - /** Get premises */ - Node getPremises() const; }; /** diff --git a/src/theory/bags/inference_generator.cpp b/src/theory/bags/inference_generator.cpp index 759ea1f0c..7ef126911 100644 --- a/src/theory/bags/inference_generator.cpp +++ b/src/theory/bags/inference_generator.cpp @@ -32,18 +32,33 @@ InferenceGenerator::InferenceGenerator(SolverState* state) : d_state(state) d_one = d_nm->mkConst(Rational(1)); } +InferInfo InferenceGenerator::nonNegativeCount(Node n, Node e) +{ + Assert(n.getType().isBag()); + Assert(e.getType() == n.getType().getBagElementType()); + + InferInfo inferInfo; + inferInfo.d_id = Inference::BAG_NON_NEGATIVE_COUNT; + Node count = d_nm->mkNode(kind::BAG_COUNT, e, n); + + Node gte = d_nm->mkNode(kind::GEQ, count, d_zero); + inferInfo.d_conclusion = gte; + return inferInfo; +} + InferInfo InferenceGenerator::mkBag(Node n, Node e) { Assert(n.getKind() == kind::MK_BAG); Assert(e.getType() == n.getType().getBagElementType()); InferInfo inferInfo; - inferInfo.d_id = Inference::BAG_MK_BAG; - Node count = getMultiplicitySkolem(e, n, inferInfo); + Node skolem = getSkolem(n, inferInfo); + Node count = getMultiplicityTerm(e, skolem); if (n[0] == e) { - // TODO: refactor this with the rewriter + // TODO issue #78: refactor this with BagRewriter // (=> true (= (bag.count e (bag e c)) c)) + inferInfo.d_id = Inference::BAG_MK_BAG_SAME_ELEMENT; inferInfo.d_conclusion = count.eqNode(n[1]); } else @@ -51,7 +66,7 @@ InferInfo InferenceGenerator::mkBag(Node n, Node e) // (=> // true // (= (bag.count e (bag x c)) (ite (= e x) c 0))) - + inferInfo.d_id = Inference::BAG_MK_BAG; Node same = d_nm->mkNode(kind::EQUAL, n[0], e); Node ite = d_nm->mkNode(kind::ITE, same, n[1], d_zero); Node equal = count.eqNode(ite); @@ -60,30 +75,12 @@ InferInfo InferenceGenerator::mkBag(Node n, Node e) return inferInfo; } -InferInfo InferenceGenerator::bagEquality(Node n, Node e) -{ - Assert(n.getKind() == kind::EQUAL && n[0].getType().isBag()); - Assert(e.getType() == n[0].getType().getBagElementType()); - - Node A = n[0]; - Node B = n[1]; - InferInfo inferInfo; - inferInfo.d_id = Inference::BAG_EQUALITY; - inferInfo.d_premises.push_back(n); - Node countA = getMultiplicitySkolem(e, A, inferInfo); - Node countB = getMultiplicitySkolem(e, B, inferInfo); - - Node equal = countA.eqNode(countB); - inferInfo.d_conclusion = equal; - return inferInfo; -} - struct BagsDeqAttributeId { }; typedef expr::Attribute BagsDeqAttribute; -InferInfo InferenceGenerator::bagDisequality(Node n) +InferInfo InferenceGenerator::bagDisequality(Node n, Node reason) { Assert(n.getKind() == kind::NOT && n[0].getKind() == kind::EQUAL); Assert(n[0][0].getType().isBag()); @@ -93,22 +90,19 @@ InferInfo InferenceGenerator::bagDisequality(Node n) InferInfo inferInfo; inferInfo.d_id = Inference::BAG_DISEQUALITY; + inferInfo.d_premises.push_back(reason); TypeNode elementType = A.getType().getBagElementType(); - BoundVarManager* bvm = d_nm->getBoundVarManager(); Node element = bvm->mkBoundVar(n, elementType); - SkolemManager* sm = d_nm->getSkolemManager(); Node skolem = - sm->mkSkolem(element, - n, - "bag_disequal", - "an extensional lemma for disequality of two bags"); + d_sm->mkSkolem(element, + n, + "bag_disequal", + "an extensional lemma for disequality of two bags"); - inferInfo.d_newSkolem.push_back(skolem); - - Node countA = getMultiplicitySkolem(skolem, A, inferInfo); - Node countB = getMultiplicitySkolem(skolem, B, inferInfo); + Node countA = getMultiplicityTerm(skolem, A); + Node countB = getMultiplicityTerm(skolem, B); Node disEqual = countA.eqNode(countB).notNode(); @@ -117,13 +111,20 @@ InferInfo InferenceGenerator::bagDisequality(Node n) return inferInfo; } +Node InferenceGenerator::getSkolem(Node& n, InferInfo& inferInfo) +{ + Node skolem = d_sm->mkPurifySkolem(n, "skolem_bag", "skolem bag"); + inferInfo.d_skolems[n] = skolem; + return skolem; +} + InferInfo InferenceGenerator::bagEmpty(Node e) { EmptyBag emptyBag = EmptyBag(d_nm->mkBagType(e.getType())); Node empty = d_nm->mkConst(emptyBag); InferInfo inferInfo; inferInfo.d_id = Inference::BAG_EMPTY; - Node count = getMultiplicitySkolem(e, empty, inferInfo); + Node count = getMultiplicityTerm(e, empty); Node equal = count.eqNode(d_zero); inferInfo.d_conclusion = equal; @@ -140,9 +141,11 @@ InferInfo InferenceGenerator::unionDisjoint(Node n, Node e) InferInfo inferInfo; inferInfo.d_id = Inference::BAG_UNION_DISJOINT; - Node countA = getMultiplicitySkolem(e, A, inferInfo); - Node countB = getMultiplicitySkolem(e, B, inferInfo); - Node count = getMultiplicitySkolem(e, n, inferInfo); + Node countA = getMultiplicityTerm(e, A); + Node countB = getMultiplicityTerm(e, B); + + Node skolem = getSkolem(n, inferInfo); + Node count = getMultiplicityTerm(e, skolem); Node sum = d_nm->mkNode(kind::PLUS, countA, countB); Node equal = count.eqNode(sum); @@ -161,9 +164,11 @@ InferInfo InferenceGenerator::unionMax(Node n, Node e) InferInfo inferInfo; inferInfo.d_id = Inference::BAG_UNION_MAX; - Node countA = getMultiplicitySkolem(e, A, inferInfo); - Node countB = getMultiplicitySkolem(e, B, inferInfo); - Node count = getMultiplicitySkolem(e, n, inferInfo); + Node countA = getMultiplicityTerm(e, A); + Node countB = getMultiplicityTerm(e, B); + + Node skolem = getSkolem(n, inferInfo); + Node count = getMultiplicityTerm(e, skolem); Node gt = d_nm->mkNode(kind::GT, countA, countB); Node max = d_nm->mkNode(kind::ITE, gt, countA, countB); @@ -183,9 +188,10 @@ InferInfo InferenceGenerator::intersection(Node n, Node e) InferInfo inferInfo; inferInfo.d_id = Inference::BAG_INTERSECTION_MIN; - Node countA = getMultiplicitySkolem(e, A, inferInfo); - Node countB = getMultiplicitySkolem(e, B, inferInfo); - Node count = getMultiplicitySkolem(e, n, inferInfo); + Node countA = getMultiplicityTerm(e, A); + Node countB = getMultiplicityTerm(e, B); + Node skolem = getSkolem(n, inferInfo); + Node count = getMultiplicityTerm(e, skolem); Node lt = d_nm->mkNode(kind::LT, countA, countB); Node min = d_nm->mkNode(kind::ITE, lt, countA, countB); @@ -204,9 +210,10 @@ InferInfo InferenceGenerator::differenceSubtract(Node n, Node e) InferInfo inferInfo; inferInfo.d_id = Inference::BAG_DIFFERENCE_SUBTRACT; - Node countA = getMultiplicitySkolem(e, A, inferInfo); - Node countB = getMultiplicitySkolem(e, B, inferInfo); - Node count = getMultiplicitySkolem(e, n, inferInfo); + Node countA = getMultiplicityTerm(e, A); + Node countB = getMultiplicityTerm(e, B); + Node skolem = getSkolem(n, inferInfo); + Node count = getMultiplicityTerm(e, skolem); Node subtract = d_nm->mkNode(kind::MINUS, countA, countB); Node gte = d_nm->mkNode(kind::GEQ, countA, countB); @@ -226,9 +233,11 @@ InferInfo InferenceGenerator::differenceRemove(Node n, Node e) InferInfo inferInfo; inferInfo.d_id = Inference::BAG_DIFFERENCE_REMOVE; - Node countA = getMultiplicitySkolem(e, A, inferInfo); - Node countB = getMultiplicitySkolem(e, B, inferInfo); - Node count = getMultiplicitySkolem(e, n, inferInfo); + Node countA = getMultiplicityTerm(e, A); + Node countB = getMultiplicityTerm(e, B); + + Node skolem = getSkolem(n, inferInfo); + Node count = getMultiplicityTerm(e, skolem); Node notInB = d_nm->mkNode(kind::EQUAL, countB, d_zero); Node difference = d_nm->mkNode(kind::ITE, notInB, countA, d_zero); @@ -246,8 +255,9 @@ InferInfo InferenceGenerator::duplicateRemoval(Node n, Node e) InferInfo inferInfo; inferInfo.d_id = Inference::BAG_DUPLICATE_REMOVAL; - Node countA = getMultiplicitySkolem(e, A, inferInfo); - Node count = getMultiplicitySkolem(e, n, inferInfo); + Node countA = getMultiplicityTerm(e, A); + Node skolem = getSkolem(n, inferInfo); + Node count = getMultiplicityTerm(e, skolem); Node gte = d_nm->mkNode(kind::GEQ, countA, d_one); Node ite = d_nm->mkNode(kind::ITE, gte, d_one, d_zero); @@ -256,16 +266,10 @@ InferInfo InferenceGenerator::duplicateRemoval(Node n, Node e) return inferInfo; } -Node InferenceGenerator::getMultiplicitySkolem(Node element, - Node bag, - InferInfo& inferInfo) +Node InferenceGenerator::getMultiplicityTerm(Node element, Node bag) { Node count = d_nm->mkNode(kind::BAG_COUNT, element, bag); - Node skolem = d_state->registerBagElement(count); - eq::EqualityEngine* ee = d_state->getEqualityEngine(); - ee->assertEquality(skolem.eqNode(count), true, d_nm->mkConst(true)); - inferInfo.d_newSkolem.push_back(skolem); - return skolem; + return count; } } // namespace bags diff --git a/src/theory/bags/inference_generator.h b/src/theory/bags/inference_generator.h index b56997088..9eee46e43 100644 --- a/src/theory/bags/inference_generator.h +++ b/src/theory/bags/inference_generator.h @@ -38,33 +38,38 @@ class InferenceGenerator InferenceGenerator(SolverState* state); /** - * @param n is (bag x c) of type (Bag E) + * @param A is a bag of type (Bag E) * @param e is a node of type E * @return an inference that represents the following implication * (=> * true - * (= (bag.count e (bag x c)) (ite (= e x) c 0))) + * (>= (bag.count e A) 0) */ - InferInfo mkBag(Node n, Node e); + InferInfo nonNegativeCount(Node n, Node e); /** - * @param n is (= A B) where A, B are bags of type (Bag E) - * @param e is a node of Type E + * @param n is (bag x c) of type (Bag E) + * @param e is a node of type E * @return an inference that represents the following implication * (=> - * (= A B) - * (= (count e A) (count e B))) + * true + * (= (bag.count e skolem) c)) + * if e is exactly node x. Node skolem is a fresh variable equals (bag x c). + * Otherwise the following inference is returned + * (=> + * true + * (= (bag.count e skolem) (ite (= e x) c 0))) */ - InferInfo bagEquality(Node n, Node e); + InferInfo mkBag(Node n, Node e); /** * @param n is (not (= A B)) where A, B are bags of type (Bag E) * @return an inference that represents the following implication * (=> * (not (= A B)) * (not (= (count e A) (count e B)))) - * where e is a fresh skolem of type E + * where e is a fresh skolem of type E. */ - InferInfo bagDisequality(Node n); + InferInfo bagDisequality(Node n, Node reason); /** * @param e is a node of Type E * @return an inference that represents the following implication @@ -79,10 +84,9 @@ class InferenceGenerator * @return an inference that represents the following implication * (=> * true - * (= (count e k_{(union_disjoint A B)}) + * (= (count e skolem) * (+ (count e A) (count e B)))) - * where k_{(union_disjoint A B)} is a unique purification skolem - * for (union_disjoint A B) + * where skolem is a fresh variable equals (union_disjoint A B) */ InferInfo unionDisjoint(Node n, Node e); /** @@ -91,11 +95,13 @@ class InferenceGenerator * @return an inference that represents the following implication * (=> * true - * (= (count e (union_max A B)) + * (= + * (count e skolem) * (ite - * (> (count e A) (count e B)) - * (count e A) - * (count e B))))) + * (> (count e A) (count e B)) + * (count e A) + * (count e B))))) + * where skolem is a fresh variable equals (union_max A B) */ InferInfo unionMax(Node n, Node e); /** @@ -104,11 +110,13 @@ class InferenceGenerator * @return an inference that represents the following implication * (=> * true - * (= (count e (intersection_min A B)) + * (= + * (count e skolem) * (ite( - * (< (count e A) (count e B)) - * (count e A) - * (count e B))))) + * (< (count e A) (count e B)) + * (count e A) + * (count e B))))) + * where skolem is a fresh variable equals (intersection_min A B) */ InferInfo intersection(Node n, Node e); /** @@ -117,11 +125,13 @@ class InferenceGenerator * @return an inference that represents the following implication * (=> * true - * (= (count e (difference_subtract A B)) + * (= + * (count e skolem) * (ite - * (>= (count e A) (count e B)) - * (- (count e A) (count e B)) - * 0)))) + * (>= (count e A) (count e B)) + * (- (count e A) (count e B)) + * 0)))) + * where skolem is a fresh variable equals (difference_subtract A B) */ InferInfo differenceSubtract(Node n, Node e); /** @@ -130,11 +140,13 @@ class InferenceGenerator * @return an inference that represents the following implication * (=> * true - * (= (count e (difference_remove A B)) + * (= + * (count e skolem) * (ite - * (= (count e B) 0) - * (count e A) - * 0)))) + * (= (count e B) 0) + * (count e A) + * 0)))) + * where skolem is a fresh variable equals (difference_remove A B) */ InferInfo differenceRemove(Node n, Node e); /** @@ -143,20 +155,24 @@ class InferenceGenerator * @return an inference that represents the following implication * (=> * true - * (= (count e (duplicate_removal A)) - * (ite (>= (count e A) 1) 1 0)))) + * (= + * (count e skolem) + * (ite (>= (count e A) 1) 1 0)))) + * where skolem is a fresh variable equals (duplicate_removal A) */ InferInfo duplicateRemoval(Node n, Node e); /** * @param element of type T * @param bag of type (bag T) - * @param inferInfo to store new skolem - * @return a skolem for (bag.count element bag) + * @return a count term (bag.count element bag) */ - Node getMultiplicitySkolem(Node element, Node bag, InferInfo& inferInfo); + Node getMultiplicityTerm(Node element, Node bag); private: + /** generate skolem variable for node n and add it to inferInfo */ + Node getSkolem(Node& n, InferInfo& inferInfo); + NodeManager* d_nm; SkolemManager* d_sm; SolverState* d_state; diff --git a/src/theory/bags/inference_manager.h b/src/theory/bags/inference_manager.h index 67025548c..71a014582 100644 --- a/src/theory/bags/inference_manager.h +++ b/src/theory/bags/inference_manager.h @@ -45,7 +45,7 @@ class InferenceManager : public InferenceManagerBuffered * process the pending lemmas and then the pending phase requirements. * Notice that we process the pending lemmas even if there were facts. */ - // TODO: refactor this before merge with theory of strings + // TODO issue #78: refactor this with theory of strings void doPending(); private: diff --git a/src/theory/bags/solver_state.cpp b/src/theory/bags/solver_state.cpp index 744f6de9f..9bcb6ae3c 100644 --- a/src/theory/bags/solver_state.cpp +++ b/src/theory/bags/solver_state.cpp @@ -33,52 +33,89 @@ SolverState::SolverState(context::Context* c, { d_true = NodeManager::currentNM()->mkConst(true); d_false = NodeManager::currentNM()->mkConst(false); + d_nm = NodeManager::currentNM(); } -struct BagsCountAttributeId +void SolverState::registerBag(TNode n) { -}; -typedef expr::Attribute BagsCountAttribute; - -void SolverState::registerClass(TNode n) -{ - TypeNode t = n.getType(); - if (!t.isBag()) - { - return; - } + Assert(n.getType().isBag()); d_bags.insert(n); } -Node SolverState::registerBagElement(TNode n) +void SolverState::registerCountTerm(TNode n) { Assert(n.getKind() == BAG_COUNT); - Node element = n[0]; - TypeNode elementType = element.getType(); - Node bag = n[1]; - d_elements[elementType].insert(element); - NodeManager* nm = NodeManager::currentNM(); - BoundVarManager* bvm = nm->getBoundVarManager(); - Node multiplicity = bvm->mkBoundVar(n, nm->integerType()); - Node equal = n.eqNode(multiplicity); - SkolemManager* sm = nm->getSkolemManager(); - Node skolem = sm->mkSkolem( - multiplicity, - equal, - "bag_multiplicity", - "an extensional lemma for multiplicity of an element in a bag"); - d_count[bag][element] = skolem; - Trace("bags::SolverState::registerBagElement") - << "New skolem: " << skolem << " for " << n << std::endl; - - return skolem; + Node element = getRepresentative(n[0]); + Node bag = getRepresentative(n[1]); + d_bagElements[bag].insert(element); +} + +const std::set& SolverState::getBags() { return d_bags; } + +const std::set& SolverState::getElements(Node B) +{ + Node bag = getRepresentative(B); + return d_bagElements[B]; +} + +void SolverState::reset() +{ + d_bagElements.clear(); + d_bags.clear(); } -std::set& SolverState::getBags() { return d_bags; } +void SolverState::initialize() +{ + reset(); + collectBagsAndCountTerms(); +} + +void SolverState::collectBagsAndCountTerms() +{ + Trace("SolverState::collectBagsAndCountTerms") + << "SolverState::collectBagsAndCountTerms start" << endl; + eq::EqClassesIterator repIt = eq::EqClassesIterator(d_ee); + while (!repIt.isFinished()) + { + Node eqc = (*repIt); + Trace("SolverState::collectBagsAndCountTerms") + << "[" << eqc << "]: " << endl; + + if (eqc.getType().isBag()) + { + registerBag(eqc); + } -std::set& SolverState::getElements(TypeNode t) { return d_elements[t]; } + eq::EqClassIterator it = eq::EqClassIterator(eqc, d_ee); + while (!it.isFinished()) + { + Node n = (*it); + Kind k = n.getKind(); + if (k == MK_BAG) + { + // for terms (bag x c) we need to store x by registering the count term + // (bag.count x (bag x c)) + Node count = d_nm->mkNode(BAG_COUNT, n[0], n); + registerCountTerm(count); + Trace("SolverState::collectBagsAndCountTerms") + << "registered " << count << endl; + } + if (k == BAG_COUNT) + { + // this takes care of all count terms in each equivalent class + registerCountTerm(n); + Trace("SolverState::collectBagsAndCountTerms") + << "registered " << n << endl; + } + ++it; + } -std::map& SolverState::getBagElements(Node B) { return d_count[B]; } + ++repIt; + } + + Trace("SolverState::collectBagsAndCountTerms") + << "SolverState::collectBagsAndCountTerms end" << endl; +} } // namespace bags } // namespace theory diff --git a/src/theory/bags/solver_state.h b/src/theory/bags/solver_state.h index 8d70ee8f7..175317529 100644 --- a/src/theory/bags/solver_state.h +++ b/src/theory/bags/solver_state.h @@ -31,24 +31,52 @@ class SolverState : public TheoryState public: SolverState(context::Context* c, context::UserContext* u, Valuation val); - void registerClass(TNode n); + /** + * This function adds the bag representative n to the set d_bags if it is not + * already there. This function is called during postCheck to collect bag + * terms in the equality engine. See the documentation of + * @link SolverState::collectBagsAndCountTerms + */ + void registerBag(TNode n); - Node registerBagElement(TNode n); - - std::set& getBags(); - - std::set& getElements(TypeNode t); - - std::map& getBagElements(Node B); + /** + * @param n has the form (bag.count e A) + * @pre bag A needs is already registered using registerBag(A) + * @return a unique skolem for (bag.count e A) + */ + void registerCountTerm(TNode n); + /** get all bag terms that are representatives in the equality engine. + * This function is valid after the current solver is initialized during + * postCheck. See SolverState::initialize and BagSolver::postCheck + */ + const std::set& getBags(); + /** + * @pre B is a registered bag + * @return all elements associated with bag B so far + * Note that associated elements are not necessarily elements in B + * Example: + * (assert (= 0 (bag.count x B))) + * element x is associated with bag B, albeit x is definitely not in B. + */ + const std::set& getElements(Node B); + /** initialize bag and count terms */ + void initialize(); private: + /** clear all bags data structures */ + void reset(); + /** collect bags' representatives and all count terms. + * This function is called during postCheck */ + void collectBagsAndCountTerms(); /** constants */ Node d_true; Node d_false; + /** node manager for this solver state */ + NodeManager* d_nm; + /** collection of bag representatives */ std::set d_bags; - std::map> d_elements; - /** bag -> element -> multiplicity */ - std::map> d_count; + /** bag -> associated elements */ + std::map> d_bagElements; }; /* class SolverState */ } // namespace bags diff --git a/src/theory/bags/theory_bags.cpp b/src/theory/bags/theory_bags.cpp index 21a9d0e53..153e9017d 100644 --- a/src/theory/bags/theory_bags.cpp +++ b/src/theory/bags/theory_bags.cpp @@ -78,22 +78,22 @@ void TheoryBags::finishInit() void TheoryBags::postCheck(Effort effort) { d_im.doPendingFacts(); - // TODO: clean this before merge Assert(d_strat.isStrategyInit()); + // TODO issue #78: add Assert(d_strat.isStrategyInit()); if (!d_state.isInConflict() && !d_valuation.needCheck()) - // TODO: clean this before merge && d_strat.hasStrategyEffort(e)) + // TODO issue #78: add && d_strat.hasStrategyEffort(e)) { Trace("bags::TheoryBags::postCheck") << "effort: " << std::endl; - // TODO: clean this before merge ++(d_statistics.d_checkRuns); + // TODO issue #78: add ++(d_statistics.d_checkRuns); bool sentLemma = false; bool hadPending = false; Trace("bags-check") << "Full effort check..." << std::endl; do { d_im.reset(); - // TODO: clean this before merge ++(d_statistics.d_strategyRuns); + // TODO issue #78: add ++(d_statistics.d_strategyRuns); Trace("bags-check") << " * Run strategy..." << std::endl; - // TODO: clean this before merge runStrategy(e); + // TODO issue #78: add runStrategy(e); d_solver.postCheck(); @@ -153,14 +153,22 @@ bool TheoryBags::collectModelValues(TheoryModel* m, continue; } Node r = d_state.getRepresentative(n); - std::map elements = d_state.getBagElements(r); + std::set solverElements = d_state.getElements(r); + std::set elements; + // only consider terms in termSet and ignore other elements in the solver + std::set_intersection(termSet.begin(), + termSet.end(), + solverElements.begin(), + solverElements.end(), + std::inserter(elements, elements.begin())); Trace("bags-model") << "Elements of bag " << n << " are: " << std::endl << elements << std::endl; std::map elementReps; - for (std::pair pair : elements) + for (const Node& e : elements) { - Node key = d_state.getRepresentative(pair.first); - Node value = d_state.getRepresentative(pair.second); + Node key = d_state.getRepresentative(e); + Node countTerm = NodeManager::currentNM()->mkNode(BAG_COUNT, e, r); + Node value = d_state.getRepresentative(countTerm); elementReps[key] = value; } Node rep = NormalForm::constructBagFromElements(tn, elementReps); @@ -211,38 +219,7 @@ void TheoryBags::presolve() {} /**************************** eq::NotifyClass *****************************/ -void TheoryBags::eqNotifyNewClass(TNode n) -{ - Kind k = n.getKind(); - d_state.registerClass(n); - if (n.getKind() == MK_BAG) - { - // TODO: refactor this before merge - /* - * (bag x m) generates the lemma (and (= s (count x (bag x m))) (= s m)) - * where s is a fresh skolem variable - */ - NodeManager* nm = NodeManager::currentNM(); - Node count = nm->mkNode(BAG_COUNT, n[0], n); - Node skolem = d_state.registerBagElement(count); - Node countSkolem = count.eqNode(skolem); - Node skolemMultiplicity = n[1].eqNode(skolem); - Node lemma = countSkolem.andNode(skolemMultiplicity); - TrustNode trustedLemma = TrustNode::mkTrustLemma(lemma, nullptr); - d_im.trustedLemma(trustedLemma); - } - if (k == BAG_COUNT) - { - /* - * (count x A) generates the lemma (= s (count x A)) - * where s is a fresh skolem variable - */ - Node skolem = d_state.registerBagElement(n); - Node lemma = n.eqNode(skolem); - TrustNode trustedLemma = TrustNode::mkTrustLemma(lemma, nullptr); - d_im.trustedLemma(trustedLemma); - } -} +void TheoryBags::eqNotifyNewClass(TNode n) {} void TheoryBags::eqNotifyMerge(TNode n1, TNode n2) {} @@ -251,10 +228,8 @@ void TheoryBags::eqNotifyDisequal(TNode n1, TNode n2, TNode reason) TypeNode t1 = n1.getType(); if (t1.isBag()) { - InferInfo info = d_ig.bagDisequality(n1.eqNode(n2).notNode()); - Node lemma = reason.impNode(info.d_conclusion); - TrustNode trustedLemma = TrustNode::mkTrustLemma(lemma, nullptr); - d_im.trustedLemma(trustedLemma); + InferInfo info = d_ig.bagDisequality(n1.eqNode(n2).notNode(), reason); + info.process(d_inferManager, true); } } diff --git a/test/regress/CMakeLists.txt b/test/regress/CMakeLists.txt index cf4b0386d..810ed8128 100644 --- a/test/regress/CMakeLists.txt +++ b/test/regress/CMakeLists.txt @@ -1427,7 +1427,9 @@ set(regress_1_tests regress1/bug681.smt2 regress1/bug694-Unapply1.scala-0.smt2 regress1/bug800.smt2 + regress1/bags/difference_remove1.smt2 regress1/bags/disequality.smt2 + regress1/bags/issue5759.smt2 regress1/bags/subbag1.smt2 regress1/bags/subbag2.smt2 regress1/bags/union_disjoint.smt2 diff --git a/test/regress/regress1/bags/difference_remove1.smt2 b/test/regress/regress1/bags/difference_remove1.smt2 new file mode 100644 index 000000000..f5a87c149 --- /dev/null +++ b/test/regress/regress1/bags/difference_remove1.smt2 @@ -0,0 +1,10 @@ +(set-logic ALL) +(set-info :status sat) +(declare-fun A () (Bag Int)) +(declare-fun B () (Bag Int)) +(declare-fun x () Int) +(declare-fun y () Int) +(assert (= A (union_max (bag x 1) (bag y 2)))) +(assert (= A (union_disjoint B (bag y 2)))) +(assert (= x y)) +(check-sat) diff --git a/test/regress/regress1/bags/issue5759.smt2 b/test/regress/regress1/bags/issue5759.smt2 new file mode 100644 index 000000000..ba3752e09 --- /dev/null +++ b/test/regress/regress1/bags/issue5759.smt2 @@ -0,0 +1,10 @@ +(set-logic ALL) +(set-info :status sat) +(set-option :produce-models true) +(declare-fun A () (Bag Int)) +(declare-fun B () (Bag Int)) +(declare-fun x () Int) +(assert (not (= A (union_max (bag x 1) (bag 0 1))))) +(assert (= A (union_disjoint B (bag 0 1)))) +(assert (= x 1)) +(check-sat) \ No newline at end of file -- 2.30.2