From 2d64f408f416c601b3b545984ca1b6c31c151f16 Mon Sep 17 00:00:00 2001 From: mudathirmahgoub Date: Tue, 1 Feb 2022 08:58:04 -0600 Subject: [PATCH] Add bag.filter operator (#8006) --- src/CMakeLists.txt | 4 +- src/api/cpp/cvc5.cpp | 2 + src/api/cpp/cvc5_kind.h | 17 +++ src/parser/smt2/smt2.cpp | 1 + src/printer/smt2/smt2_printer.cpp | 1 + src/theory/bags/bag_solver.cpp | 25 +++- src/theory/bags/bag_solver.h | 2 + src/theory/bags/bags_rewriter.cpp | 58 ++++++++- src/theory/bags/bags_rewriter.h | 10 ++ .../bags/{normal_form.cpp => bags_utils.cpp} | 122 +++++++++++++----- .../bags/{normal_form.h => bags_utils.h} | 17 ++- src/theory/bags/card_solver.cpp | 2 +- src/theory/bags/inference_generator.cpp | 46 +++++++ src/theory/bags/inference_generator.h | 28 ++++ src/theory/bags/kinds | 5 + src/theory/bags/rewrites.cpp | 3 + src/theory/bags/rewrites.h | 3 + src/theory/bags/theory_bags.cpp | 7 +- .../bags/theory_bags_type_enumerator.cpp | 7 +- src/theory/bags/theory_bags_type_rules.cpp | 46 ++++++- src/theory/bags/theory_bags_type_rules.h | 9 ++ src/theory/inference_id.cpp | 2 + src/theory/inference_id.h | 2 + test/regress/CMakeLists.txt | 5 + test/regress/regress1/bags/filter1.smt2 | 11 ++ test/regress/regress1/bags/filter2.smt2 | 9 ++ test/regress/regress1/bags/filter3.smt2 | 10 ++ test/regress/regress1/bags/filter4.smt2 | 11 ++ test/regress/regress1/bags/filter5.smt2 | 11 ++ test/regress/regress1/bags/map1.smt2 | 1 - .../theory/theory_bags_normal_form_white.cpp | 70 +++++----- 31 files changed, 456 insertions(+), 91 deletions(-) rename src/theory/bags/{normal_form.cpp => bags_utils.cpp} (87%) rename src/theory/bags/{normal_form.h => bags_utils.h} (94%) create mode 100644 test/regress/regress1/bags/filter1.smt2 create mode 100644 test/regress/regress1/bags/filter2.smt2 create mode 100644 test/regress/regress1/bags/filter3.smt2 create mode 100644 test/regress/regress1/bags/filter4.smt2 create mode 100644 test/regress/regress1/bags/filter5.smt2 diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index a1ad056f1..1715257f7 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -549,6 +549,8 @@ libcvc5_add_sources( theory/bags/bag_reduction.h theory/bags/bags_statistics.cpp theory/bags/bags_statistics.h + theory/bags/bags_utils.cpp + theory/bags/bags_utils.h theory/bags/card_solver.cpp theory/bags/card_solver.h theory/bags/infer_info.cpp @@ -557,8 +559,6 @@ libcvc5_add_sources( theory/bags/inference_generator.h theory/bags/inference_manager.cpp theory/bags/inference_manager.h - theory/bags/normal_form.cpp - theory/bags/normal_form.h theory/bags/rewrites.cpp theory/bags/rewrites.h theory/bags/solver_state.cpp diff --git a/src/api/cpp/cvc5.cpp b/src/api/cpp/cvc5.cpp index 458dd359a..df9a8b8ae 100644 --- a/src/api/cpp/cvc5.cpp +++ b/src/api/cpp/cvc5.cpp @@ -312,6 +312,7 @@ const static std::unordered_map s_kinds{ {BAG_FROM_SET, cvc5::Kind::BAG_FROM_SET}, {BAG_TO_SET, cvc5::Kind::BAG_TO_SET}, {BAG_MAP, cvc5::Kind::BAG_MAP}, + {BAG_FILTER, cvc5::Kind::BAG_FILTER}, {BAG_FOLD, cvc5::Kind::BAG_FOLD}, /* Strings ------------------------------------------------------------- */ {STRING_CONCAT, cvc5::Kind::STRING_CONCAT}, @@ -624,6 +625,7 @@ const static std::unordered_map {cvc5::Kind::BAG_FROM_SET, BAG_FROM_SET}, {cvc5::Kind::BAG_TO_SET, BAG_TO_SET}, {cvc5::Kind::BAG_MAP, BAG_MAP}, + {cvc5::Kind::BAG_FILTER, BAG_FILTER}, {cvc5::Kind::BAG_FOLD, BAG_FOLD}, /* Strings --------------------------------------------------------- */ {cvc5::Kind::STRING_CONCAT, STRING_CONCAT}, diff --git a/src/api/cpp/cvc5_kind.h b/src/api/cpp/cvc5_kind.h index 3bd896814..dba4df07f 100644 --- a/src/api/cpp/cvc5_kind.h +++ b/src/api/cpp/cvc5_kind.h @@ -2539,6 +2539,23 @@ enum Kind : int32_t * - `Solver::mkTerm(Kind kind, const std::vector& children) const` */ BAG_MAP, + /** + * bag.filter operator filters the elements of a bag. + * (bag.filter p B) takes a predicate p of type (-> T Bool) as a first + * argument, and a bag B of type (Bag T) as a second argument, and returns a + * subbag of type (Bag T) that includes all elements of B that satisfy p + * with the same multiplicity. + * + * Parameters: + * - 1: a function of type (-> T Bool) + * - 2: a bag of type (Bag T) + * + * Create with: + * - `Solver::mkTerm(Kind kind, const Term& child1, const Term& child2) + * const` + * - `Solver::mkTerm(Kind kind, const std::vector& children) const` + */ + BAG_FILTER, /** * bag.fold operator combines elements of a bag into a single value. * (bag.fold f t B) folds the elements of bag B starting with term t and using diff --git a/src/parser/smt2/smt2.cpp b/src/parser/smt2/smt2.cpp index 8eed51baa..a93596633 100644 --- a/src/parser/smt2/smt2.cpp +++ b/src/parser/smt2/smt2.cpp @@ -625,6 +625,7 @@ Command* Smt2::setLogic(std::string name, bool fromCommand) addOperator(api::BAG_FROM_SET, "bag.from_set"); addOperator(api::BAG_TO_SET, "bag.to_set"); addOperator(api::BAG_MAP, "bag.map"); + addOperator(api::BAG_FILTER, "bag.filter"); addOperator(api::BAG_FOLD, "bag.fold"); } if(d_logic.isTheoryEnabled(theory::THEORY_STRINGS)) { diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index 067cf27fe..cf85e6b0e 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -1123,6 +1123,7 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v) case kind::BAG_FROM_SET: return "bag.from_set"; case kind::BAG_TO_SET: return "bag.to_set"; case kind::BAG_MAP: return "bag.map"; + case kind::BAG_FILTER: return "bag.filter"; case kind::BAG_FOLD: return "bag.fold"; // fp theory diff --git a/src/theory/bags/bag_solver.cpp b/src/theory/bags/bag_solver.cpp index 55367bb89..ed4b501f3 100644 --- a/src/theory/bags/bag_solver.cpp +++ b/src/theory/bags/bag_solver.cpp @@ -16,9 +16,9 @@ #include "theory/bags/bag_solver.h" #include "expr/emptybag.h" +#include "theory/bags/bags_utils.h" #include "theory/bags/inference_generator.h" #include "theory/bags/inference_manager.h" -#include "theory/bags/normal_form.h" #include "theory/bags/solver_state.h" #include "theory/bags/term_registry.h" #include "theory/uf/equality_engine_iterator.h" @@ -76,6 +76,7 @@ void BagSolver::checkBasicOperations() case kind::BAG_DIFFERENCE_SUBTRACT: checkDifferenceSubtract(n); break; case kind::BAG_DIFFERENCE_REMOVE: checkDifferenceRemove(n); break; case kind::BAG_DUPLICATE_REMOVAL: checkDuplicateRemoval(n); break; + case kind::BAG_FILTER: checkFilter(n); break; case kind::BAG_MAP: checkMap(n); break; default: break; } @@ -280,6 +281,28 @@ void BagSolver::checkMap(Node n) } } +void BagSolver::checkFilter(Node n) +{ + Assert(n.getKind() == BAG_FILTER); + + set elements; + const set& downwards = d_state.getElements(n); + const set& upwards = d_state.getElements(n[0]); + elements.insert(downwards.begin(), downwards.end()); + elements.insert(upwards.begin(), upwards.end()); + + for (const Node& e : elements) + { + InferInfo i = d_ig.filterDownwards(n, d_state.getRepresentative(e)); + d_im.lemmaTheoryInference(&i); + } + for (const Node& e : elements) + { + InferInfo i = d_ig.filterUpwards(n, d_state.getRepresentative(e)); + d_im.lemmaTheoryInference(&i); + } +} + } // namespace bags } // namespace theory } // namespace cvc5 diff --git a/src/theory/bags/bag_solver.h b/src/theory/bags/bag_solver.h index 499b7998d..fca72b22e 100644 --- a/src/theory/bags/bag_solver.h +++ b/src/theory/bags/bag_solver.h @@ -96,6 +96,8 @@ class BagSolver : protected EnvObj void checkDisequalBagTerms(); /** apply inference rules for map operator */ void checkMap(Node n); + /** apply inference rules for filter operator */ + void checkFilter(Node n); /** The solver state object */ SolverState& d_state; diff --git a/src/theory/bags/bags_rewriter.cpp b/src/theory/bags/bags_rewriter.cpp index 40f8d6c95..24f313ad6 100644 --- a/src/theory/bags/bags_rewriter.cpp +++ b/src/theory/bags/bags_rewriter.cpp @@ -16,7 +16,7 @@ #include "theory/bags/bags_rewriter.h" #include "expr/emptybag.h" -#include "theory/bags/normal_form.h" +#include "theory/bags/bags_utils.h" #include "util/rational.h" #include "util/statistics_registry.h" @@ -65,9 +65,9 @@ RewriteResponse BagsRewriter::postRewrite(TNode n) { response = rewriteChoose(n); } - else if (NormalForm::areChildrenConstants(n)) + else if (BagsUtils::areChildrenConstants(n)) { - Node value = NormalForm::evaluate(n); + Node value = BagsUtils::evaluate(n); response = BagsRewriteResponse(value, Rewrite::CONSTANT_EVALUATION); } else @@ -90,6 +90,7 @@ RewriteResponse BagsRewriter::postRewrite(TNode n) case BAG_FROM_SET: response = rewriteFromSet(n); break; case BAG_TO_SET: response = rewriteToSet(n); break; case BAG_MAP: response = postRewriteMap(n); break; + case BAG_FILTER: response = postRewriteFilter(n); break; case BAG_FOLD: response = postRewriteFold(n); break; default: response = BagsRewriteResponse(n, Rewrite::NONE); break; } @@ -533,7 +534,7 @@ BagsRewriteResponse BagsRewriter::postRewriteMap(const TNode& n) const { // (bag.map f (as bag.empty (Bag T1)) = (as bag.empty (Bag T2)) // (bag.map f (bag "a" 3)) = (bag (f "a") 3) - std::map elements = NormalForm::getBagElements(n[1]); + std::map elements = BagsUtils::getBagElements(n[1]); std::map mappedElements; std::map::iterator it = elements.begin(); while (it != elements.end()) @@ -543,7 +544,7 @@ BagsRewriteResponse BagsRewriter::postRewriteMap(const TNode& n) const ++it; } TypeNode t = d_nm->mkBagType(n[0].getType().getRangeType()); - Node ret = NormalForm::constructConstantBagFromElements(t, mappedElements); + Node ret = BagsUtils::constructConstantBagFromElements(t, mappedElements); return BagsRewriteResponse(ret, Rewrite::MAP_CONST); } Kind k = n[1].getKind(); @@ -572,6 +573,49 @@ BagsRewriteResponse BagsRewriter::postRewriteMap(const TNode& n) const } } +BagsRewriteResponse BagsRewriter::postRewriteFilter(const TNode& n) const +{ + Assert(n.getKind() == kind::BAG_FILTER); + Node P = n[0]; + Node A = n[1]; + TypeNode t = A.getType(); + if (A.isConst()) + { + // (bag.filter p (as bag.empty (Bag T)) = (as bag.empty (Bag T)) + // (bag.filter p (bag "a" 3) ((bag "b" 2))) = + // (bag.union_disjoint + // (ite (p "a") (bag "a" 3) (as bag.empty (Bag T))) + // (ite (p "b") (bag "b" 2) (as bag.empty (Bag T))) + + Node ret = BagsUtils::evaluateBagFilter(n); + return BagsRewriteResponse(ret, Rewrite::FILTER_CONST); + } + Kind k = A.getKind(); + switch (k) + { + case BAG_MAKE: + { + // (bag.filter p (bag x y)) = (ite (p x) (bag x y) (as bag.empty (Bag T))) + Node empty = d_nm->mkConst(EmptyBag(t)); + Node pOfe = d_nm->mkNode(APPLY_UF, P, A[0]); + Node ret = d_nm->mkNode(ITE, pOfe, A, empty); + return BagsRewriteResponse(ret, Rewrite::FILTER_BAG_MAKE); + } + + case BAG_UNION_DISJOINT: + { + // (bag.filter p (bag.union_disjoint A B)) = + // (bag.union_disjoint (bag.filter p A) (bag.filter p B)) + Node a = d_nm->mkNode(BAG_FILTER, n[0], n[1][0]); + Node b = d_nm->mkNode(BAG_FILTER, n[0], n[1][1]); + Node ret = d_nm->mkNode(BAG_UNION_DISJOINT, a, b); + return BagsRewriteResponse(ret, Rewrite::FILTER_UNION_DISJOINT); + } + + default: return BagsRewriteResponse(n, Rewrite::NONE); + } +} + BagsRewriteResponse BagsRewriter::postRewriteFold(const TNode& n) const { Assert(n.getKind() == kind::BAG_FOLD); @@ -580,7 +624,7 @@ BagsRewriteResponse BagsRewriter::postRewriteFold(const TNode& n) const Node bag = n[2]; if (bag.isConst()) { - Node value = NormalForm::evaluateBagFold(n); + Node value = BagsUtils::evaluateBagFold(n); return BagsRewriteResponse(value, Rewrite::FOLD_CONST); } Kind k = bag.getKind(); @@ -591,7 +635,7 @@ BagsRewriteResponse BagsRewriter::postRewriteFold(const TNode& n) const if (bag[1].isConst() && bag[1].getConst() > Rational(0)) { // (bag.fold f t (bag x n)) = (f t ... (f t (f t x))) n times, n > 0 - Node value = NormalForm::evaluateBagFold(n); + Node value = BagsUtils::evaluateBagFold(n); return BagsRewriteResponse(value, Rewrite::FOLD_BAG); } break; diff --git a/src/theory/bags/bags_rewriter.h b/src/theory/bags/bags_rewriter.h index b4b1e9043..3e5b69a1c 100644 --- a/src/theory/bags/bags_rewriter.h +++ b/src/theory/bags/bags_rewriter.h @@ -228,6 +228,16 @@ class BagsRewriter : public TheoryRewriter */ BagsRewriteResponse postRewriteMap(const TNode& n) const; + /** + * rewrites for n include: + * - (bag.filter p (as bag.empty (Bag T)) = (as bag.empty (Bag T)) + * - (bag.filter p (bag x y)) = (ite (p x) (bag x y) (as bag.empty (Bag T))) + * - (bag.filter p (bag.union_disjoint A B)) = + * (bag.union_disjoint (bag.filter p A) (bag.filter p B)) + * where p: T -> Bool + */ + BagsRewriteResponse postRewriteFilter(const TNode& n) const; + /** * rewrites for n include: * - (bag.fold f t (as bag.empty (Bag T1))) = t diff --git a/src/theory/bags/normal_form.cpp b/src/theory/bags/bags_utils.cpp similarity index 87% rename from src/theory/bags/normal_form.cpp rename to src/theory/bags/bags_utils.cpp index 6cf26d357..39987ce9d 100644 --- a/src/theory/bags/normal_form.cpp +++ b/src/theory/bags/bags_utils.cpp @@ -10,9 +10,9 @@ * directory for licensing information. * **************************************************************************** * - * Normal form for bag constants. + * Utility functions for bags. */ -#include "normal_form.h" +#include "bags_utils.h" #include "expr/emptybag.h" #include "smt/logic_exception.h" @@ -26,7 +26,31 @@ namespace cvc5 { namespace theory { namespace bags { -bool NormalForm::isConstant(TNode n) +Node BagsUtils::computeDisjointUnion(TypeNode bagType, + const std::vector& bags) +{ + NodeManager* nm = NodeManager::currentNM(); + if (bags.empty()) + { + return nm->mkConst(EmptyBag(bagType)); + } + if (bags.size() == 1) + { + return bags[0]; + } + Node unionDisjoint = bags[0]; + for (size_t i = 1; i < bags.size(); i++) + { + if (bags[i].getKind() == BAG_EMPTY) + { + continue; + } + unionDisjoint = nm->mkNode(BAG_UNION_DISJOINT, unionDisjoint, bags[i]); + } + return unionDisjoint; +} + +bool BagsUtils::isConstant(TNode n) { if (n.getKind() == BAG_EMPTY) { @@ -82,12 +106,12 @@ bool NormalForm::isConstant(TNode n) return false; } -bool NormalForm::areChildrenConstants(TNode n) +bool BagsUtils::areChildrenConstants(TNode n) { return std::all_of(n.begin(), n.end(), [](Node c) { return c.isConst(); }); } -Node NormalForm::evaluate(TNode n) +Node BagsUtils::evaluate(TNode n) { Assert(areChildrenConstants(n)); if (n.isConst()) @@ -110,6 +134,7 @@ Node NormalForm::evaluate(TNode n) case BAG_FROM_SET: return evaluateFromSet(n); case BAG_TO_SET: return evaluateToSet(n); case BAG_MAP: return evaluateBagMap(n); + case BAG_FILTER: return evaluateBagFilter(n); case BAG_FOLD: return evaluateBagFold(n); default: break; } @@ -118,12 +143,12 @@ Node NormalForm::evaluate(TNode n) } template -Node NormalForm::evaluateBinaryOperation(const TNode& n, - T1&& equal, - T2&& less, - T3&& greaterOrEqual, - T4&& remainderOfA, - T5&& remainderOfB) +Node BagsUtils::evaluateBinaryOperation(const TNode& n, + T1&& equal, + T2&& less, + T3&& greaterOrEqual, + T4&& remainderOfA, + T5&& remainderOfB) { std::map elementsA = getBagElements(n[0]); std::map elementsB = getBagElements(n[1]); @@ -168,7 +193,7 @@ Node NormalForm::evaluateBinaryOperation(const TNode& n, return bag; } -std::map NormalForm::getBagElements(TNode n) +std::map BagsUtils::getBagElements(TNode n) { std::map elements; if (n.getKind() == BAG_EMPTY) @@ -190,7 +215,7 @@ std::map NormalForm::getBagElements(TNode n) return elements; } -Node NormalForm::constructConstantBagFromElements( +Node BagsUtils::constructConstantBagFromElements( TypeNode t, const std::map& elements) { Assert(t.isBag()); @@ -210,8 +235,8 @@ Node NormalForm::constructConstantBagFromElements( return bag; } -Node NormalForm::constructBagFromElements(TypeNode t, - const std::map& elements) +Node BagsUtils::constructBagFromElements(TypeNode t, + const std::map& elements) { Assert(t.isBag()); NodeManager* nm = NodeManager::currentNM(); @@ -230,7 +255,7 @@ Node NormalForm::constructBagFromElements(TypeNode t, return bag; } -Node NormalForm::evaluateMakeBag(TNode n) +Node BagsUtils::evaluateMakeBag(TNode n) { // the case where n is const should be handled earlier. // here we handle the case where the multiplicity is zero or negative @@ -240,7 +265,7 @@ Node NormalForm::evaluateMakeBag(TNode n) return emptybag; } -Node NormalForm::evaluateBagCount(TNode n) +Node BagsUtils::evaluateBagCount(TNode n) { Assert(n.getKind() == BAG_COUNT); // Examples @@ -263,7 +288,7 @@ Node NormalForm::evaluateBagCount(TNode n) return nm->mkConstInt(Rational(0)); } -Node NormalForm::evaluateDuplicateRemoval(TNode n) +Node BagsUtils::evaluateDuplicateRemoval(TNode n) { Assert(n.getKind() == BAG_DUPLICATE_REMOVAL); @@ -288,7 +313,7 @@ Node NormalForm::evaluateDuplicateRemoval(TNode n) return bag; } -Node NormalForm::evaluateUnionDisjoint(TNode n) +Node BagsUtils::evaluateUnionDisjoint(TNode n) { Assert(n.getKind() == BAG_UNION_DISJOINT); // Example @@ -348,7 +373,7 @@ Node NormalForm::evaluateUnionDisjoint(TNode n) n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); } -Node NormalForm::evaluateUnionMax(TNode n) +Node BagsUtils::evaluateUnionMax(TNode n) { Assert(n.getKind() == BAG_UNION_MAX); // Example @@ -408,7 +433,7 @@ Node NormalForm::evaluateUnionMax(TNode n) n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); } -Node NormalForm::evaluateIntersectionMin(TNode n) +Node BagsUtils::evaluateIntersectionMin(TNode n) { Assert(n.getKind() == BAG_INTER_MIN); // Example @@ -454,7 +479,7 @@ Node NormalForm::evaluateIntersectionMin(TNode n) n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); } -Node NormalForm::evaluateDifferenceSubtract(TNode n) +Node BagsUtils::evaluateDifferenceSubtract(TNode n) { Assert(n.getKind() == BAG_DIFFERENCE_SUBTRACT); // Example @@ -506,7 +531,7 @@ Node NormalForm::evaluateDifferenceSubtract(TNode n) n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); } -Node NormalForm::evaluateDifferenceRemove(TNode n) +Node BagsUtils::evaluateDifferenceRemove(TNode n) { Assert(n.getKind() == BAG_DIFFERENCE_REMOVE); // Example @@ -557,7 +582,7 @@ Node NormalForm::evaluateDifferenceRemove(TNode n) n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); } -Node NormalForm::evaluateChoose(TNode n) +Node BagsUtils::evaluateChoose(TNode n) { Assert(n.getKind() == BAG_CHOOSE); // Examples @@ -571,7 +596,7 @@ Node NormalForm::evaluateChoose(TNode n) throw LogicException("BAG_CHOOSE_TOTAL is not supported yet"); } -Node NormalForm::evaluateCard(TNode n) +Node BagsUtils::evaluateCard(TNode n) { Assert(n.getKind() == BAG_CARD); // Examples @@ -592,7 +617,7 @@ Node NormalForm::evaluateCard(TNode n) return sumNode; } -Node NormalForm::evaluateIsSingleton(TNode n) +Node BagsUtils::evaluateIsSingleton(TNode n) { Assert(n.getKind() == BAG_IS_SINGLETON); // Examples @@ -610,7 +635,7 @@ Node NormalForm::evaluateIsSingleton(TNode n) return NodeManager::currentNM()->mkConst(false); } -Node NormalForm::evaluateFromSet(TNode n) +Node BagsUtils::evaluateFromSet(TNode n) { Assert(n.getKind() == BAG_FROM_SET); @@ -635,7 +660,7 @@ Node NormalForm::evaluateFromSet(TNode n) return bag; } -Node NormalForm::evaluateToSet(TNode n) +Node BagsUtils::evaluateToSet(TNode n) { Assert(n.getKind() == BAG_TO_SET); @@ -659,7 +684,7 @@ Node NormalForm::evaluateToSet(TNode n) return set; } -Node NormalForm::evaluateBagMap(TNode n) +Node BagsUtils::evaluateBagMap(TNode n) { Assert(n.getKind() == BAG_MAP); @@ -672,7 +697,7 @@ Node NormalForm::evaluateBagMap(TNode n) // (bag ((lambda ((x String)) "z") "b") 3)) = // (bag "z" 5) - std::map elements = NormalForm::getBagElements(n[1]); + std::map elements = BagsUtils::getBagElements(n[1]); std::map mappedElements; std::map::iterator it = elements.begin(); NodeManager* nm = NodeManager::currentNM(); @@ -683,11 +708,42 @@ Node NormalForm::evaluateBagMap(TNode n) ++it; } TypeNode t = nm->mkBagType(n[0].getType().getRangeType()); - Node ret = NormalForm::constructConstantBagFromElements(t, mappedElements); + Node ret = BagsUtils::constructConstantBagFromElements(t, mappedElements); + return ret; +} + +Node BagsUtils::evaluateBagFilter(TNode n) +{ + Assert(n.getKind() == BAG_FILTER); + + // - (bag.filter p (as bag.empty (Bag T)) = (as bag.empty (Bag T)) + // - (bag.filter p (bag.union_disjoint (bag "a" 3) (bag "b" 2))) = + // (bag.union_disjoint + // (ite (p "a") (bag "a" 3) (as bag.empty (Bag T))) + // (ite (p "b") (bag "b" 2) (as bag.empty (Bag T))) + + Node P = n[0]; + Node A = n[1]; + TypeNode bagType = A.getType(); + NodeManager* nm = NodeManager::currentNM(); + Node empty = nm->mkConst(EmptyBag(bagType)); + + std::map elements = getBagElements(n[1]); + std::vector bags; + + for (const auto& [e, count] : elements) + { + Node multiplicity = nm->mkConst(CONST_RATIONAL, count); + Node bag = nm->mkBag(bagType.getBagElementType(), e, multiplicity); + Node pOfe = nm->mkNode(APPLY_UF, P, e); + Node ite = nm->mkNode(ITE, pOfe, bag, empty); + bags.push_back(ite); + } + Node ret = computeDisjointUnion(bagType, bags); return ret; } -Node NormalForm::evaluateBagFold(TNode n) +Node BagsUtils::evaluateBagFold(TNode n) { Assert(n.getKind() == BAG_FOLD); @@ -703,7 +759,7 @@ Node NormalForm::evaluateBagFold(TNode n) Node f = n[0]; // combining function Node ret = n[1]; // initial value Node A = n[2]; // bag - std::map elements = NormalForm::getBagElements(A); + std::map elements = BagsUtils::getBagElements(A); std::map::iterator it = elements.begin(); NodeManager* nm = NodeManager::currentNM(); diff --git a/src/theory/bags/normal_form.h b/src/theory/bags/bags_utils.h similarity index 94% rename from src/theory/bags/normal_form.h rename to src/theory/bags/bags_utils.h index 5275678ff..61473a023 100644 --- a/src/theory/bags/normal_form.h +++ b/src/theory/bags/bags_utils.h @@ -10,7 +10,7 @@ * directory for licensing information. * **************************************************************************** * - * Normal form for bag constants. + * Utility functions for bags. */ #include @@ -24,9 +24,16 @@ namespace cvc5 { namespace theory { namespace bags { -class NormalForm +class BagsUtils { public: + /** + * @param bagType type of bags + * @param bags a vector of bag nodes + * @return disjoint union of these bags + */ + static Node computeDisjointUnion(TypeNode bagType, + const std::vector& bags); /** * Returns true if n is considered a to be a (canonical) constant bag value. * A canonical bag value is one whose AST is: @@ -81,6 +88,12 @@ class NormalForm */ static Node evaluateBagFold(TNode n); + /** + * @param n has the form (bag.filter p A) where A is a constant bag + * @return A filtered with predicate p + */ + static Node evaluateBagFilter(TNode n); + private: /** * a high order helper function that return a constant bag that is the result diff --git a/src/theory/bags/card_solver.cpp b/src/theory/bags/card_solver.cpp index 4ec009c7d..2a35fb2bd 100644 --- a/src/theory/bags/card_solver.cpp +++ b/src/theory/bags/card_solver.cpp @@ -17,9 +17,9 @@ #include "expr/emptybag.h" #include "smt/logic_exception.h" +#include "theory/bags/bags_utils.h" #include "theory/bags/inference_generator.h" #include "theory/bags/inference_manager.h" -#include "theory/bags/normal_form.h" #include "theory/bags/solver_state.h" #include "theory/bags/term_registry.h" #include "theory/uf/equality_engine_iterator.h" diff --git a/src/theory/bags/inference_generator.cpp b/src/theory/bags/inference_generator.cpp index 92aa1a0ea..3247548c5 100644 --- a/src/theory/bags/inference_generator.cpp +++ b/src/theory/bags/inference_generator.cpp @@ -517,6 +517,52 @@ InferInfo InferenceGenerator::mapUpwards( return inferInfo; } +InferInfo InferenceGenerator::filterDownwards(Node n, Node e) +{ + Assert(n.getKind() == BAG_FILTER && n[1].getType().isBag()); + Assert(e.getType().isSubtypeOf(n[1].getType().getBagElementType())); + + Node P = n[0]; + Node A = n[1]; + InferInfo inferInfo(d_im, InferenceId::BAGS_FILTER_DOWN); + + Node countA = getMultiplicityTerm(e, A); + Node skolem = getSkolem(n, inferInfo); + Node count = getMultiplicityTerm(e, skolem); + + Node member = d_nm->mkNode(GEQ, count, d_one); + Node pOfe = d_nm->mkNode(APPLY_UF, P, e); + Node equal = count.eqNode(countA); + + inferInfo.d_conclusion = pOfe.andNode(equal); + inferInfo.d_premises.push_back(member); + return inferInfo; +} + +InferInfo InferenceGenerator::filterUpwards(Node n, Node e) +{ + Assert(n.getKind() == BAG_FILTER && n[1].getType().isBag()); + Assert(e.getType().isSubtypeOf(n[1].getType().getBagElementType())); + + Node P = n[0]; + Node A = n[1]; + InferInfo inferInfo(d_im, InferenceId::BAGS_FILTER_UP); + + Node countA = getMultiplicityTerm(e, A); + Node skolem = getSkolem(n, inferInfo); + Node count = getMultiplicityTerm(e, skolem); + + Node member = d_nm->mkNode(GEQ, countA, d_one); + Node pOfe = d_nm->mkNode(APPLY_UF, P, e); + Node equal = count.eqNode(countA); + Node included = pOfe.andNode(equal); + Node equalZero = count.eqNode(d_zero); + Node excluded = pOfe.notNode().andNode(equalZero); + inferInfo.d_conclusion = included.orNode(excluded); + inferInfo.d_premises.push_back(member); + return inferInfo; +} + } // namespace bags } // namespace theory } // namespace cvc5 diff --git a/src/theory/bags/inference_generator.h b/src/theory/bags/inference_generator.h index 2815058b2..3d74dbaa2 100644 --- a/src/theory/bags/inference_generator.h +++ b/src/theory/bags/inference_generator.h @@ -262,6 +262,34 @@ class InferenceGenerator */ InferInfo mapUpwards(Node n, Node uf, Node preImageSize, Node y, Node x); + /** + * @param n is (bag.filter p A) where p is a function (-> E Bool), + * A a bag of type (Bag E) + * @param e is an element of type E + * @return an inference that represents the following implication + * (=> + * (bag.member e skolem) + * (and + * (p e) + * (= (bag.count e skolem) (bag.count A))) + * where skolem is a variable equals (bag.filter p A) + */ + InferInfo filterDownwards(Node n, Node e); + + /** + * @param n is (bag.filter p A) where p is a function (-> E Bool), + * A a bag of type (Bag E) + * @param e is an element of type E + * @return an inference that represents the following implication + * (=> + * (bag.member e A) + * (or + * (and (p e) (= (bag.count e skolem) (bag.count A))) + * (and (not (p e)) (= (bag.count e skolem) 0))) + * where skolem is a variable equals (bag.filter p A) + */ + InferInfo filterUpwards(Node n, Node e); + /** * @param element of type T * @param bag of type (bag T) diff --git a/src/theory/bags/kinds b/src/theory/bags/kinds index d83be5e21..7d995dd7b 100644 --- a/src/theory/bags/kinds +++ b/src/theory/bags/kinds @@ -77,6 +77,10 @@ operator BAG_CHOOSE 1 "return an element in the bag given as a parameter # of the second argument, a bag of type (Bag T1), and returns a bag of type (Bag T2). operator BAG_MAP 2 "bag map function" +# The bag.filter operator takes a predicate of type (-> T Bool) and a bag of type (Bag T) +# and return the same bag excluding those elements that do not satisfy the predicate +operator BAG_FILTER 2 "bag filter operator" + # bag.fold operator combines elements of a bag into a single value. # (bag.fold f t B) folds the elements of bag B starting with term t and using # the combining function f. @@ -103,6 +107,7 @@ typerule BAG_IS_SINGLETON ::cvc5::theory::bags::IsSingletonTypeRule typerule BAG_FROM_SET ::cvc5::theory::bags::FromSetTypeRule typerule BAG_TO_SET ::cvc5::theory::bags::ToSetTypeRule typerule BAG_MAP ::cvc5::theory::bags::BagMapTypeRule +typerule BAG_FILTER ::cvc5::theory::bags::BagFilterTypeRule typerule BAG_FOLD ::cvc5::theory::bags::BagFoldTypeRule construle BAG_UNION_DISJOINT ::cvc5::theory::bags::BinaryOperatorTypeRule diff --git a/src/theory/bags/rewrites.cpp b/src/theory/bags/rewrites.cpp index d8ed9fb95..9bd0c3a86 100644 --- a/src/theory/bags/rewrites.cpp +++ b/src/theory/bags/rewrites.cpp @@ -38,6 +38,9 @@ const char* toString(Rewrite r) case Rewrite::EQ_CONST_FALSE: return "EQ_CONST_FALSE"; case Rewrite::EQ_REFL: return "EQ_REFL"; case Rewrite::EQ_SYM: return "EQ_SYM"; + case Rewrite::FILTER_CONST: return "FILTER_CONST"; + case Rewrite::FILTER_BAG_MAKE: return "FILTER_BAG_MAKE"; + case Rewrite::FILTER_UNION_DISJOINT: return "FILTER_UNION_DISJOINT"; case Rewrite::FROM_SINGLETON: return "FROM_SINGLETON"; case Rewrite::FOLD_BAG: return "FOLD_BAG"; case Rewrite::FOLD_CONST: return "FOLD_CONST"; diff --git a/src/theory/bags/rewrites.h b/src/theory/bags/rewrites.h index 57f106211..e1ef38c4b 100644 --- a/src/theory/bags/rewrites.h +++ b/src/theory/bags/rewrites.h @@ -42,6 +42,9 @@ enum class Rewrite : uint32_t EQ_CONST_FALSE, EQ_REFL, EQ_SYM, + FILTER_CONST, + FILTER_BAG_MAKE, + FILTER_UNION_DISJOINT, FROM_SINGLETON, FOLD_BAG, FOLD_CONST, diff --git a/src/theory/bags/theory_bags.cpp b/src/theory/bags/theory_bags.cpp index 720e97c25..37b6415e0 100644 --- a/src/theory/bags/theory_bags.cpp +++ b/src/theory/bags/theory_bags.cpp @@ -19,7 +19,7 @@ #include "expr/skolem_manager.h" #include "proof/proof_checker.h" #include "smt/logic_exception.h" -#include "theory/bags/normal_form.h" +#include "theory/bags/bags_utils.h" #include "theory/quantifiers/fmf/bounded_integers.h" #include "theory/rewriter.h" #include "theory/theory_model.h" @@ -321,7 +321,7 @@ bool TheoryBags::collectModelValues(TheoryModel* m, Node value = m->getRepresentative(countSkolem); elementReps[key] = value; } - Node constructedBag = NormalForm::constructBagFromElements(tn, elementReps); + Node constructedBag = BagsUtils::constructBagFromElements(tn, elementReps); constructedBag = rewrite(constructedBag); Trace("bags-model") << "constructed bag for " << n << " is: " << constructedBag << std::endl; @@ -352,7 +352,8 @@ bool TheoryBags::collectModelValues(TheoryModel* m, if (constructedRational < rCardRational && !d_env.isFiniteType(elementType)) { - Node newElement = nm->getSkolemManager()->mkDummySkolem("slack", elementType); + Node newElement = + nm->getSkolemManager()->mkDummySkolem("slack", elementType); Trace("bags-model") << "newElement is " << newElement << std::endl; Rational difference = rCardRational - constructedRational; Node multiplicity = nm->mkConst(CONST_RATIONAL, difference); diff --git a/src/theory/bags/theory_bags_type_enumerator.cpp b/src/theory/bags/theory_bags_type_enumerator.cpp index 14fca3297..a24981934 100644 --- a/src/theory/bags/theory_bags_type_enumerator.cpp +++ b/src/theory/bags/theory_bags_type_enumerator.cpp @@ -16,7 +16,7 @@ #include "theory/bags/theory_bags_type_enumerator.h" #include "expr/emptybag.h" -#include "theory/bags/normal_form.h" +#include "theory/bags/bags_utils.h" #include "theory_bags_type_enumerator.h" #include "util/rational.h" @@ -67,11 +67,10 @@ BagEnumerator& BagEnumerator::operator++() else { // increase the multiplicity of one of the elements in the current bag - std::map elements = - NormalForm::getBagElements(d_currentBag); + std::map elements = BagsUtils::getBagElements(d_currentBag); Node element = elements.begin()->first; elements[element] = elements[element] + Rational(1); - d_currentBag = NormalForm::constructConstantBagFromElements( + d_currentBag = BagsUtils::constructConstantBagFromElements( d_currentBag.getType(), elements); } diff --git a/src/theory/bags/theory_bags_type_rules.cpp b/src/theory/bags/theory_bags_type_rules.cpp index 2d218f821..689b0e208 100644 --- a/src/theory/bags/theory_bags_type_rules.cpp +++ b/src/theory/bags/theory_bags_type_rules.cpp @@ -20,7 +20,7 @@ #include "base/check.h" #include "expr/emptybag.h" #include "theory/bags/bag_make_op.h" -#include "theory/bags/normal_form.h" +#include "theory/bags/bags_utils.h" #include "util/cardinality.h" #include "util/rational.h" @@ -63,7 +63,7 @@ bool BinaryOperatorTypeRule::computeIsConst(NodeManager* nodeManager, TNode n) // only UNION_DISJOINT has a const rule in kinds. // Other binary operators do not have const rules in kinds Assert(n.getKind() == kind::BAG_UNION_DISJOINT); - return NormalForm::isConstant(n); + return BagsUtils::isConstant(n); } TypeNode SubBagTypeRule::computeType(NodeManager* nodeManager, @@ -356,6 +356,48 @@ TypeNode BagMapTypeRule::computeType(NodeManager* nodeManager, return retType; } +TypeNode BagFilterTypeRule::computeType(NodeManager* nodeManager, + TNode n, + bool check) +{ + Assert(n.getKind() == kind::BAG_FILTER); + TypeNode functionType = n[0].getType(check); + TypeNode bagType = n[1].getType(check); + if (check) + { + if (!bagType.isBag()) + { + throw TypeCheckingExceptionPrivate( + n, + "bag.filter operator expects a bag in the second argument, " + "a non-bag is found"); + } + + TypeNode elementType = bagType.getBagElementType(); + + if (!(functionType.isFunction())) + { + std::stringstream ss; + ss << "Operator " << n.getKind() << " expects a function of type (-> " + << elementType << " Bool) as a first argument. " + << "Found a term of type '" << functionType << "'."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + std::vector argTypes = functionType.getArgTypes(); + NodeManager* nm = NodeManager::currentNM(); + if (!(argTypes.size() == 1 && argTypes[0] == elementType + && functionType.getRangeType() == nm->booleanType())) + { + std::stringstream ss; + ss << "Operator " << n.getKind() << " expects a function of type (-> " + << elementType << " Bool). " + << "Found a function of type '" << functionType << "'."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + } + return bagType; +} + TypeNode BagFoldTypeRule::computeType(NodeManager* nodeManager, TNode n, bool check) diff --git a/src/theory/bags/theory_bags_type_rules.h b/src/theory/bags/theory_bags_type_rules.h index da9ea75bf..76c179a62 100644 --- a/src/theory/bags/theory_bags_type_rules.h +++ b/src/theory/bags/theory_bags_type_rules.h @@ -141,6 +141,15 @@ struct BagMapTypeRule static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); }; /* struct BagMapTypeRule */ +/** + * Type rule for (bag.filter p B) to make sure p is a unary predicate of type + * (-> T Bool) where B is a bag of type (Bag T) + */ +struct BagFilterTypeRule +{ + static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); +}; /* struct BagFilterTypeRule */ + /** * Type rule for (bag.fold f t A) to make sure f is a binary operation of type * (-> T1 T2 T2), t of type T2, and B is a bag of type (Bag T1) diff --git a/src/theory/inference_id.cpp b/src/theory/inference_id.cpp index d7e5fccbe..240d6e293 100644 --- a/src/theory/inference_id.cpp +++ b/src/theory/inference_id.cpp @@ -120,6 +120,8 @@ const char* toString(InferenceId i) case InferenceId::BAGS_DIFFERENCE_REMOVE: return "BAGS_DIFFERENCE_REMOVE"; case InferenceId::BAGS_DUPLICATE_REMOVAL: return "BAGS_DUPLICATE_REMOVAL"; case InferenceId::BAGS_MAP: return "BAGS_MAP"; + case InferenceId::BAGS_FILTER_DOWN: return "BAGS_FILTER_DOWN"; + case InferenceId::BAGS_FILTER_UP: return "BAGS_FILTER_UP"; case InferenceId::BAGS_FOLD: return "BAGS_FOLD"; case InferenceId::BAGS_CARD: return "BAGS_CARD"; diff --git a/src/theory/inference_id.h b/src/theory/inference_id.h index 4970c1cee..2fb3ae003 100644 --- a/src/theory/inference_id.h +++ b/src/theory/inference_id.h @@ -182,6 +182,8 @@ enum class InferenceId BAGS_DIFFERENCE_REMOVE, BAGS_DUPLICATE_REMOVAL, BAGS_MAP, + BAGS_FILTER_DOWN, + BAGS_FILTER_UP, BAGS_FOLD, BAGS_CARD, // ---------------------------------- end bags theory diff --git a/test/regress/CMakeLists.txt b/test/regress/CMakeLists.txt index c86be3e76..ec3b13caa 100644 --- a/test/regress/CMakeLists.txt +++ b/test/regress/CMakeLists.txt @@ -1653,6 +1653,11 @@ set(regress_1_tests regress1/bags/duplicate_removal1.smt2 regress1/bags/duplicate_removal2.smt2 regress1/bags/emptybag1.smt2 + regress1/bags/filter1.smt2 + regress1/bags/filter2.smt2 + regress1/bags/filter3.smt2 + regress1/bags/filter4.smt2 + regress1/bags/filter5.smt2 regress1/bags/fol_0000119.smt2 regress1/bags/fold1.smt2 regress1/bags/fuzzy1.smt2 diff --git a/test/regress/regress1/bags/filter1.smt2 b/test/regress/regress1/bags/filter1.smt2 new file mode 100644 index 000000000..65e87c10f --- /dev/null +++ b/test/regress/regress1/bags/filter1.smt2 @@ -0,0 +1,11 @@ +(set-logic ALL) +(set-info :status sat) +(declare-fun A () (Bag Int)) +(declare-fun B () (Bag Int)) +(declare-fun x () Int) +(declare-fun y () Int) +(declare-fun p (Int) Bool) +(assert (= A (bag.union_max (bag x 1) (bag y 2)))) +(assert (= B (bag.filter p A))) +(assert (distinct (p x) (p y))) +(check-sat) diff --git a/test/regress/regress1/bags/filter2.smt2 b/test/regress/regress1/bags/filter2.smt2 new file mode 100644 index 000000000..62b6403fd --- /dev/null +++ b/test/regress/regress1/bags/filter2.smt2 @@ -0,0 +1,9 @@ +(set-logic HO_ALL) +(set-info :status sat) +(set-option :fmf-bound true) +(declare-fun A () (Bag Int)) +(declare-fun B () (Bag Int)) +(declare-fun p (Int) Bool) +(assert (= B (bag.filter p A))) +(assert (= (bag.count (- 2) B) 57)) +(check-sat) diff --git a/test/regress/regress1/bags/filter3.smt2 b/test/regress/regress1/bags/filter3.smt2 new file mode 100644 index 000000000..10f637015 --- /dev/null +++ b/test/regress/regress1/bags/filter3.smt2 @@ -0,0 +1,10 @@ +(set-logic HO_ALL) +(set-info :status unsat) +(set-option :fmf-bound true) +(declare-fun A () (Bag Int)) +(declare-fun B () (Bag Int)) +(define-fun p ((x Int)) Bool (> x 1)) +(assert (= B (bag.filter p A))) +(assert (= (bag.count 3 B) 57)) +(assert (= (bag.count 3 B) 58)) +(check-sat) diff --git a/test/regress/regress1/bags/filter4.smt2 b/test/regress/regress1/bags/filter4.smt2 new file mode 100644 index 000000000..9be695210 --- /dev/null +++ b/test/regress/regress1/bags/filter4.smt2 @@ -0,0 +1,11 @@ +(set-logic HO_ALL) +(set-info :status unsat) +(declare-fun A () (Bag Int)) +(declare-fun B () (Bag Int)) +(declare-fun element () Int) +(declare-fun p (Int) Bool) +(assert (= B (bag.filter p A))) +(assert (p element)) +(assert (not (bag.member element B))) +(assert (bag.member element A)) +(check-sat) diff --git a/test/regress/regress1/bags/filter5.smt2 b/test/regress/regress1/bags/filter5.smt2 new file mode 100644 index 000000000..74ca05429 --- /dev/null +++ b/test/regress/regress1/bags/filter5.smt2 @@ -0,0 +1,11 @@ +(set-logic HO_ALL) +(set-info :status unsat) +(declare-fun A () (Bag Int)) +(declare-fun B () (Bag Int)) +(declare-fun element () Int) +(declare-fun p (Int) Bool) +(assert (= B (bag.filter p A))) +(assert (p element)) +(assert (not (bag.member element A))) +(assert (bag.member element B)) +(check-sat) diff --git a/test/regress/regress1/bags/map1.smt2 b/test/regress/regress1/bags/map1.smt2 index c7dc3d636..748d327dd 100644 --- a/test/regress/regress1/bags/map1.smt2 +++ b/test/regress/regress1/bags/map1.smt2 @@ -6,7 +6,6 @@ (declare-fun y () Int) (declare-fun f (Int) Int) (assert (= A (bag.union_max (bag x 1) (bag y 2)))) -(assert (= A (bag.union_max (bag x 1) (bag y 2)))) (assert (= B (bag.map f A))) (assert (distinct (f x) (f y) x y)) (check-sat) diff --git a/test/unit/theory/theory_bags_normal_form_white.cpp b/test/unit/theory/theory_bags_normal_form_white.cpp index 5f3abfcee..4c8c41f0b 100644 --- a/test/unit/theory/theory_bags_normal_form_white.cpp +++ b/test/unit/theory/theory_bags_normal_form_white.cpp @@ -18,7 +18,7 @@ #include "expr/emptyset.h" #include "test_smt.h" #include "theory/bags/bags_rewriter.h" -#include "theory/bags/normal_form.h" +#include "theory/bags/bags_utils.h" #include "theory/strings/type_enumerator.h" #include "util/rational.h" #include "util/string.h" @@ -65,7 +65,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, empty_bag_normal_form) Node emptybag = d_nodeManager->mkConst(EmptyBag(d_nodeManager->stringType())); // empty bags are in normal form ASSERT_TRUE(emptybag.isConst()); - Node n = NormalForm::evaluate(emptybag); + Node n = BagsUtils::evaluate(emptybag); ASSERT_EQ(emptybag, n); } @@ -89,9 +89,9 @@ TEST_F(TestTheoryWhiteBagsNormalForm, mkBag_constant_element) ASSERT_FALSE(negative.isConst()); ASSERT_FALSE(zero.isConst()); - ASSERT_EQ(emptybag, NormalForm::evaluate(negative)); - ASSERT_EQ(emptybag, NormalForm::evaluate(zero)); - ASSERT_EQ(positive, NormalForm::evaluate(positive)); + ASSERT_EQ(emptybag, BagsUtils::evaluate(negative)); + ASSERT_EQ(emptybag, BagsUtils::evaluate(zero)); + ASSERT_EQ(positive, BagsUtils::evaluate(positive)); } TEST_F(TestTheoryWhiteBagsNormalForm, bag_count) @@ -126,25 +126,25 @@ TEST_F(TestTheoryWhiteBagsNormalForm, bag_count) Node input1 = d_nodeManager->mkNode(BAG_COUNT, x, empty); Node output1 = zero; - ASSERT_EQ(output1, NormalForm::evaluate(input1)); + ASSERT_EQ(output1, BagsUtils::evaluate(input1)); Node input2 = d_nodeManager->mkNode(BAG_COUNT, x, y_5); Node output2 = zero; - ASSERT_EQ(output2, NormalForm::evaluate(input2)); + ASSERT_EQ(output2, BagsUtils::evaluate(input2)); Node input3 = d_nodeManager->mkNode(BAG_COUNT, x, x_4); Node output3 = four; - ASSERT_EQ(output2, NormalForm::evaluate(input2)); + ASSERT_EQ(output2, BagsUtils::evaluate(input2)); Node unionDisjointXY = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_4, y_5); Node input4 = d_nodeManager->mkNode(BAG_COUNT, x, unionDisjointXY); Node output4 = four; - ASSERT_EQ(output3, NormalForm::evaluate(input3)); + ASSERT_EQ(output3, BagsUtils::evaluate(input3)); Node unionDisjointYZ = d_nodeManager->mkNode(BAG_UNION_DISJOINT, y_5, z_5); Node input5 = d_nodeManager->mkNode(BAG_COUNT, x, unionDisjointYZ); Node output5 = zero; - ASSERT_EQ(output4, NormalForm::evaluate(input4)); + ASSERT_EQ(output4, BagsUtils::evaluate(input4)); } TEST_F(TestTheoryWhiteBagsNormalForm, duplicate_removal) @@ -161,7 +161,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, duplicate_removal) EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType()))); Node input1 = d_nodeManager->mkNode(BAG_DUPLICATE_REMOVAL, emptybag); Node output1 = emptybag; - ASSERT_EQ(output1, NormalForm::evaluate(input1)); + ASSERT_EQ(output1, BagsUtils::evaluate(input1)); Node x = d_nodeManager->mkConst(String("x")); Node y = d_nodeManager->mkConst(String("y")); @@ -186,12 +186,12 @@ TEST_F(TestTheoryWhiteBagsNormalForm, duplicate_removal) Node input2 = d_nodeManager->mkNode(BAG_DUPLICATE_REMOVAL, x_4); Node output2 = x_1; - ASSERT_EQ(output2, NormalForm::evaluate(input2)); + ASSERT_EQ(output2, BagsUtils::evaluate(input2)); Node normalBag = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_4, y_5); Node input3 = d_nodeManager->mkNode(BAG_DUPLICATE_REMOVAL, normalBag); Node output3 = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_1, y_1); - ASSERT_EQ(output3, NormalForm::evaluate(input3)); + ASSERT_EQ(output3, BagsUtils::evaluate(input3)); } TEST_F(TestTheoryWhiteBagsNormalForm, union_max) @@ -241,7 +241,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, union_max) d_nodeManager->mkNode(BAG_UNION_DISJOINT, y_1, z_2)); ASSERT_TRUE(output.isConst()); - ASSERT_EQ(output, NormalForm::evaluate(input)); + ASSERT_EQ(output, BagsUtils::evaluate(input)); } TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint1) @@ -265,12 +265,12 @@ TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint1) Node unionDisjointAB = d_nodeManager->mkNode(BAG_UNION_DISJOINT, A, B); // unionDisjointAB is already in a normal form ASSERT_TRUE(unionDisjointAB.isConst()); - ASSERT_EQ(unionDisjointAB, NormalForm::evaluate(unionDisjointAB)); + ASSERT_EQ(unionDisjointAB, BagsUtils::evaluate(unionDisjointAB)); Node unionDisjointBA = d_nodeManager->mkNode(BAG_UNION_DISJOINT, B, A); // unionDisjointAB is the normal form of unionDisjointBA ASSERT_FALSE(unionDisjointBA.isConst()); - ASSERT_EQ(unionDisjointAB, NormalForm::evaluate(unionDisjointBA)); + ASSERT_EQ(unionDisjointAB, BagsUtils::evaluate(unionDisjointBA)); Node unionDisjointAB_C = d_nodeManager->mkNode(BAG_UNION_DISJOINT, unionDisjointAB, C); @@ -280,7 +280,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint1) // unionDisjointA_BC is the normal form of unionDisjointAB_C ASSERT_FALSE(unionDisjointAB_C.isConst()); ASSERT_TRUE(unionDisjointA_BC.isConst()); - ASSERT_EQ(unionDisjointA_BC, NormalForm::evaluate(unionDisjointAB_C)); + ASSERT_EQ(unionDisjointA_BC, BagsUtils::evaluate(unionDisjointAB_C)); Node unionDisjointAA = d_nodeManager->mkNode(BAG_UNION_DISJOINT, A, A); Node AA = @@ -289,7 +289,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint1) d_nodeManager->mkConst(CONST_RATIONAL, Rational(4))); ASSERT_FALSE(unionDisjointAA.isConst()); ASSERT_TRUE(AA.isConst()); - ASSERT_EQ(AA, NormalForm::evaluate(unionDisjointAA)); + ASSERT_EQ(AA, BagsUtils::evaluate(unionDisjointAA)); } TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint2) @@ -339,7 +339,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint2) d_nodeManager->mkNode(BAG_UNION_DISJOINT, y_1, z_2)); ASSERT_TRUE(output.isConst()); - ASSERT_EQ(output, NormalForm::evaluate(input)); + ASSERT_EQ(output, BagsUtils::evaluate(input)); } TEST_F(TestTheoryWhiteBagsNormalForm, intersection_min) @@ -384,7 +384,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, intersection_min) Node output = x_3; ASSERT_TRUE(output.isConst()); - ASSERT_EQ(output, NormalForm::evaluate(input)); + ASSERT_EQ(output, BagsUtils::evaluate(input)); } TEST_F(TestTheoryWhiteBagsNormalForm, difference_subtract) @@ -433,7 +433,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, difference_subtract) Node output = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_1, z_2); ASSERT_TRUE(output.isConst()); - ASSERT_EQ(output, NormalForm::evaluate(input)); + ASSERT_EQ(output, BagsUtils::evaluate(input)); } TEST_F(TestTheoryWhiteBagsNormalForm, difference_remove) @@ -482,7 +482,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, difference_remove) Node output = z_2; ASSERT_TRUE(output.isConst()); - ASSERT_EQ(output, NormalForm::evaluate(input)); + ASSERT_EQ(output, BagsUtils::evaluate(input)); } TEST_F(TestTheoryWhiteBagsNormalForm, bag_card) @@ -509,16 +509,16 @@ TEST_F(TestTheoryWhiteBagsNormalForm, bag_card) Node input1 = d_nodeManager->mkNode(BAG_CARD, empty); Node output1 = d_nodeManager->mkConst(CONST_RATIONAL, Rational(0)); - ASSERT_EQ(output1, NormalForm::evaluate(input1)); + ASSERT_EQ(output1, BagsUtils::evaluate(input1)); Node input2 = d_nodeManager->mkNode(BAG_CARD, x_4); Node output2 = d_nodeManager->mkConst(CONST_RATIONAL, Rational(4)); - ASSERT_EQ(output2, NormalForm::evaluate(input2)); + ASSERT_EQ(output2, BagsUtils::evaluate(input2)); Node union_disjoint = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_4, y_1); Node input3 = d_nodeManager->mkNode(BAG_CARD, union_disjoint); Node output3 = d_nodeManager->mkConst(CONST_RATIONAL, Rational(5)); - ASSERT_EQ(output3, NormalForm::evaluate(input3)); + ASSERT_EQ(output3, BagsUtils::evaluate(input3)); } TEST_F(TestTheoryWhiteBagsNormalForm, is_singleton) @@ -552,20 +552,20 @@ TEST_F(TestTheoryWhiteBagsNormalForm, is_singleton) Node input1 = d_nodeManager->mkNode(BAG_IS_SINGLETON, empty); Node output1 = falseNode; - ASSERT_EQ(output1, NormalForm::evaluate(input1)); + ASSERT_EQ(output1, BagsUtils::evaluate(input1)); Node input2 = d_nodeManager->mkNode(BAG_IS_SINGLETON, x_1); Node output2 = trueNode; - ASSERT_EQ(output2, NormalForm::evaluate(input2)); + ASSERT_EQ(output2, BagsUtils::evaluate(input2)); Node input3 = d_nodeManager->mkNode(BAG_IS_SINGLETON, x_4); Node output3 = falseNode; - ASSERT_EQ(output2, NormalForm::evaluate(input2)); + ASSERT_EQ(output2, BagsUtils::evaluate(input2)); Node union_disjoint = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_1, y_1); Node input4 = d_nodeManager->mkNode(BAG_IS_SINGLETON, union_disjoint); Node output4 = falseNode; - ASSERT_EQ(output3, NormalForm::evaluate(input3)); + ASSERT_EQ(output3, BagsUtils::evaluate(input3)); } TEST_F(TestTheoryWhiteBagsNormalForm, from_set) @@ -583,7 +583,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, from_set) EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType()))); Node input1 = d_nodeManager->mkNode(BAG_FROM_SET, emptyset); Node output1 = emptybag; - ASSERT_EQ(output1, NormalForm::evaluate(input1)); + ASSERT_EQ(output1, BagsUtils::evaluate(input1)); Node x = d_nodeManager->mkConst(String("x")); Node y = d_nodeManager->mkConst(String("y")); @@ -602,13 +602,13 @@ TEST_F(TestTheoryWhiteBagsNormalForm, from_set) Node input2 = d_nodeManager->mkNode(BAG_FROM_SET, xSingleton); Node output2 = x_1; - ASSERT_EQ(output2, NormalForm::evaluate(input2)); + ASSERT_EQ(output2, BagsUtils::evaluate(input2)); // for normal sets, the first node is the largest, not smallest Node normalSet = d_nodeManager->mkNode(SET_UNION, ySingleton, xSingleton); Node input3 = d_nodeManager->mkNode(BAG_FROM_SET, normalSet); Node output3 = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_1, y_1); - ASSERT_EQ(output3, NormalForm::evaluate(input3)); + ASSERT_EQ(output3, BagsUtils::evaluate(input3)); } TEST_F(TestTheoryWhiteBagsNormalForm, to_set) @@ -626,7 +626,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, to_set) EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType()))); Node input1 = d_nodeManager->mkNode(BAG_TO_SET, emptybag); Node output1 = emptyset; - ASSERT_EQ(output1, NormalForm::evaluate(input1)); + ASSERT_EQ(output1, BagsUtils::evaluate(input1)); Node x = d_nodeManager->mkConst(String("x")); Node y = d_nodeManager->mkConst(String("y")); @@ -645,13 +645,13 @@ TEST_F(TestTheoryWhiteBagsNormalForm, to_set) Node input2 = d_nodeManager->mkNode(BAG_TO_SET, x_4); Node output2 = xSingleton; - ASSERT_EQ(output2, NormalForm::evaluate(input2)); + ASSERT_EQ(output2, BagsUtils::evaluate(input2)); // for normal sets, the first node is the largest, not smallest Node normalBag = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_4, y_5); Node input3 = d_nodeManager->mkNode(BAG_TO_SET, normalBag); Node output3 = d_nodeManager->mkNode(SET_UNION, ySingleton, xSingleton); - ASSERT_EQ(output3, NormalForm::evaluate(input3)); + ASSERT_EQ(output3, BagsUtils::evaluate(input3)); } } // namespace test } // namespace cvc5 -- 2.30.2