From: mudathirmahgoub Date: Tue, 1 Feb 2022 14:58:04 +0000 (-0600) Subject: Add bag.filter operator (#8006) X-Git-Tag: cvc5-1.0.0~488 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=2d64f408f416c601b3b545984ca1b6c31c151f16;p=cvc5.git Add bag.filter operator (#8006) --- diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index a1ad056f1..1715257f7 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -549,6 +549,8 @@ libcvc5_add_sources( theory/bags/bag_reduction.h theory/bags/bags_statistics.cpp theory/bags/bags_statistics.h + theory/bags/bags_utils.cpp + theory/bags/bags_utils.h theory/bags/card_solver.cpp theory/bags/card_solver.h theory/bags/infer_info.cpp @@ -557,8 +559,6 @@ libcvc5_add_sources( theory/bags/inference_generator.h theory/bags/inference_manager.cpp theory/bags/inference_manager.h - theory/bags/normal_form.cpp - theory/bags/normal_form.h theory/bags/rewrites.cpp theory/bags/rewrites.h theory/bags/solver_state.cpp diff --git a/src/api/cpp/cvc5.cpp b/src/api/cpp/cvc5.cpp index 458dd359a..df9a8b8ae 100644 --- a/src/api/cpp/cvc5.cpp +++ b/src/api/cpp/cvc5.cpp @@ -312,6 +312,7 @@ const static std::unordered_map s_kinds{ {BAG_FROM_SET, cvc5::Kind::BAG_FROM_SET}, {BAG_TO_SET, cvc5::Kind::BAG_TO_SET}, {BAG_MAP, cvc5::Kind::BAG_MAP}, + {BAG_FILTER, cvc5::Kind::BAG_FILTER}, {BAG_FOLD, cvc5::Kind::BAG_FOLD}, /* Strings ------------------------------------------------------------- */ {STRING_CONCAT, cvc5::Kind::STRING_CONCAT}, @@ -624,6 +625,7 @@ const static std::unordered_map {cvc5::Kind::BAG_FROM_SET, BAG_FROM_SET}, {cvc5::Kind::BAG_TO_SET, BAG_TO_SET}, {cvc5::Kind::BAG_MAP, BAG_MAP}, + {cvc5::Kind::BAG_FILTER, BAG_FILTER}, {cvc5::Kind::BAG_FOLD, BAG_FOLD}, /* Strings --------------------------------------------------------- */ {cvc5::Kind::STRING_CONCAT, STRING_CONCAT}, diff --git a/src/api/cpp/cvc5_kind.h b/src/api/cpp/cvc5_kind.h index 3bd896814..dba4df07f 100644 --- a/src/api/cpp/cvc5_kind.h +++ b/src/api/cpp/cvc5_kind.h @@ -2539,6 +2539,23 @@ enum Kind : int32_t * - `Solver::mkTerm(Kind kind, const std::vector& children) const` */ BAG_MAP, + /** + * bag.filter operator filters the elements of a bag. + * (bag.filter p B) takes a predicate p of type (-> T Bool) as a first + * argument, and a bag B of type (Bag T) as a second argument, and returns a + * subbag of type (Bag T) that includes all elements of B that satisfy p + * with the same multiplicity. + * + * Parameters: + * - 1: a function of type (-> T Bool) + * - 2: a bag of type (Bag T) + * + * Create with: + * - `Solver::mkTerm(Kind kind, const Term& child1, const Term& child2) + * const` + * - `Solver::mkTerm(Kind kind, const std::vector& children) const` + */ + BAG_FILTER, /** * bag.fold operator combines elements of a bag into a single value. * (bag.fold f t B) folds the elements of bag B starting with term t and using diff --git a/src/parser/smt2/smt2.cpp b/src/parser/smt2/smt2.cpp index 8eed51baa..a93596633 100644 --- a/src/parser/smt2/smt2.cpp +++ b/src/parser/smt2/smt2.cpp @@ -625,6 +625,7 @@ Command* Smt2::setLogic(std::string name, bool fromCommand) addOperator(api::BAG_FROM_SET, "bag.from_set"); addOperator(api::BAG_TO_SET, "bag.to_set"); addOperator(api::BAG_MAP, "bag.map"); + addOperator(api::BAG_FILTER, "bag.filter"); addOperator(api::BAG_FOLD, "bag.fold"); } if(d_logic.isTheoryEnabled(theory::THEORY_STRINGS)) { diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index 067cf27fe..cf85e6b0e 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -1123,6 +1123,7 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v) case kind::BAG_FROM_SET: return "bag.from_set"; case kind::BAG_TO_SET: return "bag.to_set"; case kind::BAG_MAP: return "bag.map"; + case kind::BAG_FILTER: return "bag.filter"; case kind::BAG_FOLD: return "bag.fold"; // fp theory diff --git a/src/theory/bags/bag_solver.cpp b/src/theory/bags/bag_solver.cpp index 55367bb89..ed4b501f3 100644 --- a/src/theory/bags/bag_solver.cpp +++ b/src/theory/bags/bag_solver.cpp @@ -16,9 +16,9 @@ #include "theory/bags/bag_solver.h" #include "expr/emptybag.h" +#include "theory/bags/bags_utils.h" #include "theory/bags/inference_generator.h" #include "theory/bags/inference_manager.h" -#include "theory/bags/normal_form.h" #include "theory/bags/solver_state.h" #include "theory/bags/term_registry.h" #include "theory/uf/equality_engine_iterator.h" @@ -76,6 +76,7 @@ void BagSolver::checkBasicOperations() case kind::BAG_DIFFERENCE_SUBTRACT: checkDifferenceSubtract(n); break; case kind::BAG_DIFFERENCE_REMOVE: checkDifferenceRemove(n); break; case kind::BAG_DUPLICATE_REMOVAL: checkDuplicateRemoval(n); break; + case kind::BAG_FILTER: checkFilter(n); break; case kind::BAG_MAP: checkMap(n); break; default: break; } @@ -280,6 +281,28 @@ void BagSolver::checkMap(Node n) } } +void BagSolver::checkFilter(Node n) +{ + Assert(n.getKind() == BAG_FILTER); + + set elements; + const set& downwards = d_state.getElements(n); + const set& upwards = d_state.getElements(n[0]); + elements.insert(downwards.begin(), downwards.end()); + elements.insert(upwards.begin(), upwards.end()); + + for (const Node& e : elements) + { + InferInfo i = d_ig.filterDownwards(n, d_state.getRepresentative(e)); + d_im.lemmaTheoryInference(&i); + } + for (const Node& e : elements) + { + InferInfo i = d_ig.filterUpwards(n, d_state.getRepresentative(e)); + d_im.lemmaTheoryInference(&i); + } +} + } // namespace bags } // namespace theory } // namespace cvc5 diff --git a/src/theory/bags/bag_solver.h b/src/theory/bags/bag_solver.h index 499b7998d..fca72b22e 100644 --- a/src/theory/bags/bag_solver.h +++ b/src/theory/bags/bag_solver.h @@ -96,6 +96,8 @@ class BagSolver : protected EnvObj void checkDisequalBagTerms(); /** apply inference rules for map operator */ void checkMap(Node n); + /** apply inference rules for filter operator */ + void checkFilter(Node n); /** The solver state object */ SolverState& d_state; diff --git a/src/theory/bags/bags_rewriter.cpp b/src/theory/bags/bags_rewriter.cpp index 40f8d6c95..24f313ad6 100644 --- a/src/theory/bags/bags_rewriter.cpp +++ b/src/theory/bags/bags_rewriter.cpp @@ -16,7 +16,7 @@ #include "theory/bags/bags_rewriter.h" #include "expr/emptybag.h" -#include "theory/bags/normal_form.h" +#include "theory/bags/bags_utils.h" #include "util/rational.h" #include "util/statistics_registry.h" @@ -65,9 +65,9 @@ RewriteResponse BagsRewriter::postRewrite(TNode n) { response = rewriteChoose(n); } - else if (NormalForm::areChildrenConstants(n)) + else if (BagsUtils::areChildrenConstants(n)) { - Node value = NormalForm::evaluate(n); + Node value = BagsUtils::evaluate(n); response = BagsRewriteResponse(value, Rewrite::CONSTANT_EVALUATION); } else @@ -90,6 +90,7 @@ RewriteResponse BagsRewriter::postRewrite(TNode n) case BAG_FROM_SET: response = rewriteFromSet(n); break; case BAG_TO_SET: response = rewriteToSet(n); break; case BAG_MAP: response = postRewriteMap(n); break; + case BAG_FILTER: response = postRewriteFilter(n); break; case BAG_FOLD: response = postRewriteFold(n); break; default: response = BagsRewriteResponse(n, Rewrite::NONE); break; } @@ -533,7 +534,7 @@ BagsRewriteResponse BagsRewriter::postRewriteMap(const TNode& n) const { // (bag.map f (as bag.empty (Bag T1)) = (as bag.empty (Bag T2)) // (bag.map f (bag "a" 3)) = (bag (f "a") 3) - std::map elements = NormalForm::getBagElements(n[1]); + std::map elements = BagsUtils::getBagElements(n[1]); std::map mappedElements; std::map::iterator it = elements.begin(); while (it != elements.end()) @@ -543,7 +544,7 @@ BagsRewriteResponse BagsRewriter::postRewriteMap(const TNode& n) const ++it; } TypeNode t = d_nm->mkBagType(n[0].getType().getRangeType()); - Node ret = NormalForm::constructConstantBagFromElements(t, mappedElements); + Node ret = BagsUtils::constructConstantBagFromElements(t, mappedElements); return BagsRewriteResponse(ret, Rewrite::MAP_CONST); } Kind k = n[1].getKind(); @@ -572,6 +573,49 @@ BagsRewriteResponse BagsRewriter::postRewriteMap(const TNode& n) const } } +BagsRewriteResponse BagsRewriter::postRewriteFilter(const TNode& n) const +{ + Assert(n.getKind() == kind::BAG_FILTER); + Node P = n[0]; + Node A = n[1]; + TypeNode t = A.getType(); + if (A.isConst()) + { + // (bag.filter p (as bag.empty (Bag T)) = (as bag.empty (Bag T)) + // (bag.filter p (bag "a" 3) ((bag "b" 2))) = + // (bag.union_disjoint + // (ite (p "a") (bag "a" 3) (as bag.empty (Bag T))) + // (ite (p "b") (bag "b" 2) (as bag.empty (Bag T))) + + Node ret = BagsUtils::evaluateBagFilter(n); + return BagsRewriteResponse(ret, Rewrite::FILTER_CONST); + } + Kind k = A.getKind(); + switch (k) + { + case BAG_MAKE: + { + // (bag.filter p (bag x y)) = (ite (p x) (bag x y) (as bag.empty (Bag T))) + Node empty = d_nm->mkConst(EmptyBag(t)); + Node pOfe = d_nm->mkNode(APPLY_UF, P, A[0]); + Node ret = d_nm->mkNode(ITE, pOfe, A, empty); + return BagsRewriteResponse(ret, Rewrite::FILTER_BAG_MAKE); + } + + case BAG_UNION_DISJOINT: + { + // (bag.filter p (bag.union_disjoint A B)) = + // (bag.union_disjoint (bag.filter p A) (bag.filter p B)) + Node a = d_nm->mkNode(BAG_FILTER, n[0], n[1][0]); + Node b = d_nm->mkNode(BAG_FILTER, n[0], n[1][1]); + Node ret = d_nm->mkNode(BAG_UNION_DISJOINT, a, b); + return BagsRewriteResponse(ret, Rewrite::FILTER_UNION_DISJOINT); + } + + default: return BagsRewriteResponse(n, Rewrite::NONE); + } +} + BagsRewriteResponse BagsRewriter::postRewriteFold(const TNode& n) const { Assert(n.getKind() == kind::BAG_FOLD); @@ -580,7 +624,7 @@ BagsRewriteResponse BagsRewriter::postRewriteFold(const TNode& n) const Node bag = n[2]; if (bag.isConst()) { - Node value = NormalForm::evaluateBagFold(n); + Node value = BagsUtils::evaluateBagFold(n); return BagsRewriteResponse(value, Rewrite::FOLD_CONST); } Kind k = bag.getKind(); @@ -591,7 +635,7 @@ BagsRewriteResponse BagsRewriter::postRewriteFold(const TNode& n) const if (bag[1].isConst() && bag[1].getConst() > Rational(0)) { // (bag.fold f t (bag x n)) = (f t ... (f t (f t x))) n times, n > 0 - Node value = NormalForm::evaluateBagFold(n); + Node value = BagsUtils::evaluateBagFold(n); return BagsRewriteResponse(value, Rewrite::FOLD_BAG); } break; diff --git a/src/theory/bags/bags_rewriter.h b/src/theory/bags/bags_rewriter.h index b4b1e9043..3e5b69a1c 100644 --- a/src/theory/bags/bags_rewriter.h +++ b/src/theory/bags/bags_rewriter.h @@ -228,6 +228,16 @@ class BagsRewriter : public TheoryRewriter */ BagsRewriteResponse postRewriteMap(const TNode& n) const; + /** + * rewrites for n include: + * - (bag.filter p (as bag.empty (Bag T)) = (as bag.empty (Bag T)) + * - (bag.filter p (bag x y)) = (ite (p x) (bag x y) (as bag.empty (Bag T))) + * - (bag.filter p (bag.union_disjoint A B)) = + * (bag.union_disjoint (bag.filter p A) (bag.filter p B)) + * where p: T -> Bool + */ + BagsRewriteResponse postRewriteFilter(const TNode& n) const; + /** * rewrites for n include: * - (bag.fold f t (as bag.empty (Bag T1))) = t diff --git a/src/theory/bags/bags_utils.cpp b/src/theory/bags/bags_utils.cpp new file mode 100644 index 000000000..39987ce9d --- /dev/null +++ b/src/theory/bags/bags_utils.cpp @@ -0,0 +1,783 @@ +/****************************************************************************** + * Top contributors (to current version): + * Mudathir Mohamed, Aina Niemetz + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * Utility functions for bags. + */ +#include "bags_utils.h" + +#include "expr/emptybag.h" +#include "smt/logic_exception.h" +#include "theory/sets/normal_form.h" +#include "theory/type_enumerator.h" +#include "util/rational.h" + +using namespace cvc5::kind; + +namespace cvc5 { +namespace theory { +namespace bags { + +Node BagsUtils::computeDisjointUnion(TypeNode bagType, + const std::vector& bags) +{ + NodeManager* nm = NodeManager::currentNM(); + if (bags.empty()) + { + return nm->mkConst(EmptyBag(bagType)); + } + if (bags.size() == 1) + { + return bags[0]; + } + Node unionDisjoint = bags[0]; + for (size_t i = 1; i < bags.size(); i++) + { + if (bags[i].getKind() == BAG_EMPTY) + { + continue; + } + unionDisjoint = nm->mkNode(BAG_UNION_DISJOINT, unionDisjoint, bags[i]); + } + return unionDisjoint; +} + +bool BagsUtils::isConstant(TNode n) +{ + if (n.getKind() == BAG_EMPTY) + { + // empty bags are already normalized + return true; + } + if (n.getKind() == BAG_MAKE) + { + // see the implementation in MkBagTypeRule::computeIsConst + return n.isConst(); + } + if (n.getKind() == BAG_UNION_DISJOINT) + { + if (!(n[0].getKind() == kind::BAG_MAKE && n[0].isConst())) + { + // the first child is not a constant + return false; + } + // store the previous element to check the ordering of elements + Node previousElement = n[0][0]; + Node current = n[1]; + while (current.getKind() == BAG_UNION_DISJOINT) + { + if (!(current[0].getKind() == kind::BAG_MAKE && current[0].isConst())) + { + // the current element is not a constant + return false; + } + if (previousElement >= current[0][0]) + { + // the ordering is violated + return false; + } + previousElement = current[0][0]; + current = current[1]; + } + // check last element + if (!(current.getKind() == kind::BAG_MAKE && current.isConst())) + { + // the last element is not a constant + return false; + } + if (previousElement >= current[0]) + { + // the ordering is violated + return false; + } + return true; + } + + // only nodes with kinds EMPTY_BAG, BAG_MAKE, and BAG_UNION_DISJOINT can be + // constants + return false; +} + +bool BagsUtils::areChildrenConstants(TNode n) +{ + return std::all_of(n.begin(), n.end(), [](Node c) { return c.isConst(); }); +} + +Node BagsUtils::evaluate(TNode n) +{ + Assert(areChildrenConstants(n)); + if (n.isConst()) + { + // a constant node is already in a normal form + return n; + } + switch (n.getKind()) + { + case BAG_MAKE: return evaluateMakeBag(n); + case BAG_COUNT: return evaluateBagCount(n); + case BAG_DUPLICATE_REMOVAL: return evaluateDuplicateRemoval(n); + case BAG_UNION_DISJOINT: return evaluateUnionDisjoint(n); + case BAG_UNION_MAX: return evaluateUnionMax(n); + case BAG_INTER_MIN: return evaluateIntersectionMin(n); + case BAG_DIFFERENCE_SUBTRACT: return evaluateDifferenceSubtract(n); + case BAG_DIFFERENCE_REMOVE: return evaluateDifferenceRemove(n); + case BAG_CARD: return evaluateCard(n); + case BAG_IS_SINGLETON: return evaluateIsSingleton(n); + case BAG_FROM_SET: return evaluateFromSet(n); + case BAG_TO_SET: return evaluateToSet(n); + case BAG_MAP: return evaluateBagMap(n); + case BAG_FILTER: return evaluateBagFilter(n); + case BAG_FOLD: return evaluateBagFold(n); + default: break; + } + Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n + << std::endl; +} + +template +Node BagsUtils::evaluateBinaryOperation(const TNode& n, + T1&& equal, + T2&& less, + T3&& greaterOrEqual, + T4&& remainderOfA, + T5&& remainderOfB) +{ + std::map elementsA = getBagElements(n[0]); + std::map elementsB = getBagElements(n[1]); + std::map elements; + + std::map::const_iterator itA = elementsA.begin(); + std::map::const_iterator itB = elementsB.begin(); + + Trace("bags-evaluate") << "[NormalForm::evaluateBinaryOperation " + << n.getKind() << "] " << std::endl + << "elements A: " << elementsA << std::endl + << "elements B: " << elementsB << std::endl; + + while (itA != elementsA.end() && itB != elementsB.end()) + { + if (itA->first == itB->first) + { + equal(elements, itA, itB); + itA++; + itB++; + } + else if (itA->first < itB->first) + { + less(elements, itA, itB); + itA++; + } + else + { + greaterOrEqual(elements, itA, itB); + itB++; + } + } + + // handle the remaining elements from A + remainderOfA(elements, elementsA, itA); + // handle the remaining elements from B + remainderOfB(elements, elementsB, itB); + + Trace("bags-evaluate") << "elements: " << elements << std::endl; + Node bag = constructConstantBagFromElements(n.getType(), elements); + Trace("bags-evaluate") << "bag: " << bag << std::endl; + return bag; +} + +std::map BagsUtils::getBagElements(TNode n) +{ + std::map elements; + if (n.getKind() == BAG_EMPTY) + { + return elements; + } + while (n.getKind() == kind::BAG_UNION_DISJOINT) + { + Assert(n[0].getKind() == kind::BAG_MAKE); + Node element = n[0][0]; + Rational count = n[0][1].getConst(); + elements[element] = count; + n = n[1]; + } + Assert(n.getKind() == kind::BAG_MAKE); + Node lastElement = n[0]; + Rational lastCount = n[1].getConst(); + elements[lastElement] = lastCount; + return elements; +} + +Node BagsUtils::constructConstantBagFromElements( + TypeNode t, const std::map& elements) +{ + Assert(t.isBag()); + NodeManager* nm = NodeManager::currentNM(); + if (elements.empty()) + { + return nm->mkConst(EmptyBag(t)); + } + TypeNode elementType = t.getBagElementType(); + std::map::const_reverse_iterator it = elements.rbegin(); + Node bag = nm->mkBag(elementType, it->first, nm->mkConstInt(it->second)); + while (++it != elements.rend()) + { + Node n = nm->mkBag(elementType, it->first, nm->mkConstInt(it->second)); + bag = nm->mkNode(BAG_UNION_DISJOINT, n, bag); + } + return bag; +} + +Node BagsUtils::constructBagFromElements(TypeNode t, + const std::map& elements) +{ + Assert(t.isBag()); + NodeManager* nm = NodeManager::currentNM(); + if (elements.empty()) + { + return nm->mkConst(EmptyBag(t)); + } + TypeNode elementType = t.getBagElementType(); + std::map::const_reverse_iterator it = elements.rbegin(); + Node bag = nm->mkBag(elementType, it->first, it->second); + while (++it != elements.rend()) + { + Node n = nm->mkBag(elementType, it->first, it->second); + bag = nm->mkNode(BAG_UNION_DISJOINT, n, bag); + } + return bag; +} + +Node BagsUtils::evaluateMakeBag(TNode n) +{ + // the case where n is const should be handled earlier. + // here we handle the case where the multiplicity is zero or negative + Assert(n.getKind() == BAG_MAKE && !n.isConst() + && n[1].getConst().sgn() < 1); + Node emptybag = NodeManager::currentNM()->mkConst(EmptyBag(n.getType())); + return emptybag; +} + +Node BagsUtils::evaluateBagCount(TNode n) +{ + Assert(n.getKind() == BAG_COUNT); + // Examples + // -------- + // - (bag.count "x" (as bag.empty (Bag String))) = 0 + // - (bag.count "x" (bag "y" 5)) = 0 + // - (bag.count "x" (bag "x" 4)) = 4 + // - (bag.count "x" (bag.union_disjoint (bag "x" 4) (bag "y" 5)) = 4 + // - (bag.count "x" (bag.union_disjoint (bag "y" 5) (bag "z" 5)) = 0 + + std::map elements = getBagElements(n[1]); + std::map::iterator it = elements.find(n[0]); + + NodeManager* nm = NodeManager::currentNM(); + if (it != elements.end()) + { + Node count = nm->mkConstInt(it->second); + return count; + } + return nm->mkConstInt(Rational(0)); +} + +Node BagsUtils::evaluateDuplicateRemoval(TNode n) +{ + Assert(n.getKind() == BAG_DUPLICATE_REMOVAL); + + // Examples + // -------- + // - (bag.duplicate_removal (as bag.empty (Bag String))) = (as bag.empty (Bag + // String)) + // - (bag.duplicate_removal (bag "x" 4)) = (bag "x" 1) + // - (bag.duplicate_removal (bag.disjoint_union (bag "x" 3) (bag "y" 5)) = + // (bag.disjoint_union (bag "x" 1) (bag "y" 1) + + std::map oldElements = getBagElements(n[0]); + // copy elements from the old bag + std::map newElements(oldElements); + Rational one = Rational(1); + std::map::iterator it; + for (it = newElements.begin(); it != newElements.end(); it++) + { + it->second = one; + } + Node bag = constructConstantBagFromElements(n[0].getType(), newElements); + return bag; +} + +Node BagsUtils::evaluateUnionDisjoint(TNode n) +{ + Assert(n.getKind() == BAG_UNION_DISJOINT); + // Example + // ------- + // input: (bag.union_disjoint A B) + // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2))) + // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1))) + // output: + // (bag.union_disjoint A B) + // where A = (bag "x" 7) + // B = (bag.union_disjoint (bag "y" 1) (bag "z" 2))) + + auto equal = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // compute the sum of the multiplicities + elements[itA->first] = itA->second + itB->second; + }; + + auto less = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // add the element to the result + elements[itA->first] = itA->second; + }; + + auto greaterOrEqual = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // add the element to the result + elements[itB->first] = itB->second; + }; + + auto remainderOfA = [](std::map& elements, + std::map& elementsA, + std::map::const_iterator& itA) { + // append the remainder of A + while (itA != elementsA.end()) + { + elements[itA->first] = itA->second; + itA++; + } + }; + + auto remainderOfB = [](std::map& elements, + std::map& elementsB, + std::map::const_iterator& itB) { + // append the remainder of B + while (itB != elementsB.end()) + { + elements[itB->first] = itB->second; + itB++; + } + }; + + return evaluateBinaryOperation( + n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); +} + +Node BagsUtils::evaluateUnionMax(TNode n) +{ + Assert(n.getKind() == BAG_UNION_MAX); + // Example + // ------- + // input: (bag.union_max A B) + // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2))) + // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1))) + // output: + // (bag.union_disjoint A B) + // where A = (bag "x" 4) + // B = (bag.union_disjoint (bag "y" 1) (bag "z" 2))) + + auto equal = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // compute the maximum multiplicity + elements[itA->first] = std::max(itA->second, itB->second); + }; + + auto less = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // add to the result + elements[itA->first] = itA->second; + }; + + auto greaterOrEqual = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // add to the result + elements[itB->first] = itB->second; + }; + + auto remainderOfA = [](std::map& elements, + std::map& elementsA, + std::map::const_iterator& itA) { + // append the remainder of A + while (itA != elementsA.end()) + { + elements[itA->first] = itA->second; + itA++; + } + }; + + auto remainderOfB = [](std::map& elements, + std::map& elementsB, + std::map::const_iterator& itB) { + // append the remainder of B + while (itB != elementsB.end()) + { + elements[itB->first] = itB->second; + itB++; + } + }; + + return evaluateBinaryOperation( + n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); +} + +Node BagsUtils::evaluateIntersectionMin(TNode n) +{ + Assert(n.getKind() == BAG_INTER_MIN); + // Example + // ------- + // input: (bag.inter_min A B) + // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2))) + // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1))) + // output: + // (bag "x" 3) + + auto equal = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // compute the minimum multiplicity + elements[itA->first] = std::min(itA->second, itB->second); + }; + + auto less = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // do nothing + }; + + auto greaterOrEqual = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // do nothing + }; + + auto remainderOfA = [](std::map& elements, + std::map& elementsA, + std::map::const_iterator& itA) { + // do nothing + }; + + auto remainderOfB = [](std::map& elements, + std::map& elementsB, + std::map::const_iterator& itB) { + // do nothing + }; + + return evaluateBinaryOperation( + n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); +} + +Node BagsUtils::evaluateDifferenceSubtract(TNode n) +{ + Assert(n.getKind() == BAG_DIFFERENCE_SUBTRACT); + // Example + // ------- + // input: (bag.difference_subtract A B) + // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2))) + // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1))) + // output: + // (bag.union_disjoint (bag "x" 1) (bag "z" 2)) + + auto equal = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // subtract the multiplicities + elements[itA->first] = itA->second - itB->second; + }; + + auto less = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // itA->first is not in B, so we add it to the difference subtract + elements[itA->first] = itA->second; + }; + + auto greaterOrEqual = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // itB->first is not in A, so we just skip it + }; + + auto remainderOfA = [](std::map& elements, + std::map& elementsA, + std::map::const_iterator& itA) { + // append the remainder of A + while (itA != elementsA.end()) + { + elements[itA->first] = itA->second; + itA++; + } + }; + + auto remainderOfB = [](std::map& elements, + std::map& elementsB, + std::map::const_iterator& itB) { + // do nothing + }; + + return evaluateBinaryOperation( + n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); +} + +Node BagsUtils::evaluateDifferenceRemove(TNode n) +{ + Assert(n.getKind() == BAG_DIFFERENCE_REMOVE); + // Example + // ------- + // input: (bag.difference_remove A B) + // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2))) + // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1))) + // output: + // (bag "z" 2) + + auto equal = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // skip the shared element by doing nothing + }; + + auto less = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // itA->first is not in B, so we add it to the difference remove + elements[itA->first] = itA->second; + }; + + auto greaterOrEqual = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // itB->first is not in A, so we just skip it + }; + + auto remainderOfA = [](std::map& elements, + std::map& elementsA, + std::map::const_iterator& itA) { + // append the remainder of A + while (itA != elementsA.end()) + { + elements[itA->first] = itA->second; + itA++; + } + }; + + auto remainderOfB = [](std::map& elements, + std::map& elementsB, + std::map::const_iterator& itB) { + // do nothing + }; + + return evaluateBinaryOperation( + n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); +} + +Node BagsUtils::evaluateChoose(TNode n) +{ + Assert(n.getKind() == BAG_CHOOSE); + // Examples + // -------- + // - (bag.choose (bag "x" 4)) = "x" + + if (n[0].getKind() == BAG_MAKE) + { + return n[0][0]; + } + throw LogicException("BAG_CHOOSE_TOTAL is not supported yet"); +} + +Node BagsUtils::evaluateCard(TNode n) +{ + Assert(n.getKind() == BAG_CARD); + // Examples + // -------- + // - (card (as bag.empty (Bag String))) = 0 + // - (bag.choose (bag "x" 4)) = 4 + // - (bag.choose (bag.union_disjoint (bag "x" 4) (bag "y" 1))) = 5 + + std::map elements = getBagElements(n[0]); + Rational sum(0); + for (std::pair element : elements) + { + sum += element.second; + } + + NodeManager* nm = NodeManager::currentNM(); + Node sumNode = nm->mkConstInt(sum); + return sumNode; +} + +Node BagsUtils::evaluateIsSingleton(TNode n) +{ + Assert(n.getKind() == BAG_IS_SINGLETON); + // Examples + // -------- + // - (bag.is_singleton (as bag.empty (Bag String))) = false + // - (bag.is_singleton (bag "x" 1)) = true + // - (bag.is_singleton (bag "x" 4)) = false + // - (bag.is_singleton (bag.union_disjoint (bag "x" 1) (bag "y" 1))) + // = false + + if (n[0].getKind() == BAG_MAKE && n[0][1].getConst().isOne()) + { + return NodeManager::currentNM()->mkConst(true); + } + return NodeManager::currentNM()->mkConst(false); +} + +Node BagsUtils::evaluateFromSet(TNode n) +{ + Assert(n.getKind() == BAG_FROM_SET); + + // Examples + // -------- + // - (bag.from_set (as set.empty (Set String))) = (as bag.empty (Bag String)) + // - (bag.from_set (set.singleton "x")) = (bag "x" 1) + // - (bag.from_set (set.union (set.singleton "x") (set.singleton "y"))) = + // (bag.disjoint_union (bag "x" 1) (bag "y" 1)) + + NodeManager* nm = NodeManager::currentNM(); + std::set setElements = + sets::NormalForm::getElementsFromNormalConstant(n[0]); + Rational one = Rational(1); + std::map bagElements; + for (const Node& element : setElements) + { + bagElements[element] = one; + } + TypeNode bagType = nm->mkBagType(n[0].getType().getSetElementType()); + Node bag = constructConstantBagFromElements(bagType, bagElements); + return bag; +} + +Node BagsUtils::evaluateToSet(TNode n) +{ + Assert(n.getKind() == BAG_TO_SET); + + // Examples + // -------- + // - (bag.to_set (as bag.empty (Bag String))) = (as set.empty (Set String)) + // - (bag.to_set (bag "x" 4)) = (set.singleton "x") + // - (bag.to_set (bag.disjoint_union (bag "x" 3) (bag "y" 5)) = + // (set.union (set.singleton "x") (set.singleton "y"))) + + NodeManager* nm = NodeManager::currentNM(); + std::map bagElements = getBagElements(n[0]); + std::set setElements; + std::map::const_reverse_iterator it; + for (it = bagElements.rbegin(); it != bagElements.rend(); it++) + { + setElements.insert(it->first); + } + TypeNode setType = nm->mkSetType(n[0].getType().getBagElementType()); + Node set = sets::NormalForm::elementsToSet(setElements, setType); + return set; +} + +Node BagsUtils::evaluateBagMap(TNode n) +{ + Assert(n.getKind() == BAG_MAP); + + // Examples + // -------- + // - (bag.map ((lambda ((x String)) "z") + // (bag.union_disjoint (bag "a" 2) (bag "b" 3)) = + // (bag.union_disjoint + // (bag ((lambda ((x String)) "z") "a") 2) + // (bag ((lambda ((x String)) "z") "b") 3)) = + // (bag "z" 5) + + std::map elements = BagsUtils::getBagElements(n[1]); + std::map mappedElements; + std::map::iterator it = elements.begin(); + NodeManager* nm = NodeManager::currentNM(); + while (it != elements.end()) + { + Node mappedElement = nm->mkNode(APPLY_UF, n[0], it->first); + mappedElements[mappedElement] = it->second; + ++it; + } + TypeNode t = nm->mkBagType(n[0].getType().getRangeType()); + Node ret = BagsUtils::constructConstantBagFromElements(t, mappedElements); + return ret; +} + +Node BagsUtils::evaluateBagFilter(TNode n) +{ + Assert(n.getKind() == BAG_FILTER); + + // - (bag.filter p (as bag.empty (Bag T)) = (as bag.empty (Bag T)) + // - (bag.filter p (bag.union_disjoint (bag "a" 3) (bag "b" 2))) = + // (bag.union_disjoint + // (ite (p "a") (bag "a" 3) (as bag.empty (Bag T))) + // (ite (p "b") (bag "b" 2) (as bag.empty (Bag T))) + + Node P = n[0]; + Node A = n[1]; + TypeNode bagType = A.getType(); + NodeManager* nm = NodeManager::currentNM(); + Node empty = nm->mkConst(EmptyBag(bagType)); + + std::map elements = getBagElements(n[1]); + std::vector bags; + + for (const auto& [e, count] : elements) + { + Node multiplicity = nm->mkConst(CONST_RATIONAL, count); + Node bag = nm->mkBag(bagType.getBagElementType(), e, multiplicity); + Node pOfe = nm->mkNode(APPLY_UF, P, e); + Node ite = nm->mkNode(ITE, pOfe, bag, empty); + bags.push_back(ite); + } + Node ret = computeDisjointUnion(bagType, bags); + return ret; +} + +Node BagsUtils::evaluateBagFold(TNode n) +{ + Assert(n.getKind() == BAG_FOLD); + + // Examples + // -------- + // minimum string + // - (bag.fold + // ((lambda ((x String) (y String)) (ite (str.< x y) x y)) + // "" + // (bag.union_disjoint (bag "a" 2) (bag "b" 3)) + // = "a" + + Node f = n[0]; // combining function + Node ret = n[1]; // initial value + Node A = n[2]; // bag + std::map elements = BagsUtils::getBagElements(A); + + std::map::iterator it = elements.begin(); + NodeManager* nm = NodeManager::currentNM(); + while (it != elements.end()) + { + // apply the combination function n times, where n is the multiplicity + Rational count = it->second; + Assert(count.sgn() >= 0) << "negative multiplicity" << std::endl; + while (!count.isZero()) + { + ret = nm->mkNode(APPLY_UF, f, it->first, ret); + count = count - 1; + } + ++it; + } + return ret; +} + +} // namespace bags +} // namespace theory +} // namespace cvc5 diff --git a/src/theory/bags/bags_utils.h b/src/theory/bags/bags_utils.h new file mode 100644 index 000000000..61473a023 --- /dev/null +++ b/src/theory/bags/bags_utils.h @@ -0,0 +1,223 @@ +/****************************************************************************** + * Top contributors (to current version): + * Mudathir Mohamed + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * Utility functions for bags. + */ + +#include + +#include "cvc5_private.h" + +#ifndef CVC5__THEORY__BAGS__NORMAL_FORM_H +#define CVC5__THEORY__BAGS__NORMAL_FORM_H + +namespace cvc5 { +namespace theory { +namespace bags { + +class BagsUtils +{ + public: + /** + * @param bagType type of bags + * @param bags a vector of bag nodes + * @return disjoint union of these bags + */ + static Node computeDisjointUnion(TypeNode bagType, + const std::vector& bags); + /** + * Returns true if n is considered a to be a (canonical) constant bag value. + * A canonical bag value is one whose AST is: + * (bag.union_disjoint (bag e1 c1) ... + * (bag.union_disjoint (bag e_{n-1} c_{n-1}) (bag e_n c_n)))) + * where c1 ... cn are positive integers, e1 ... en are constants, and the + * node identifier of these constants are such that: e1 < ... < en. + * Also handles the corner cases of empty bag and bag constructed by bag + */ + static bool isConstant(TNode n); + /** + * check whether all children of the given node are constants + */ + static bool areChildrenConstants(TNode n); + /** + * evaluate the node n to a constant value. + * As a precondition, children of n should be constants. + */ + static Node evaluate(TNode n); + + /** + * get the elements along with their multiplicities in a given bag + * @param n a constant node whose type is a bag + * @return a map whose keys are constant elements and values are + * multiplicities + */ + static std::map getBagElements(TNode n); + + /** + * construct a constant bag from constant elements + * @param t the type of the returned bag + * @param elements a map whose keys are constant elements and values are + * multiplicities + * @return a constant bag that contains + */ + static Node constructConstantBagFromElements( + TypeNode t, const std::map& elements); + + /** + * construct a constant bag from node elements + * @param t the type of the returned bag + * @param elements a map whose keys are constant elements and values are + * multiplicities + * @return a constant bag that contains + */ + static Node constructBagFromElements(TypeNode t, + const std::map& elements); + + /** + * @param n has the form (bag.fold f t A) where A is a constant bag + * @return a single value which is the result of the fold + */ + static Node evaluateBagFold(TNode n); + + /** + * @param n has the form (bag.filter p A) where A is a constant bag + * @return A filtered with predicate p + */ + static Node evaluateBagFilter(TNode n); + + private: + /** + * a high order helper function that return a constant bag that is the result + * of (op A B) where op is a binary operator and A, B are constant bags. + * The result is computed from the elements of A (elementsA with iterator itA) + * and elements of B (elementsB with iterator itB). + * The arguments below specify how these iterators are used to generate the + * elements of the result (elements). + * @param n a node whose kind is a binary operator (bag.union_disjoint, + * union_max, intersection_min, difference_subtract, difference_remove) and + * whose children are constant bags. + * @param equal a lambda expression that receives (elements, itA, itB) and + * specify the action that needs to be taken when the elements of itA, itB are + * equal. + * @param less a lambda expression that receives (elements, itA, itB) and + * specify the action that needs to be taken when the element itA is less than + * the element of itB. + * @param greaterOrEqual less a lambda expression that receives (elements, + * itA, itB) and specify the action that needs to be taken when the element + * itA is greater than or equal than the element of itB. + * @param remainderOfA a lambda expression that receives (elements, elementsA, + * itA) and specify the action that needs to be taken to the remaining + * elements of A when all elements of B are visited. + * @param remainderOfB a lambda expression that receives (elements, elementsB, + * itB) and specify the action that needs to be taken to the remaining + * elements of B when all elements of A are visited. + * @return a constant bag that the result of (op n[0] n[1]) + */ + template + static Node evaluateBinaryOperation(const TNode& n, + T1&& equal, + T2&& less, + T3&& greaterOrEqual, + T4&& remainderOfA, + T5&& remainderOfB); + /** + * evaluate n as follows: + * - (bag a 0) = (as bag.empty T) where T is the type of the original bag + * - (bag a (-c)) = (as bag.empty T) where T is the type the original bag, + * and c > 0 is a constant + */ + static Node evaluateMakeBag(TNode n); + + /** + * returns the multiplicity in a constant bag + * @param n has the form (bag.count x A) where x, A are constants + * @return the multiplicity of element x in bag A. + */ + static Node evaluateBagCount(TNode n); + + /** + * @param n has the form (bag.duplicate_removal A) where A is a constant bag + * @return a constant bag constructed from the elements in A where each + * element has multiplicity one + */ + static Node evaluateDuplicateRemoval(TNode n); + + /** + * evaluates union disjoint node such that the returned node is a canonical + * bag that has the form + * (bag.union_disjoint (bag e1 c1) ... + * (bag.union_disjoint * (bag e_{n-1} c_{n-1}) (bag e_n c_n)))) where + * c1... cn are positive integers, e1 ... en are constants, and the node + * identifier of these constants are such that: e1 < ... < en. + * @param n has the form (bag.union_disjoint A B) where A, B are constant bags + * @return the union disjoint of A and B + */ + static Node evaluateUnionDisjoint(TNode n); + /** + * @param n has the form (bag.union_max A B) where A, B are constant bags + * @return the union max of A and B + */ + static Node evaluateUnionMax(TNode n); + /** + * @param n has the form (bag.inter_min A B) where A, B are constant bags + * @return the intersection min of A and B + */ + static Node evaluateIntersectionMin(TNode n); + /** + * @param n has the form (bag.difference_subtract A B) where A, B are constant + * bags + * @return the difference subtract of A and B + */ + static Node evaluateDifferenceSubtract(TNode n); + /** + * @param n has the form (bag.difference_remove A B) where A, B are constant + * bags + * @return the difference remove of A and B + */ + static Node evaluateDifferenceRemove(TNode n); + /** + * @param n has the form (bag.choose A) where A is a constant bag + * @return x if n has the form (bag.choose (bag x c)). Otherwise an error is + * thrown. + */ + static Node evaluateChoose(TNode n); + /** + * @param n has the form (bag.card A) where A is a constant bag + * @return the number of elements in bag A + */ + static Node evaluateCard(TNode n); + /** + * @param n has the form (bag.is_singleton A) where A is a constant bag + * @return whether the bag A has cardinality one. + */ + static Node evaluateIsSingleton(TNode n); + /** + * @param n has the form (bag.from_set A) where A is a constant set + * @return a constant bag that contains exactly the elements in A. + */ + static Node evaluateFromSet(TNode n); + /** + * @param n has the form (bag.to_set A) where A is a constant bag + * @return a constant set constructed from the elements in A. + */ + static Node evaluateToSet(TNode n); + /** + * @param n has the form (bag.map f A) where A is a constant bag + * @return a constant bag constructed from the images of elements in A. + */ + static Node evaluateBagMap(TNode n); +}; +} // namespace bags +} // namespace theory +} // namespace cvc5 + +#endif /* CVC5__THEORY__BAGS__NORMAL_FORM_H */ diff --git a/src/theory/bags/card_solver.cpp b/src/theory/bags/card_solver.cpp index 4ec009c7d..2a35fb2bd 100644 --- a/src/theory/bags/card_solver.cpp +++ b/src/theory/bags/card_solver.cpp @@ -17,9 +17,9 @@ #include "expr/emptybag.h" #include "smt/logic_exception.h" +#include "theory/bags/bags_utils.h" #include "theory/bags/inference_generator.h" #include "theory/bags/inference_manager.h" -#include "theory/bags/normal_form.h" #include "theory/bags/solver_state.h" #include "theory/bags/term_registry.h" #include "theory/uf/equality_engine_iterator.h" diff --git a/src/theory/bags/inference_generator.cpp b/src/theory/bags/inference_generator.cpp index 92aa1a0ea..3247548c5 100644 --- a/src/theory/bags/inference_generator.cpp +++ b/src/theory/bags/inference_generator.cpp @@ -517,6 +517,52 @@ InferInfo InferenceGenerator::mapUpwards( return inferInfo; } +InferInfo InferenceGenerator::filterDownwards(Node n, Node e) +{ + Assert(n.getKind() == BAG_FILTER && n[1].getType().isBag()); + Assert(e.getType().isSubtypeOf(n[1].getType().getBagElementType())); + + Node P = n[0]; + Node A = n[1]; + InferInfo inferInfo(d_im, InferenceId::BAGS_FILTER_DOWN); + + Node countA = getMultiplicityTerm(e, A); + Node skolem = getSkolem(n, inferInfo); + Node count = getMultiplicityTerm(e, skolem); + + Node member = d_nm->mkNode(GEQ, count, d_one); + Node pOfe = d_nm->mkNode(APPLY_UF, P, e); + Node equal = count.eqNode(countA); + + inferInfo.d_conclusion = pOfe.andNode(equal); + inferInfo.d_premises.push_back(member); + return inferInfo; +} + +InferInfo InferenceGenerator::filterUpwards(Node n, Node e) +{ + Assert(n.getKind() == BAG_FILTER && n[1].getType().isBag()); + Assert(e.getType().isSubtypeOf(n[1].getType().getBagElementType())); + + Node P = n[0]; + Node A = n[1]; + InferInfo inferInfo(d_im, InferenceId::BAGS_FILTER_UP); + + Node countA = getMultiplicityTerm(e, A); + Node skolem = getSkolem(n, inferInfo); + Node count = getMultiplicityTerm(e, skolem); + + Node member = d_nm->mkNode(GEQ, countA, d_one); + Node pOfe = d_nm->mkNode(APPLY_UF, P, e); + Node equal = count.eqNode(countA); + Node included = pOfe.andNode(equal); + Node equalZero = count.eqNode(d_zero); + Node excluded = pOfe.notNode().andNode(equalZero); + inferInfo.d_conclusion = included.orNode(excluded); + inferInfo.d_premises.push_back(member); + return inferInfo; +} + } // namespace bags } // namespace theory } // namespace cvc5 diff --git a/src/theory/bags/inference_generator.h b/src/theory/bags/inference_generator.h index 2815058b2..3d74dbaa2 100644 --- a/src/theory/bags/inference_generator.h +++ b/src/theory/bags/inference_generator.h @@ -262,6 +262,34 @@ class InferenceGenerator */ InferInfo mapUpwards(Node n, Node uf, Node preImageSize, Node y, Node x); + /** + * @param n is (bag.filter p A) where p is a function (-> E Bool), + * A a bag of type (Bag E) + * @param e is an element of type E + * @return an inference that represents the following implication + * (=> + * (bag.member e skolem) + * (and + * (p e) + * (= (bag.count e skolem) (bag.count A))) + * where skolem is a variable equals (bag.filter p A) + */ + InferInfo filterDownwards(Node n, Node e); + + /** + * @param n is (bag.filter p A) where p is a function (-> E Bool), + * A a bag of type (Bag E) + * @param e is an element of type E + * @return an inference that represents the following implication + * (=> + * (bag.member e A) + * (or + * (and (p e) (= (bag.count e skolem) (bag.count A))) + * (and (not (p e)) (= (bag.count e skolem) 0))) + * where skolem is a variable equals (bag.filter p A) + */ + InferInfo filterUpwards(Node n, Node e); + /** * @param element of type T * @param bag of type (bag T) diff --git a/src/theory/bags/kinds b/src/theory/bags/kinds index d83be5e21..7d995dd7b 100644 --- a/src/theory/bags/kinds +++ b/src/theory/bags/kinds @@ -77,6 +77,10 @@ operator BAG_CHOOSE 1 "return an element in the bag given as a parameter # of the second argument, a bag of type (Bag T1), and returns a bag of type (Bag T2). operator BAG_MAP 2 "bag map function" +# The bag.filter operator takes a predicate of type (-> T Bool) and a bag of type (Bag T) +# and return the same bag excluding those elements that do not satisfy the predicate +operator BAG_FILTER 2 "bag filter operator" + # bag.fold operator combines elements of a bag into a single value. # (bag.fold f t B) folds the elements of bag B starting with term t and using # the combining function f. @@ -103,6 +107,7 @@ typerule BAG_IS_SINGLETON ::cvc5::theory::bags::IsSingletonTypeRule typerule BAG_FROM_SET ::cvc5::theory::bags::FromSetTypeRule typerule BAG_TO_SET ::cvc5::theory::bags::ToSetTypeRule typerule BAG_MAP ::cvc5::theory::bags::BagMapTypeRule +typerule BAG_FILTER ::cvc5::theory::bags::BagFilterTypeRule typerule BAG_FOLD ::cvc5::theory::bags::BagFoldTypeRule construle BAG_UNION_DISJOINT ::cvc5::theory::bags::BinaryOperatorTypeRule diff --git a/src/theory/bags/normal_form.cpp b/src/theory/bags/normal_form.cpp deleted file mode 100644 index 6cf26d357..000000000 --- a/src/theory/bags/normal_form.cpp +++ /dev/null @@ -1,727 +0,0 @@ -/****************************************************************************** - * Top contributors (to current version): - * Mudathir Mohamed, Aina Niemetz - * - * This file is part of the cvc5 project. - * - * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS - * in the top-level source directory and their institutional affiliations. - * All rights reserved. See the file COPYING in the top-level source - * directory for licensing information. - * **************************************************************************** - * - * Normal form for bag constants. - */ -#include "normal_form.h" - -#include "expr/emptybag.h" -#include "smt/logic_exception.h" -#include "theory/sets/normal_form.h" -#include "theory/type_enumerator.h" -#include "util/rational.h" - -using namespace cvc5::kind; - -namespace cvc5 { -namespace theory { -namespace bags { - -bool NormalForm::isConstant(TNode n) -{ - if (n.getKind() == BAG_EMPTY) - { - // empty bags are already normalized - return true; - } - if (n.getKind() == BAG_MAKE) - { - // see the implementation in MkBagTypeRule::computeIsConst - return n.isConst(); - } - if (n.getKind() == BAG_UNION_DISJOINT) - { - if (!(n[0].getKind() == kind::BAG_MAKE && n[0].isConst())) - { - // the first child is not a constant - return false; - } - // store the previous element to check the ordering of elements - Node previousElement = n[0][0]; - Node current = n[1]; - while (current.getKind() == BAG_UNION_DISJOINT) - { - if (!(current[0].getKind() == kind::BAG_MAKE && current[0].isConst())) - { - // the current element is not a constant - return false; - } - if (previousElement >= current[0][0]) - { - // the ordering is violated - return false; - } - previousElement = current[0][0]; - current = current[1]; - } - // check last element - if (!(current.getKind() == kind::BAG_MAKE && current.isConst())) - { - // the last element is not a constant - return false; - } - if (previousElement >= current[0]) - { - // the ordering is violated - return false; - } - return true; - } - - // only nodes with kinds EMPTY_BAG, BAG_MAKE, and BAG_UNION_DISJOINT can be - // constants - return false; -} - -bool NormalForm::areChildrenConstants(TNode n) -{ - return std::all_of(n.begin(), n.end(), [](Node c) { return c.isConst(); }); -} - -Node NormalForm::evaluate(TNode n) -{ - Assert(areChildrenConstants(n)); - if (n.isConst()) - { - // a constant node is already in a normal form - return n; - } - switch (n.getKind()) - { - case BAG_MAKE: return evaluateMakeBag(n); - case BAG_COUNT: return evaluateBagCount(n); - case BAG_DUPLICATE_REMOVAL: return evaluateDuplicateRemoval(n); - case BAG_UNION_DISJOINT: return evaluateUnionDisjoint(n); - case BAG_UNION_MAX: return evaluateUnionMax(n); - case BAG_INTER_MIN: return evaluateIntersectionMin(n); - case BAG_DIFFERENCE_SUBTRACT: return evaluateDifferenceSubtract(n); - case BAG_DIFFERENCE_REMOVE: return evaluateDifferenceRemove(n); - case BAG_CARD: return evaluateCard(n); - case BAG_IS_SINGLETON: return evaluateIsSingleton(n); - case BAG_FROM_SET: return evaluateFromSet(n); - case BAG_TO_SET: return evaluateToSet(n); - case BAG_MAP: return evaluateBagMap(n); - case BAG_FOLD: return evaluateBagFold(n); - default: break; - } - Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n - << std::endl; -} - -template -Node NormalForm::evaluateBinaryOperation(const TNode& n, - T1&& equal, - T2&& less, - T3&& greaterOrEqual, - T4&& remainderOfA, - T5&& remainderOfB) -{ - std::map elementsA = getBagElements(n[0]); - std::map elementsB = getBagElements(n[1]); - std::map elements; - - std::map::const_iterator itA = elementsA.begin(); - std::map::const_iterator itB = elementsB.begin(); - - Trace("bags-evaluate") << "[NormalForm::evaluateBinaryOperation " - << n.getKind() << "] " << std::endl - << "elements A: " << elementsA << std::endl - << "elements B: " << elementsB << std::endl; - - while (itA != elementsA.end() && itB != elementsB.end()) - { - if (itA->first == itB->first) - { - equal(elements, itA, itB); - itA++; - itB++; - } - else if (itA->first < itB->first) - { - less(elements, itA, itB); - itA++; - } - else - { - greaterOrEqual(elements, itA, itB); - itB++; - } - } - - // handle the remaining elements from A - remainderOfA(elements, elementsA, itA); - // handle the remaining elements from B - remainderOfB(elements, elementsB, itB); - - Trace("bags-evaluate") << "elements: " << elements << std::endl; - Node bag = constructConstantBagFromElements(n.getType(), elements); - Trace("bags-evaluate") << "bag: " << bag << std::endl; - return bag; -} - -std::map NormalForm::getBagElements(TNode n) -{ - std::map elements; - if (n.getKind() == BAG_EMPTY) - { - return elements; - } - while (n.getKind() == kind::BAG_UNION_DISJOINT) - { - Assert(n[0].getKind() == kind::BAG_MAKE); - Node element = n[0][0]; - Rational count = n[0][1].getConst(); - elements[element] = count; - n = n[1]; - } - Assert(n.getKind() == kind::BAG_MAKE); - Node lastElement = n[0]; - Rational lastCount = n[1].getConst(); - elements[lastElement] = lastCount; - return elements; -} - -Node NormalForm::constructConstantBagFromElements( - TypeNode t, const std::map& elements) -{ - Assert(t.isBag()); - NodeManager* nm = NodeManager::currentNM(); - if (elements.empty()) - { - return nm->mkConst(EmptyBag(t)); - } - TypeNode elementType = t.getBagElementType(); - std::map::const_reverse_iterator it = elements.rbegin(); - Node bag = nm->mkBag(elementType, it->first, nm->mkConstInt(it->second)); - while (++it != elements.rend()) - { - Node n = nm->mkBag(elementType, it->first, nm->mkConstInt(it->second)); - bag = nm->mkNode(BAG_UNION_DISJOINT, n, bag); - } - return bag; -} - -Node NormalForm::constructBagFromElements(TypeNode t, - const std::map& elements) -{ - Assert(t.isBag()); - NodeManager* nm = NodeManager::currentNM(); - if (elements.empty()) - { - return nm->mkConst(EmptyBag(t)); - } - TypeNode elementType = t.getBagElementType(); - std::map::const_reverse_iterator it = elements.rbegin(); - Node bag = nm->mkBag(elementType, it->first, it->second); - while (++it != elements.rend()) - { - Node n = nm->mkBag(elementType, it->first, it->second); - bag = nm->mkNode(BAG_UNION_DISJOINT, n, bag); - } - return bag; -} - -Node NormalForm::evaluateMakeBag(TNode n) -{ - // the case where n is const should be handled earlier. - // here we handle the case where the multiplicity is zero or negative - Assert(n.getKind() == BAG_MAKE && !n.isConst() - && n[1].getConst().sgn() < 1); - Node emptybag = NodeManager::currentNM()->mkConst(EmptyBag(n.getType())); - return emptybag; -} - -Node NormalForm::evaluateBagCount(TNode n) -{ - Assert(n.getKind() == BAG_COUNT); - // Examples - // -------- - // - (bag.count "x" (as bag.empty (Bag String))) = 0 - // - (bag.count "x" (bag "y" 5)) = 0 - // - (bag.count "x" (bag "x" 4)) = 4 - // - (bag.count "x" (bag.union_disjoint (bag "x" 4) (bag "y" 5)) = 4 - // - (bag.count "x" (bag.union_disjoint (bag "y" 5) (bag "z" 5)) = 0 - - std::map elements = getBagElements(n[1]); - std::map::iterator it = elements.find(n[0]); - - NodeManager* nm = NodeManager::currentNM(); - if (it != elements.end()) - { - Node count = nm->mkConstInt(it->second); - return count; - } - return nm->mkConstInt(Rational(0)); -} - -Node NormalForm::evaluateDuplicateRemoval(TNode n) -{ - Assert(n.getKind() == BAG_DUPLICATE_REMOVAL); - - // Examples - // -------- - // - (bag.duplicate_removal (as bag.empty (Bag String))) = (as bag.empty (Bag - // String)) - // - (bag.duplicate_removal (bag "x" 4)) = (bag "x" 1) - // - (bag.duplicate_removal (bag.disjoint_union (bag "x" 3) (bag "y" 5)) = - // (bag.disjoint_union (bag "x" 1) (bag "y" 1) - - std::map oldElements = getBagElements(n[0]); - // copy elements from the old bag - std::map newElements(oldElements); - Rational one = Rational(1); - std::map::iterator it; - for (it = newElements.begin(); it != newElements.end(); it++) - { - it->second = one; - } - Node bag = constructConstantBagFromElements(n[0].getType(), newElements); - return bag; -} - -Node NormalForm::evaluateUnionDisjoint(TNode n) -{ - Assert(n.getKind() == BAG_UNION_DISJOINT); - // Example - // ------- - // input: (bag.union_disjoint A B) - // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2))) - // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1))) - // output: - // (bag.union_disjoint A B) - // where A = (bag "x" 7) - // B = (bag.union_disjoint (bag "y" 1) (bag "z" 2))) - - auto equal = [](std::map& elements, - std::map::const_iterator& itA, - std::map::const_iterator& itB) { - // compute the sum of the multiplicities - elements[itA->first] = itA->second + itB->second; - }; - - auto less = [](std::map& elements, - std::map::const_iterator& itA, - std::map::const_iterator& itB) { - // add the element to the result - elements[itA->first] = itA->second; - }; - - auto greaterOrEqual = [](std::map& elements, - std::map::const_iterator& itA, - std::map::const_iterator& itB) { - // add the element to the result - elements[itB->first] = itB->second; - }; - - auto remainderOfA = [](std::map& elements, - std::map& elementsA, - std::map::const_iterator& itA) { - // append the remainder of A - while (itA != elementsA.end()) - { - elements[itA->first] = itA->second; - itA++; - } - }; - - auto remainderOfB = [](std::map& elements, - std::map& elementsB, - std::map::const_iterator& itB) { - // append the remainder of B - while (itB != elementsB.end()) - { - elements[itB->first] = itB->second; - itB++; - } - }; - - return evaluateBinaryOperation( - n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); -} - -Node NormalForm::evaluateUnionMax(TNode n) -{ - Assert(n.getKind() == BAG_UNION_MAX); - // Example - // ------- - // input: (bag.union_max A B) - // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2))) - // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1))) - // output: - // (bag.union_disjoint A B) - // where A = (bag "x" 4) - // B = (bag.union_disjoint (bag "y" 1) (bag "z" 2))) - - auto equal = [](std::map& elements, - std::map::const_iterator& itA, - std::map::const_iterator& itB) { - // compute the maximum multiplicity - elements[itA->first] = std::max(itA->second, itB->second); - }; - - auto less = [](std::map& elements, - std::map::const_iterator& itA, - std::map::const_iterator& itB) { - // add to the result - elements[itA->first] = itA->second; - }; - - auto greaterOrEqual = [](std::map& elements, - std::map::const_iterator& itA, - std::map::const_iterator& itB) { - // add to the result - elements[itB->first] = itB->second; - }; - - auto remainderOfA = [](std::map& elements, - std::map& elementsA, - std::map::const_iterator& itA) { - // append the remainder of A - while (itA != elementsA.end()) - { - elements[itA->first] = itA->second; - itA++; - } - }; - - auto remainderOfB = [](std::map& elements, - std::map& elementsB, - std::map::const_iterator& itB) { - // append the remainder of B - while (itB != elementsB.end()) - { - elements[itB->first] = itB->second; - itB++; - } - }; - - return evaluateBinaryOperation( - n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); -} - -Node NormalForm::evaluateIntersectionMin(TNode n) -{ - Assert(n.getKind() == BAG_INTER_MIN); - // Example - // ------- - // input: (bag.inter_min A B) - // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2))) - // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1))) - // output: - // (bag "x" 3) - - auto equal = [](std::map& elements, - std::map::const_iterator& itA, - std::map::const_iterator& itB) { - // compute the minimum multiplicity - elements[itA->first] = std::min(itA->second, itB->second); - }; - - auto less = [](std::map& elements, - std::map::const_iterator& itA, - std::map::const_iterator& itB) { - // do nothing - }; - - auto greaterOrEqual = [](std::map& elements, - std::map::const_iterator& itA, - std::map::const_iterator& itB) { - // do nothing - }; - - auto remainderOfA = [](std::map& elements, - std::map& elementsA, - std::map::const_iterator& itA) { - // do nothing - }; - - auto remainderOfB = [](std::map& elements, - std::map& elementsB, - std::map::const_iterator& itB) { - // do nothing - }; - - return evaluateBinaryOperation( - n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); -} - -Node NormalForm::evaluateDifferenceSubtract(TNode n) -{ - Assert(n.getKind() == BAG_DIFFERENCE_SUBTRACT); - // Example - // ------- - // input: (bag.difference_subtract A B) - // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2))) - // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1))) - // output: - // (bag.union_disjoint (bag "x" 1) (bag "z" 2)) - - auto equal = [](std::map& elements, - std::map::const_iterator& itA, - std::map::const_iterator& itB) { - // subtract the multiplicities - elements[itA->first] = itA->second - itB->second; - }; - - auto less = [](std::map& elements, - std::map::const_iterator& itA, - std::map::const_iterator& itB) { - // itA->first is not in B, so we add it to the difference subtract - elements[itA->first] = itA->second; - }; - - auto greaterOrEqual = [](std::map& elements, - std::map::const_iterator& itA, - std::map::const_iterator& itB) { - // itB->first is not in A, so we just skip it - }; - - auto remainderOfA = [](std::map& elements, - std::map& elementsA, - std::map::const_iterator& itA) { - // append the remainder of A - while (itA != elementsA.end()) - { - elements[itA->first] = itA->second; - itA++; - } - }; - - auto remainderOfB = [](std::map& elements, - std::map& elementsB, - std::map::const_iterator& itB) { - // do nothing - }; - - return evaluateBinaryOperation( - n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); -} - -Node NormalForm::evaluateDifferenceRemove(TNode n) -{ - Assert(n.getKind() == BAG_DIFFERENCE_REMOVE); - // Example - // ------- - // input: (bag.difference_remove A B) - // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2))) - // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1))) - // output: - // (bag "z" 2) - - auto equal = [](std::map& elements, - std::map::const_iterator& itA, - std::map::const_iterator& itB) { - // skip the shared element by doing nothing - }; - - auto less = [](std::map& elements, - std::map::const_iterator& itA, - std::map::const_iterator& itB) { - // itA->first is not in B, so we add it to the difference remove - elements[itA->first] = itA->second; - }; - - auto greaterOrEqual = [](std::map& elements, - std::map::const_iterator& itA, - std::map::const_iterator& itB) { - // itB->first is not in A, so we just skip it - }; - - auto remainderOfA = [](std::map& elements, - std::map& elementsA, - std::map::const_iterator& itA) { - // append the remainder of A - while (itA != elementsA.end()) - { - elements[itA->first] = itA->second; - itA++; - } - }; - - auto remainderOfB = [](std::map& elements, - std::map& elementsB, - std::map::const_iterator& itB) { - // do nothing - }; - - return evaluateBinaryOperation( - n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); -} - -Node NormalForm::evaluateChoose(TNode n) -{ - Assert(n.getKind() == BAG_CHOOSE); - // Examples - // -------- - // - (bag.choose (bag "x" 4)) = "x" - - if (n[0].getKind() == BAG_MAKE) - { - return n[0][0]; - } - throw LogicException("BAG_CHOOSE_TOTAL is not supported yet"); -} - -Node NormalForm::evaluateCard(TNode n) -{ - Assert(n.getKind() == BAG_CARD); - // Examples - // -------- - // - (card (as bag.empty (Bag String))) = 0 - // - (bag.choose (bag "x" 4)) = 4 - // - (bag.choose (bag.union_disjoint (bag "x" 4) (bag "y" 1))) = 5 - - std::map elements = getBagElements(n[0]); - Rational sum(0); - for (std::pair element : elements) - { - sum += element.second; - } - - NodeManager* nm = NodeManager::currentNM(); - Node sumNode = nm->mkConstInt(sum); - return sumNode; -} - -Node NormalForm::evaluateIsSingleton(TNode n) -{ - Assert(n.getKind() == BAG_IS_SINGLETON); - // Examples - // -------- - // - (bag.is_singleton (as bag.empty (Bag String))) = false - // - (bag.is_singleton (bag "x" 1)) = true - // - (bag.is_singleton (bag "x" 4)) = false - // - (bag.is_singleton (bag.union_disjoint (bag "x" 1) (bag "y" 1))) - // = false - - if (n[0].getKind() == BAG_MAKE && n[0][1].getConst().isOne()) - { - return NodeManager::currentNM()->mkConst(true); - } - return NodeManager::currentNM()->mkConst(false); -} - -Node NormalForm::evaluateFromSet(TNode n) -{ - Assert(n.getKind() == BAG_FROM_SET); - - // Examples - // -------- - // - (bag.from_set (as set.empty (Set String))) = (as bag.empty (Bag String)) - // - (bag.from_set (set.singleton "x")) = (bag "x" 1) - // - (bag.from_set (set.union (set.singleton "x") (set.singleton "y"))) = - // (bag.disjoint_union (bag "x" 1) (bag "y" 1)) - - NodeManager* nm = NodeManager::currentNM(); - std::set setElements = - sets::NormalForm::getElementsFromNormalConstant(n[0]); - Rational one = Rational(1); - std::map bagElements; - for (const Node& element : setElements) - { - bagElements[element] = one; - } - TypeNode bagType = nm->mkBagType(n[0].getType().getSetElementType()); - Node bag = constructConstantBagFromElements(bagType, bagElements); - return bag; -} - -Node NormalForm::evaluateToSet(TNode n) -{ - Assert(n.getKind() == BAG_TO_SET); - - // Examples - // -------- - // - (bag.to_set (as bag.empty (Bag String))) = (as set.empty (Set String)) - // - (bag.to_set (bag "x" 4)) = (set.singleton "x") - // - (bag.to_set (bag.disjoint_union (bag "x" 3) (bag "y" 5)) = - // (set.union (set.singleton "x") (set.singleton "y"))) - - NodeManager* nm = NodeManager::currentNM(); - std::map bagElements = getBagElements(n[0]); - std::set setElements; - std::map::const_reverse_iterator it; - for (it = bagElements.rbegin(); it != bagElements.rend(); it++) - { - setElements.insert(it->first); - } - TypeNode setType = nm->mkSetType(n[0].getType().getBagElementType()); - Node set = sets::NormalForm::elementsToSet(setElements, setType); - return set; -} - -Node NormalForm::evaluateBagMap(TNode n) -{ - Assert(n.getKind() == BAG_MAP); - - // Examples - // -------- - // - (bag.map ((lambda ((x String)) "z") - // (bag.union_disjoint (bag "a" 2) (bag "b" 3)) = - // (bag.union_disjoint - // (bag ((lambda ((x String)) "z") "a") 2) - // (bag ((lambda ((x String)) "z") "b") 3)) = - // (bag "z" 5) - - std::map elements = NormalForm::getBagElements(n[1]); - std::map mappedElements; - std::map::iterator it = elements.begin(); - NodeManager* nm = NodeManager::currentNM(); - while (it != elements.end()) - { - Node mappedElement = nm->mkNode(APPLY_UF, n[0], it->first); - mappedElements[mappedElement] = it->second; - ++it; - } - TypeNode t = nm->mkBagType(n[0].getType().getRangeType()); - Node ret = NormalForm::constructConstantBagFromElements(t, mappedElements); - return ret; -} - -Node NormalForm::evaluateBagFold(TNode n) -{ - Assert(n.getKind() == BAG_FOLD); - - // Examples - // -------- - // minimum string - // - (bag.fold - // ((lambda ((x String) (y String)) (ite (str.< x y) x y)) - // "" - // (bag.union_disjoint (bag "a" 2) (bag "b" 3)) - // = "a" - - Node f = n[0]; // combining function - Node ret = n[1]; // initial value - Node A = n[2]; // bag - std::map elements = NormalForm::getBagElements(A); - - std::map::iterator it = elements.begin(); - NodeManager* nm = NodeManager::currentNM(); - while (it != elements.end()) - { - // apply the combination function n times, where n is the multiplicity - Rational count = it->second; - Assert(count.sgn() >= 0) << "negative multiplicity" << std::endl; - while (!count.isZero()) - { - ret = nm->mkNode(APPLY_UF, f, it->first, ret); - count = count - 1; - } - ++it; - } - return ret; -} - -} // namespace bags -} // namespace theory -} // namespace cvc5 diff --git a/src/theory/bags/normal_form.h b/src/theory/bags/normal_form.h deleted file mode 100644 index 5275678ff..000000000 --- a/src/theory/bags/normal_form.h +++ /dev/null @@ -1,210 +0,0 @@ -/****************************************************************************** - * Top contributors (to current version): - * Mudathir Mohamed - * - * This file is part of the cvc5 project. - * - * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS - * in the top-level source directory and their institutional affiliations. - * All rights reserved. See the file COPYING in the top-level source - * directory for licensing information. - * **************************************************************************** - * - * Normal form for bag constants. - */ - -#include - -#include "cvc5_private.h" - -#ifndef CVC5__THEORY__BAGS__NORMAL_FORM_H -#define CVC5__THEORY__BAGS__NORMAL_FORM_H - -namespace cvc5 { -namespace theory { -namespace bags { - -class NormalForm -{ - public: - /** - * Returns true if n is considered a to be a (canonical) constant bag value. - * A canonical bag value is one whose AST is: - * (bag.union_disjoint (bag e1 c1) ... - * (bag.union_disjoint (bag e_{n-1} c_{n-1}) (bag e_n c_n)))) - * where c1 ... cn are positive integers, e1 ... en are constants, and the - * node identifier of these constants are such that: e1 < ... < en. - * Also handles the corner cases of empty bag and bag constructed by bag - */ - static bool isConstant(TNode n); - /** - * check whether all children of the given node are constants - */ - static bool areChildrenConstants(TNode n); - /** - * evaluate the node n to a constant value. - * As a precondition, children of n should be constants. - */ - static Node evaluate(TNode n); - - /** - * get the elements along with their multiplicities in a given bag - * @param n a constant node whose type is a bag - * @return a map whose keys are constant elements and values are - * multiplicities - */ - static std::map getBagElements(TNode n); - - /** - * construct a constant bag from constant elements - * @param t the type of the returned bag - * @param elements a map whose keys are constant elements and values are - * multiplicities - * @return a constant bag that contains - */ - static Node constructConstantBagFromElements( - TypeNode t, const std::map& elements); - - /** - * construct a constant bag from node elements - * @param t the type of the returned bag - * @param elements a map whose keys are constant elements and values are - * multiplicities - * @return a constant bag that contains - */ - static Node constructBagFromElements(TypeNode t, - const std::map& elements); - - /** - * @param n has the form (bag.fold f t A) where A is a constant bag - * @return a single value which is the result of the fold - */ - static Node evaluateBagFold(TNode n); - - private: - /** - * a high order helper function that return a constant bag that is the result - * of (op A B) where op is a binary operator and A, B are constant bags. - * The result is computed from the elements of A (elementsA with iterator itA) - * and elements of B (elementsB with iterator itB). - * The arguments below specify how these iterators are used to generate the - * elements of the result (elements). - * @param n a node whose kind is a binary operator (bag.union_disjoint, - * union_max, intersection_min, difference_subtract, difference_remove) and - * whose children are constant bags. - * @param equal a lambda expression that receives (elements, itA, itB) and - * specify the action that needs to be taken when the elements of itA, itB are - * equal. - * @param less a lambda expression that receives (elements, itA, itB) and - * specify the action that needs to be taken when the element itA is less than - * the element of itB. - * @param greaterOrEqual less a lambda expression that receives (elements, - * itA, itB) and specify the action that needs to be taken when the element - * itA is greater than or equal than the element of itB. - * @param remainderOfA a lambda expression that receives (elements, elementsA, - * itA) and specify the action that needs to be taken to the remaining - * elements of A when all elements of B are visited. - * @param remainderOfB a lambda expression that receives (elements, elementsB, - * itB) and specify the action that needs to be taken to the remaining - * elements of B when all elements of A are visited. - * @return a constant bag that the result of (op n[0] n[1]) - */ - template - static Node evaluateBinaryOperation(const TNode& n, - T1&& equal, - T2&& less, - T3&& greaterOrEqual, - T4&& remainderOfA, - T5&& remainderOfB); - /** - * evaluate n as follows: - * - (bag a 0) = (as bag.empty T) where T is the type of the original bag - * - (bag a (-c)) = (as bag.empty T) where T is the type the original bag, - * and c > 0 is a constant - */ - static Node evaluateMakeBag(TNode n); - - /** - * returns the multiplicity in a constant bag - * @param n has the form (bag.count x A) where x, A are constants - * @return the multiplicity of element x in bag A. - */ - static Node evaluateBagCount(TNode n); - - /** - * @param n has the form (bag.duplicate_removal A) where A is a constant bag - * @return a constant bag constructed from the elements in A where each - * element has multiplicity one - */ - static Node evaluateDuplicateRemoval(TNode n); - - /** - * evaluates union disjoint node such that the returned node is a canonical - * bag that has the form - * (bag.union_disjoint (bag e1 c1) ... - * (bag.union_disjoint * (bag e_{n-1} c_{n-1}) (bag e_n c_n)))) where - * c1... cn are positive integers, e1 ... en are constants, and the node - * identifier of these constants are such that: e1 < ... < en. - * @param n has the form (bag.union_disjoint A B) where A, B are constant bags - * @return the union disjoint of A and B - */ - static Node evaluateUnionDisjoint(TNode n); - /** - * @param n has the form (bag.union_max A B) where A, B are constant bags - * @return the union max of A and B - */ - static Node evaluateUnionMax(TNode n); - /** - * @param n has the form (bag.inter_min A B) where A, B are constant bags - * @return the intersection min of A and B - */ - static Node evaluateIntersectionMin(TNode n); - /** - * @param n has the form (bag.difference_subtract A B) where A, B are constant - * bags - * @return the difference subtract of A and B - */ - static Node evaluateDifferenceSubtract(TNode n); - /** - * @param n has the form (bag.difference_remove A B) where A, B are constant - * bags - * @return the difference remove of A and B - */ - static Node evaluateDifferenceRemove(TNode n); - /** - * @param n has the form (bag.choose A) where A is a constant bag - * @return x if n has the form (bag.choose (bag x c)). Otherwise an error is - * thrown. - */ - static Node evaluateChoose(TNode n); - /** - * @param n has the form (bag.card A) where A is a constant bag - * @return the number of elements in bag A - */ - static Node evaluateCard(TNode n); - /** - * @param n has the form (bag.is_singleton A) where A is a constant bag - * @return whether the bag A has cardinality one. - */ - static Node evaluateIsSingleton(TNode n); - /** - * @param n has the form (bag.from_set A) where A is a constant set - * @return a constant bag that contains exactly the elements in A. - */ - static Node evaluateFromSet(TNode n); - /** - * @param n has the form (bag.to_set A) where A is a constant bag - * @return a constant set constructed from the elements in A. - */ - static Node evaluateToSet(TNode n); - /** - * @param n has the form (bag.map f A) where A is a constant bag - * @return a constant bag constructed from the images of elements in A. - */ - static Node evaluateBagMap(TNode n); -}; -} // namespace bags -} // namespace theory -} // namespace cvc5 - -#endif /* CVC5__THEORY__BAGS__NORMAL_FORM_H */ diff --git a/src/theory/bags/rewrites.cpp b/src/theory/bags/rewrites.cpp index d8ed9fb95..9bd0c3a86 100644 --- a/src/theory/bags/rewrites.cpp +++ b/src/theory/bags/rewrites.cpp @@ -38,6 +38,9 @@ const char* toString(Rewrite r) case Rewrite::EQ_CONST_FALSE: return "EQ_CONST_FALSE"; case Rewrite::EQ_REFL: return "EQ_REFL"; case Rewrite::EQ_SYM: return "EQ_SYM"; + case Rewrite::FILTER_CONST: return "FILTER_CONST"; + case Rewrite::FILTER_BAG_MAKE: return "FILTER_BAG_MAKE"; + case Rewrite::FILTER_UNION_DISJOINT: return "FILTER_UNION_DISJOINT"; case Rewrite::FROM_SINGLETON: return "FROM_SINGLETON"; case Rewrite::FOLD_BAG: return "FOLD_BAG"; case Rewrite::FOLD_CONST: return "FOLD_CONST"; diff --git a/src/theory/bags/rewrites.h b/src/theory/bags/rewrites.h index 57f106211..e1ef38c4b 100644 --- a/src/theory/bags/rewrites.h +++ b/src/theory/bags/rewrites.h @@ -42,6 +42,9 @@ enum class Rewrite : uint32_t EQ_CONST_FALSE, EQ_REFL, EQ_SYM, + FILTER_CONST, + FILTER_BAG_MAKE, + FILTER_UNION_DISJOINT, FROM_SINGLETON, FOLD_BAG, FOLD_CONST, diff --git a/src/theory/bags/theory_bags.cpp b/src/theory/bags/theory_bags.cpp index 720e97c25..37b6415e0 100644 --- a/src/theory/bags/theory_bags.cpp +++ b/src/theory/bags/theory_bags.cpp @@ -19,7 +19,7 @@ #include "expr/skolem_manager.h" #include "proof/proof_checker.h" #include "smt/logic_exception.h" -#include "theory/bags/normal_form.h" +#include "theory/bags/bags_utils.h" #include "theory/quantifiers/fmf/bounded_integers.h" #include "theory/rewriter.h" #include "theory/theory_model.h" @@ -321,7 +321,7 @@ bool TheoryBags::collectModelValues(TheoryModel* m, Node value = m->getRepresentative(countSkolem); elementReps[key] = value; } - Node constructedBag = NormalForm::constructBagFromElements(tn, elementReps); + Node constructedBag = BagsUtils::constructBagFromElements(tn, elementReps); constructedBag = rewrite(constructedBag); Trace("bags-model") << "constructed bag for " << n << " is: " << constructedBag << std::endl; @@ -352,7 +352,8 @@ bool TheoryBags::collectModelValues(TheoryModel* m, if (constructedRational < rCardRational && !d_env.isFiniteType(elementType)) { - Node newElement = nm->getSkolemManager()->mkDummySkolem("slack", elementType); + Node newElement = + nm->getSkolemManager()->mkDummySkolem("slack", elementType); Trace("bags-model") << "newElement is " << newElement << std::endl; Rational difference = rCardRational - constructedRational; Node multiplicity = nm->mkConst(CONST_RATIONAL, difference); diff --git a/src/theory/bags/theory_bags_type_enumerator.cpp b/src/theory/bags/theory_bags_type_enumerator.cpp index 14fca3297..a24981934 100644 --- a/src/theory/bags/theory_bags_type_enumerator.cpp +++ b/src/theory/bags/theory_bags_type_enumerator.cpp @@ -16,7 +16,7 @@ #include "theory/bags/theory_bags_type_enumerator.h" #include "expr/emptybag.h" -#include "theory/bags/normal_form.h" +#include "theory/bags/bags_utils.h" #include "theory_bags_type_enumerator.h" #include "util/rational.h" @@ -67,11 +67,10 @@ BagEnumerator& BagEnumerator::operator++() else { // increase the multiplicity of one of the elements in the current bag - std::map elements = - NormalForm::getBagElements(d_currentBag); + std::map elements = BagsUtils::getBagElements(d_currentBag); Node element = elements.begin()->first; elements[element] = elements[element] + Rational(1); - d_currentBag = NormalForm::constructConstantBagFromElements( + d_currentBag = BagsUtils::constructConstantBagFromElements( d_currentBag.getType(), elements); } diff --git a/src/theory/bags/theory_bags_type_rules.cpp b/src/theory/bags/theory_bags_type_rules.cpp index 2d218f821..689b0e208 100644 --- a/src/theory/bags/theory_bags_type_rules.cpp +++ b/src/theory/bags/theory_bags_type_rules.cpp @@ -20,7 +20,7 @@ #include "base/check.h" #include "expr/emptybag.h" #include "theory/bags/bag_make_op.h" -#include "theory/bags/normal_form.h" +#include "theory/bags/bags_utils.h" #include "util/cardinality.h" #include "util/rational.h" @@ -63,7 +63,7 @@ bool BinaryOperatorTypeRule::computeIsConst(NodeManager* nodeManager, TNode n) // only UNION_DISJOINT has a const rule in kinds. // Other binary operators do not have const rules in kinds Assert(n.getKind() == kind::BAG_UNION_DISJOINT); - return NormalForm::isConstant(n); + return BagsUtils::isConstant(n); } TypeNode SubBagTypeRule::computeType(NodeManager* nodeManager, @@ -356,6 +356,48 @@ TypeNode BagMapTypeRule::computeType(NodeManager* nodeManager, return retType; } +TypeNode BagFilterTypeRule::computeType(NodeManager* nodeManager, + TNode n, + bool check) +{ + Assert(n.getKind() == kind::BAG_FILTER); + TypeNode functionType = n[0].getType(check); + TypeNode bagType = n[1].getType(check); + if (check) + { + if (!bagType.isBag()) + { + throw TypeCheckingExceptionPrivate( + n, + "bag.filter operator expects a bag in the second argument, " + "a non-bag is found"); + } + + TypeNode elementType = bagType.getBagElementType(); + + if (!(functionType.isFunction())) + { + std::stringstream ss; + ss << "Operator " << n.getKind() << " expects a function of type (-> " + << elementType << " Bool) as a first argument. " + << "Found a term of type '" << functionType << "'."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + std::vector argTypes = functionType.getArgTypes(); + NodeManager* nm = NodeManager::currentNM(); + if (!(argTypes.size() == 1 && argTypes[0] == elementType + && functionType.getRangeType() == nm->booleanType())) + { + std::stringstream ss; + ss << "Operator " << n.getKind() << " expects a function of type (-> " + << elementType << " Bool). " + << "Found a function of type '" << functionType << "'."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + } + return bagType; +} + TypeNode BagFoldTypeRule::computeType(NodeManager* nodeManager, TNode n, bool check) diff --git a/src/theory/bags/theory_bags_type_rules.h b/src/theory/bags/theory_bags_type_rules.h index da9ea75bf..76c179a62 100644 --- a/src/theory/bags/theory_bags_type_rules.h +++ b/src/theory/bags/theory_bags_type_rules.h @@ -141,6 +141,15 @@ struct BagMapTypeRule static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); }; /* struct BagMapTypeRule */ +/** + * Type rule for (bag.filter p B) to make sure p is a unary predicate of type + * (-> T Bool) where B is a bag of type (Bag T) + */ +struct BagFilterTypeRule +{ + static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); +}; /* struct BagFilterTypeRule */ + /** * Type rule for (bag.fold f t A) to make sure f is a binary operation of type * (-> T1 T2 T2), t of type T2, and B is a bag of type (Bag T1) diff --git a/src/theory/inference_id.cpp b/src/theory/inference_id.cpp index d7e5fccbe..240d6e293 100644 --- a/src/theory/inference_id.cpp +++ b/src/theory/inference_id.cpp @@ -120,6 +120,8 @@ const char* toString(InferenceId i) case InferenceId::BAGS_DIFFERENCE_REMOVE: return "BAGS_DIFFERENCE_REMOVE"; case InferenceId::BAGS_DUPLICATE_REMOVAL: return "BAGS_DUPLICATE_REMOVAL"; case InferenceId::BAGS_MAP: return "BAGS_MAP"; + case InferenceId::BAGS_FILTER_DOWN: return "BAGS_FILTER_DOWN"; + case InferenceId::BAGS_FILTER_UP: return "BAGS_FILTER_UP"; case InferenceId::BAGS_FOLD: return "BAGS_FOLD"; case InferenceId::BAGS_CARD: return "BAGS_CARD"; diff --git a/src/theory/inference_id.h b/src/theory/inference_id.h index 4970c1cee..2fb3ae003 100644 --- a/src/theory/inference_id.h +++ b/src/theory/inference_id.h @@ -182,6 +182,8 @@ enum class InferenceId BAGS_DIFFERENCE_REMOVE, BAGS_DUPLICATE_REMOVAL, BAGS_MAP, + BAGS_FILTER_DOWN, + BAGS_FILTER_UP, BAGS_FOLD, BAGS_CARD, // ---------------------------------- end bags theory diff --git a/test/regress/CMakeLists.txt b/test/regress/CMakeLists.txt index c86be3e76..ec3b13caa 100644 --- a/test/regress/CMakeLists.txt +++ b/test/regress/CMakeLists.txt @@ -1653,6 +1653,11 @@ set(regress_1_tests regress1/bags/duplicate_removal1.smt2 regress1/bags/duplicate_removal2.smt2 regress1/bags/emptybag1.smt2 + regress1/bags/filter1.smt2 + regress1/bags/filter2.smt2 + regress1/bags/filter3.smt2 + regress1/bags/filter4.smt2 + regress1/bags/filter5.smt2 regress1/bags/fol_0000119.smt2 regress1/bags/fold1.smt2 regress1/bags/fuzzy1.smt2 diff --git a/test/regress/regress1/bags/filter1.smt2 b/test/regress/regress1/bags/filter1.smt2 new file mode 100644 index 000000000..65e87c10f --- /dev/null +++ b/test/regress/regress1/bags/filter1.smt2 @@ -0,0 +1,11 @@ +(set-logic ALL) +(set-info :status sat) +(declare-fun A () (Bag Int)) +(declare-fun B () (Bag Int)) +(declare-fun x () Int) +(declare-fun y () Int) +(declare-fun p (Int) Bool) +(assert (= A (bag.union_max (bag x 1) (bag y 2)))) +(assert (= B (bag.filter p A))) +(assert (distinct (p x) (p y))) +(check-sat) diff --git a/test/regress/regress1/bags/filter2.smt2 b/test/regress/regress1/bags/filter2.smt2 new file mode 100644 index 000000000..62b6403fd --- /dev/null +++ b/test/regress/regress1/bags/filter2.smt2 @@ -0,0 +1,9 @@ +(set-logic HO_ALL) +(set-info :status sat) +(set-option :fmf-bound true) +(declare-fun A () (Bag Int)) +(declare-fun B () (Bag Int)) +(declare-fun p (Int) Bool) +(assert (= B (bag.filter p A))) +(assert (= (bag.count (- 2) B) 57)) +(check-sat) diff --git a/test/regress/regress1/bags/filter3.smt2 b/test/regress/regress1/bags/filter3.smt2 new file mode 100644 index 000000000..10f637015 --- /dev/null +++ b/test/regress/regress1/bags/filter3.smt2 @@ -0,0 +1,10 @@ +(set-logic HO_ALL) +(set-info :status unsat) +(set-option :fmf-bound true) +(declare-fun A () (Bag Int)) +(declare-fun B () (Bag Int)) +(define-fun p ((x Int)) Bool (> x 1)) +(assert (= B (bag.filter p A))) +(assert (= (bag.count 3 B) 57)) +(assert (= (bag.count 3 B) 58)) +(check-sat) diff --git a/test/regress/regress1/bags/filter4.smt2 b/test/regress/regress1/bags/filter4.smt2 new file mode 100644 index 000000000..9be695210 --- /dev/null +++ b/test/regress/regress1/bags/filter4.smt2 @@ -0,0 +1,11 @@ +(set-logic HO_ALL) +(set-info :status unsat) +(declare-fun A () (Bag Int)) +(declare-fun B () (Bag Int)) +(declare-fun element () Int) +(declare-fun p (Int) Bool) +(assert (= B (bag.filter p A))) +(assert (p element)) +(assert (not (bag.member element B))) +(assert (bag.member element A)) +(check-sat) diff --git a/test/regress/regress1/bags/filter5.smt2 b/test/regress/regress1/bags/filter5.smt2 new file mode 100644 index 000000000..74ca05429 --- /dev/null +++ b/test/regress/regress1/bags/filter5.smt2 @@ -0,0 +1,11 @@ +(set-logic HO_ALL) +(set-info :status unsat) +(declare-fun A () (Bag Int)) +(declare-fun B () (Bag Int)) +(declare-fun element () Int) +(declare-fun p (Int) Bool) +(assert (= B (bag.filter p A))) +(assert (p element)) +(assert (not (bag.member element A))) +(assert (bag.member element B)) +(check-sat) diff --git a/test/regress/regress1/bags/map1.smt2 b/test/regress/regress1/bags/map1.smt2 index c7dc3d636..748d327dd 100644 --- a/test/regress/regress1/bags/map1.smt2 +++ b/test/regress/regress1/bags/map1.smt2 @@ -6,7 +6,6 @@ (declare-fun y () Int) (declare-fun f (Int) Int) (assert (= A (bag.union_max (bag x 1) (bag y 2)))) -(assert (= A (bag.union_max (bag x 1) (bag y 2)))) (assert (= B (bag.map f A))) (assert (distinct (f x) (f y) x y)) (check-sat) diff --git a/test/unit/theory/theory_bags_normal_form_white.cpp b/test/unit/theory/theory_bags_normal_form_white.cpp index 5f3abfcee..4c8c41f0b 100644 --- a/test/unit/theory/theory_bags_normal_form_white.cpp +++ b/test/unit/theory/theory_bags_normal_form_white.cpp @@ -18,7 +18,7 @@ #include "expr/emptyset.h" #include "test_smt.h" #include "theory/bags/bags_rewriter.h" -#include "theory/bags/normal_form.h" +#include "theory/bags/bags_utils.h" #include "theory/strings/type_enumerator.h" #include "util/rational.h" #include "util/string.h" @@ -65,7 +65,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, empty_bag_normal_form) Node emptybag = d_nodeManager->mkConst(EmptyBag(d_nodeManager->stringType())); // empty bags are in normal form ASSERT_TRUE(emptybag.isConst()); - Node n = NormalForm::evaluate(emptybag); + Node n = BagsUtils::evaluate(emptybag); ASSERT_EQ(emptybag, n); } @@ -89,9 +89,9 @@ TEST_F(TestTheoryWhiteBagsNormalForm, mkBag_constant_element) ASSERT_FALSE(negative.isConst()); ASSERT_FALSE(zero.isConst()); - ASSERT_EQ(emptybag, NormalForm::evaluate(negative)); - ASSERT_EQ(emptybag, NormalForm::evaluate(zero)); - ASSERT_EQ(positive, NormalForm::evaluate(positive)); + ASSERT_EQ(emptybag, BagsUtils::evaluate(negative)); + ASSERT_EQ(emptybag, BagsUtils::evaluate(zero)); + ASSERT_EQ(positive, BagsUtils::evaluate(positive)); } TEST_F(TestTheoryWhiteBagsNormalForm, bag_count) @@ -126,25 +126,25 @@ TEST_F(TestTheoryWhiteBagsNormalForm, bag_count) Node input1 = d_nodeManager->mkNode(BAG_COUNT, x, empty); Node output1 = zero; - ASSERT_EQ(output1, NormalForm::evaluate(input1)); + ASSERT_EQ(output1, BagsUtils::evaluate(input1)); Node input2 = d_nodeManager->mkNode(BAG_COUNT, x, y_5); Node output2 = zero; - ASSERT_EQ(output2, NormalForm::evaluate(input2)); + ASSERT_EQ(output2, BagsUtils::evaluate(input2)); Node input3 = d_nodeManager->mkNode(BAG_COUNT, x, x_4); Node output3 = four; - ASSERT_EQ(output2, NormalForm::evaluate(input2)); + ASSERT_EQ(output2, BagsUtils::evaluate(input2)); Node unionDisjointXY = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_4, y_5); Node input4 = d_nodeManager->mkNode(BAG_COUNT, x, unionDisjointXY); Node output4 = four; - ASSERT_EQ(output3, NormalForm::evaluate(input3)); + ASSERT_EQ(output3, BagsUtils::evaluate(input3)); Node unionDisjointYZ = d_nodeManager->mkNode(BAG_UNION_DISJOINT, y_5, z_5); Node input5 = d_nodeManager->mkNode(BAG_COUNT, x, unionDisjointYZ); Node output5 = zero; - ASSERT_EQ(output4, NormalForm::evaluate(input4)); + ASSERT_EQ(output4, BagsUtils::evaluate(input4)); } TEST_F(TestTheoryWhiteBagsNormalForm, duplicate_removal) @@ -161,7 +161,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, duplicate_removal) EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType()))); Node input1 = d_nodeManager->mkNode(BAG_DUPLICATE_REMOVAL, emptybag); Node output1 = emptybag; - ASSERT_EQ(output1, NormalForm::evaluate(input1)); + ASSERT_EQ(output1, BagsUtils::evaluate(input1)); Node x = d_nodeManager->mkConst(String("x")); Node y = d_nodeManager->mkConst(String("y")); @@ -186,12 +186,12 @@ TEST_F(TestTheoryWhiteBagsNormalForm, duplicate_removal) Node input2 = d_nodeManager->mkNode(BAG_DUPLICATE_REMOVAL, x_4); Node output2 = x_1; - ASSERT_EQ(output2, NormalForm::evaluate(input2)); + ASSERT_EQ(output2, BagsUtils::evaluate(input2)); Node normalBag = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_4, y_5); Node input3 = d_nodeManager->mkNode(BAG_DUPLICATE_REMOVAL, normalBag); Node output3 = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_1, y_1); - ASSERT_EQ(output3, NormalForm::evaluate(input3)); + ASSERT_EQ(output3, BagsUtils::evaluate(input3)); } TEST_F(TestTheoryWhiteBagsNormalForm, union_max) @@ -241,7 +241,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, union_max) d_nodeManager->mkNode(BAG_UNION_DISJOINT, y_1, z_2)); ASSERT_TRUE(output.isConst()); - ASSERT_EQ(output, NormalForm::evaluate(input)); + ASSERT_EQ(output, BagsUtils::evaluate(input)); } TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint1) @@ -265,12 +265,12 @@ TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint1) Node unionDisjointAB = d_nodeManager->mkNode(BAG_UNION_DISJOINT, A, B); // unionDisjointAB is already in a normal form ASSERT_TRUE(unionDisjointAB.isConst()); - ASSERT_EQ(unionDisjointAB, NormalForm::evaluate(unionDisjointAB)); + ASSERT_EQ(unionDisjointAB, BagsUtils::evaluate(unionDisjointAB)); Node unionDisjointBA = d_nodeManager->mkNode(BAG_UNION_DISJOINT, B, A); // unionDisjointAB is the normal form of unionDisjointBA ASSERT_FALSE(unionDisjointBA.isConst()); - ASSERT_EQ(unionDisjointAB, NormalForm::evaluate(unionDisjointBA)); + ASSERT_EQ(unionDisjointAB, BagsUtils::evaluate(unionDisjointBA)); Node unionDisjointAB_C = d_nodeManager->mkNode(BAG_UNION_DISJOINT, unionDisjointAB, C); @@ -280,7 +280,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint1) // unionDisjointA_BC is the normal form of unionDisjointAB_C ASSERT_FALSE(unionDisjointAB_C.isConst()); ASSERT_TRUE(unionDisjointA_BC.isConst()); - ASSERT_EQ(unionDisjointA_BC, NormalForm::evaluate(unionDisjointAB_C)); + ASSERT_EQ(unionDisjointA_BC, BagsUtils::evaluate(unionDisjointAB_C)); Node unionDisjointAA = d_nodeManager->mkNode(BAG_UNION_DISJOINT, A, A); Node AA = @@ -289,7 +289,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint1) d_nodeManager->mkConst(CONST_RATIONAL, Rational(4))); ASSERT_FALSE(unionDisjointAA.isConst()); ASSERT_TRUE(AA.isConst()); - ASSERT_EQ(AA, NormalForm::evaluate(unionDisjointAA)); + ASSERT_EQ(AA, BagsUtils::evaluate(unionDisjointAA)); } TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint2) @@ -339,7 +339,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint2) d_nodeManager->mkNode(BAG_UNION_DISJOINT, y_1, z_2)); ASSERT_TRUE(output.isConst()); - ASSERT_EQ(output, NormalForm::evaluate(input)); + ASSERT_EQ(output, BagsUtils::evaluate(input)); } TEST_F(TestTheoryWhiteBagsNormalForm, intersection_min) @@ -384,7 +384,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, intersection_min) Node output = x_3; ASSERT_TRUE(output.isConst()); - ASSERT_EQ(output, NormalForm::evaluate(input)); + ASSERT_EQ(output, BagsUtils::evaluate(input)); } TEST_F(TestTheoryWhiteBagsNormalForm, difference_subtract) @@ -433,7 +433,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, difference_subtract) Node output = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_1, z_2); ASSERT_TRUE(output.isConst()); - ASSERT_EQ(output, NormalForm::evaluate(input)); + ASSERT_EQ(output, BagsUtils::evaluate(input)); } TEST_F(TestTheoryWhiteBagsNormalForm, difference_remove) @@ -482,7 +482,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, difference_remove) Node output = z_2; ASSERT_TRUE(output.isConst()); - ASSERT_EQ(output, NormalForm::evaluate(input)); + ASSERT_EQ(output, BagsUtils::evaluate(input)); } TEST_F(TestTheoryWhiteBagsNormalForm, bag_card) @@ -509,16 +509,16 @@ TEST_F(TestTheoryWhiteBagsNormalForm, bag_card) Node input1 = d_nodeManager->mkNode(BAG_CARD, empty); Node output1 = d_nodeManager->mkConst(CONST_RATIONAL, Rational(0)); - ASSERT_EQ(output1, NormalForm::evaluate(input1)); + ASSERT_EQ(output1, BagsUtils::evaluate(input1)); Node input2 = d_nodeManager->mkNode(BAG_CARD, x_4); Node output2 = d_nodeManager->mkConst(CONST_RATIONAL, Rational(4)); - ASSERT_EQ(output2, NormalForm::evaluate(input2)); + ASSERT_EQ(output2, BagsUtils::evaluate(input2)); Node union_disjoint = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_4, y_1); Node input3 = d_nodeManager->mkNode(BAG_CARD, union_disjoint); Node output3 = d_nodeManager->mkConst(CONST_RATIONAL, Rational(5)); - ASSERT_EQ(output3, NormalForm::evaluate(input3)); + ASSERT_EQ(output3, BagsUtils::evaluate(input3)); } TEST_F(TestTheoryWhiteBagsNormalForm, is_singleton) @@ -552,20 +552,20 @@ TEST_F(TestTheoryWhiteBagsNormalForm, is_singleton) Node input1 = d_nodeManager->mkNode(BAG_IS_SINGLETON, empty); Node output1 = falseNode; - ASSERT_EQ(output1, NormalForm::evaluate(input1)); + ASSERT_EQ(output1, BagsUtils::evaluate(input1)); Node input2 = d_nodeManager->mkNode(BAG_IS_SINGLETON, x_1); Node output2 = trueNode; - ASSERT_EQ(output2, NormalForm::evaluate(input2)); + ASSERT_EQ(output2, BagsUtils::evaluate(input2)); Node input3 = d_nodeManager->mkNode(BAG_IS_SINGLETON, x_4); Node output3 = falseNode; - ASSERT_EQ(output2, NormalForm::evaluate(input2)); + ASSERT_EQ(output2, BagsUtils::evaluate(input2)); Node union_disjoint = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_1, y_1); Node input4 = d_nodeManager->mkNode(BAG_IS_SINGLETON, union_disjoint); Node output4 = falseNode; - ASSERT_EQ(output3, NormalForm::evaluate(input3)); + ASSERT_EQ(output3, BagsUtils::evaluate(input3)); } TEST_F(TestTheoryWhiteBagsNormalForm, from_set) @@ -583,7 +583,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, from_set) EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType()))); Node input1 = d_nodeManager->mkNode(BAG_FROM_SET, emptyset); Node output1 = emptybag; - ASSERT_EQ(output1, NormalForm::evaluate(input1)); + ASSERT_EQ(output1, BagsUtils::evaluate(input1)); Node x = d_nodeManager->mkConst(String("x")); Node y = d_nodeManager->mkConst(String("y")); @@ -602,13 +602,13 @@ TEST_F(TestTheoryWhiteBagsNormalForm, from_set) Node input2 = d_nodeManager->mkNode(BAG_FROM_SET, xSingleton); Node output2 = x_1; - ASSERT_EQ(output2, NormalForm::evaluate(input2)); + ASSERT_EQ(output2, BagsUtils::evaluate(input2)); // for normal sets, the first node is the largest, not smallest Node normalSet = d_nodeManager->mkNode(SET_UNION, ySingleton, xSingleton); Node input3 = d_nodeManager->mkNode(BAG_FROM_SET, normalSet); Node output3 = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_1, y_1); - ASSERT_EQ(output3, NormalForm::evaluate(input3)); + ASSERT_EQ(output3, BagsUtils::evaluate(input3)); } TEST_F(TestTheoryWhiteBagsNormalForm, to_set) @@ -626,7 +626,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, to_set) EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType()))); Node input1 = d_nodeManager->mkNode(BAG_TO_SET, emptybag); Node output1 = emptyset; - ASSERT_EQ(output1, NormalForm::evaluate(input1)); + ASSERT_EQ(output1, BagsUtils::evaluate(input1)); Node x = d_nodeManager->mkConst(String("x")); Node y = d_nodeManager->mkConst(String("y")); @@ -645,13 +645,13 @@ TEST_F(TestTheoryWhiteBagsNormalForm, to_set) Node input2 = d_nodeManager->mkNode(BAG_TO_SET, x_4); Node output2 = xSingleton; - ASSERT_EQ(output2, NormalForm::evaluate(input2)); + ASSERT_EQ(output2, BagsUtils::evaluate(input2)); // for normal sets, the first node is the largest, not smallest Node normalBag = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_4, y_5); Node input3 = d_nodeManager->mkNode(BAG_TO_SET, normalBag); Node output3 = d_nodeManager->mkNode(SET_UNION, ySingleton, xSingleton); - ASSERT_EQ(output3, NormalForm::evaluate(input3)); + ASSERT_EQ(output3, BagsUtils::evaluate(input3)); } } // namespace test } // namespace cvc5