From f29f07d2c3ac15fe55f0055c9a001dc24d13bdce Mon Sep 17 00:00:00 2001 From: mudathirmahgoub Date: Wed, 20 Apr 2022 23:46:43 -0500 Subject: [PATCH] Add bag.partition evaluation (#8637) --- src/api/cpp/cvc5.cpp | 2 + src/api/cpp/cvc5_kind.h | 27 ++++++ src/parser/smt2/smt2.cpp | 1 + src/printer/smt2/smt2_printer.cpp | 1 + src/theory/bags/bags_rewriter.cpp | 21 ++++- src/theory/bags/bags_rewriter.h | 8 +- src/theory/bags/bags_utils.cpp | 94 +++++++++++++++++++ src/theory/bags/bags_utils.h | 8 ++ src/theory/bags/kinds | 5 + src/theory/bags/rewrites.cpp | 1 + src/theory/bags/rewrites.h | 1 + src/theory/bags/theory_bags.cpp | 4 +- src/theory/bags/theory_bags_type_rules.cpp | 45 +++++++++ src/theory/bags/theory_bags_type_rules.h | 11 ++- test/regress/cli/CMakeLists.txt | 1 + .../cli/regress1/bags/bag_partition1.smt2 | 34 +++++++ 16 files changed, 259 insertions(+), 5 deletions(-) create mode 100644 test/regress/cli/regress1/bags/bag_partition1.smt2 diff --git a/src/api/cpp/cvc5.cpp b/src/api/cpp/cvc5.cpp index d3c28aa06..84967b5c9 100644 --- a/src/api/cpp/cvc5.cpp +++ b/src/api/cpp/cvc5.cpp @@ -328,6 +328,7 @@ const static std::unordered_map> KIND_ENUM(BAG_MAP, internal::Kind::BAG_MAP), KIND_ENUM(BAG_FILTER, internal::Kind::BAG_FILTER), KIND_ENUM(BAG_FOLD, internal::Kind::BAG_FOLD), + KIND_ENUM(BAG_PARTITION, internal::Kind::BAG_PARTITION), KIND_ENUM(TABLE_PRODUCT, internal::Kind::TABLE_PRODUCT), KIND_ENUM(TABLE_PROJECT, internal::Kind::TABLE_PROJECT), /* Strings ---------------------------------------------------------- */ @@ -644,6 +645,7 @@ const static std::unordered_map&) const + * - Solver::mkTerm(const Op&, const std::vector&) const + * + * \rst + * .. warning:: This kind is experimental and may be changed or removed in + * future versions. + * \endrst + */ + BAG_PARTITION, /** * Table cross product. * diff --git a/src/parser/smt2/smt2.cpp b/src/parser/smt2/smt2.cpp index 93a518df0..a4a16c214 100644 --- a/src/parser/smt2/smt2.cpp +++ b/src/parser/smt2/smt2.cpp @@ -634,6 +634,7 @@ Command* Smt2::setLogic(std::string name, bool fromCommand) addOperator(cvc5::BAG_MAP, "bag.map"); addOperator(cvc5::BAG_FILTER, "bag.filter"); addOperator(cvc5::BAG_FOLD, "bag.fold"); + addOperator(cvc5::BAG_PARTITION, "bag.partition"); addOperator(cvc5::TABLE_PRODUCT, "table.product"); } if (d_logic.isTheoryEnabled(internal::theory::THEORY_STRINGS)) diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index bc35f639f..41fa39575 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -1183,6 +1183,7 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v) case kind::BAG_MAP: return "bag.map"; case kind::BAG_FILTER: return "bag.filter"; case kind::BAG_FOLD: return "bag.fold"; + case kind::BAG_PARTITION: return "bag.partition"; case kind::TABLE_PRODUCT: return "table.product"; case kind::TABLE_PROJECT: return "table.project"; diff --git a/src/theory/bags/bags_rewriter.cpp b/src/theory/bags/bags_rewriter.cpp index cdf2dde02..6b7e49a31 100644 --- a/src/theory/bags/bags_rewriter.cpp +++ b/src/theory/bags/bags_rewriter.cpp @@ -17,6 +17,7 @@ #include "expr/emptybag.h" #include "theory/bags/bags_utils.h" +#include "theory/rewriter.h" #include "util/rational.h" #include "util/statistics_registry.h" @@ -41,8 +42,8 @@ BagsRewriteResponse::BagsRewriteResponse(const BagsRewriteResponse& r) { } -BagsRewriter::BagsRewriter(HistogramStat* statistics) - : d_statistics(statistics) +BagsRewriter::BagsRewriter(Rewriter* r, HistogramStat* statistics) + : d_rewriter(r), d_statistics(statistics) { d_nm = NodeManager::currentNM(); d_zero = d_nm->mkConstInt(Rational(0)); @@ -92,6 +93,7 @@ RewriteResponse BagsRewriter::postRewrite(TNode n) case BAG_MAP: response = postRewriteMap(n); break; case BAG_FILTER: response = postRewriteFilter(n); break; case BAG_FOLD: response = postRewriteFold(n); break; + case BAG_PARTITION: response = postRewritePartition(n); break; case TABLE_PRODUCT: response = postRewriteProduct(n); break; default: response = BagsRewriteResponse(n, Rewrite::NONE); break; } @@ -648,6 +650,21 @@ BagsRewriteResponse BagsRewriter::postRewriteFold(const TNode& n) const return BagsRewriteResponse(n, Rewrite::NONE); } +BagsRewriteResponse BagsRewriter::postRewritePartition(const TNode& n) const +{ + Assert(n.getKind() == kind::BAG_PARTITION); + if (n[1].isConst()) + { + Node ret = BagsUtils::evaluateBagPartition(d_rewriter, n); + if (ret != n) + { + return BagsRewriteResponse(ret, Rewrite::PARTITION_CONST); + } + } + + return BagsRewriteResponse(n, Rewrite::NONE); +} + BagsRewriteResponse BagsRewriter::postRewriteProduct(const TNode& n) const { Assert(n.getKind() == TABLE_PRODUCT); diff --git a/src/theory/bags/bags_rewriter.h b/src/theory/bags/bags_rewriter.h index 00c8b6d0c..3c08208a8 100644 --- a/src/theory/bags/bags_rewriter.h +++ b/src/theory/bags/bags_rewriter.h @@ -42,7 +42,7 @@ struct BagsRewriteResponse class BagsRewriter : public TheoryRewriter { public: - BagsRewriter(HistogramStat* statistics = nullptr); + BagsRewriter(Rewriter* r, HistogramStat* statistics = nullptr); /** * postRewrite nodes with kinds: BAG_MAKE, BAG_COUNT, BAG_UNION_MAX, @@ -246,6 +246,7 @@ class BagsRewriter : public TheoryRewriter * where f: T1 -> T2 -> T2 */ BagsRewriteResponse postRewriteFold(const TNode& n) const; + BagsRewriteResponse postRewritePartition(const TNode& n) const; /** * rewrites for n include: * - (bag.product A (as bag.empty T2)) = (as bag.empty T) @@ -262,6 +263,11 @@ class BagsRewriter : public TheoryRewriter NodeManager* d_nm; Node d_zero; Node d_one; + /** + * Pointer to the rewriter. NOTE this is a cyclic dependency, and should + * be removed. + */ + Rewriter* d_rewriter; /** Reference to the rewriter statistics. */ HistogramStat* d_statistics; }; /* class TheoryBagsRewriter */ diff --git a/src/theory/bags/bags_utils.cpp b/src/theory/bags/bags_utils.cpp index e71923248..fd5a98c25 100644 --- a/src/theory/bags/bags_utils.cpp +++ b/src/theory/bags/bags_utils.cpp @@ -20,8 +20,10 @@ #include "smt/logic_exception.h" #include "table_project_op.h" #include "theory/datatypes/tuple_utils.h" +#include "theory/rewriter.h" #include "theory/sets/normal_form.h" #include "theory/type_enumerator.h" +#include "theory/uf/equality_engine.h" #include "util/rational.h" using namespace cvc5::internal::kind; @@ -785,6 +787,98 @@ Node BagsUtils::evaluateBagFold(TNode n) return ret; } +Node BagsUtils::evaluateBagPartition(Rewriter* rewriter, TNode n) +{ + Assert(n.getKind() == BAG_PARTITION); + NodeManager* nm = NodeManager::currentNM(); + + // Examples + // -------- + // minimum string + // - (bag.partition + // ((lambda ((x Int) (y Int)) (= 0 (+ x y))) + // (bag.union_disjoint + // (bag 1 20) (bag (- 1) 50) + // (bag 2 30) (bag (- 2) 60) + // (bag 3 40) (bag (- 3) 70) + // (bag 4 100))) + // = (bag.union_disjoint + // (bag (bag 4 100) 1) + // (bag (bag.union_disjoint (bag 1 20) (bag (- 1) 50)) 1) + // (bag (bag.union_disjoint (bag 2 30) (bag (- 2) 60)) 1) + // (bag (bag.union_disjoint (bag 3 40) (bag (- 3) 70)) 1))) + + Node r = n[0]; // equivalence relation + Node A = n[1]; // bag + TypeNode bagType = A.getType(); + TypeNode partitionType = n.getType(); + std::map elements = BagsUtils::getBagElements(A); + Trace("bags-partition") << "elements: " << elements << std::endl; + // a simple map from elements to equivalent classes with this invariant: + // each key element must appear exactly once in one of the values. + std::map> sets; + std::set emptyClass; + for (const auto& pair : elements) + { + // initially each singleton element is an equivalence class + sets[pair.first] = {pair.first}; + } + for (std::map::iterator i = elements.begin(); + i != elements.end(); + ++i) + { + if (sets[i->first].empty()) + { + // skip this element since its equivalent class has already been processed + continue; + } + std::map::iterator j = i; + ++j; + while (j != elements.end()) + { + Node sameClass = nm->mkNode(APPLY_UF, r, i->first, j->first); + sameClass = rewriter->rewrite(sameClass); + if (!sameClass.isConst()) + { + // we can not pursue further, so we return n itself + return n; + } + if (sameClass.getConst()) + { + // add element j to the equivalent class + sets[i->first].insert(j->first); + // mark the equivalent class of j as processed + sets[j->first] = emptyClass; + } + ++j; + } + } + + // construct the partition parts + std::map parts; + for (std::pair> pair : sets) + { + const std::set& eqc = pair.second; + if (eqc.empty()) + { + continue; + } + std::vector bags; + for (const Node& node : eqc) + { + Node bag = nm->mkBag( + bagType.getBagElementType(), node, nm->mkConstInt(elements[node])); + bags.push_back(bag); + } + Node part = computeDisjointUnion(bagType, bags); + // each part in the partitions has multiplicity one + parts[part] = Rational(1); + } + Node ret = constructConstantBagFromElements(partitionType, parts); + Trace("bags-partition") << "ret: " << ret << std::endl; + return ret; +} + Node BagsUtils::constructProductTuple(TNode n, TNode e1, TNode e2) { Assert(n.getKind() == TABLE_PRODUCT); diff --git a/src/theory/bags/bags_utils.h b/src/theory/bags/bags_utils.h index 42e7b0caf..21de8e959 100644 --- a/src/theory/bags/bags_utils.h +++ b/src/theory/bags/bags_utils.h @@ -20,6 +20,8 @@ #ifndef CVC5__THEORY__BAGS__UTILS_H #define CVC5__THEORY__BAGS__UTILS_H +#include "theory/theory_rewriter.h" + namespace cvc5::internal { namespace theory { namespace bags { @@ -88,6 +90,12 @@ class BagsUtils */ static Node evaluateBagFold(TNode n); + /** + * @param n has the form (bag.partition r A) where A is a constant bag + * @return a partition of A based on the equivalence relation r + */ + static Node evaluateBagPartition(Rewriter *rewriter, TNode n); + /** * @param n has the form (bag.filter p A) where A is a constant bag * @return A filtered with predicate p diff --git a/src/theory/bags/kinds b/src/theory/bags/kinds index 49bca83fb..1e875e998 100644 --- a/src/theory/bags/kinds +++ b/src/theory/bags/kinds @@ -89,6 +89,10 @@ operator BAG_FILTER 2 "bag filter operator" # B: a bag of type (Bag T1) operator BAG_FOLD 3 "bag fold operator" +# bag.partition operator partitions a bag into a bag of bags based on an equivalence relation such that +# each element occurs exactly in one these bags. +operator BAG_PARTITION 2 "bag partition operator" + typerule BAG_UNION_MAX ::cvc5::internal::theory::bags::BinaryOperatorTypeRule typerule BAG_UNION_DISJOINT ::cvc5::internal::theory::bags::BinaryOperatorTypeRule typerule BAG_INTER_MIN ::cvc5::internal::theory::bags::BinaryOperatorTypeRule @@ -109,6 +113,7 @@ typerule BAG_TO_SET ::cvc5::internal::theory::bags::ToSetTypeRule typerule BAG_MAP ::cvc5::internal::theory::bags::BagMapTypeRule typerule BAG_FILTER ::cvc5::internal::theory::bags::BagFilterTypeRule typerule BAG_FOLD ::cvc5::internal::theory::bags::BagFoldTypeRule +typerule BAG_PARTITION ::cvc5::internal::theory::bags::BagPartitionTypeRule construle BAG_UNION_DISJOINT ::cvc5::internal::theory::bags::BinaryOperatorTypeRule construle BAG_MAKE ::cvc5::internal::theory::bags::BagMakeTypeRule diff --git a/src/theory/bags/rewrites.cpp b/src/theory/bags/rewrites.cpp index 287e879a7..0c634351a 100644 --- a/src/theory/bags/rewrites.cpp +++ b/src/theory/bags/rewrites.cpp @@ -56,6 +56,7 @@ const char* toString(Rewrite r) case Rewrite::MAP_BAG_MAKE: return "MAP_BAG_MAKE"; case Rewrite::MAP_UNION_DISJOINT: return "MAP_UNION_DISJOINT"; case Rewrite::MEMBER: return "MEMBER"; + case Rewrite::PARTITION_CONST: return "PARTITION_CONST"; case Rewrite::PRODUCT_EMPTY: return "PRODUCT_EMPTY"; case Rewrite::REMOVE_FROM_UNION: return "REMOVE_FROM_UNION"; case Rewrite::REMOVE_MIN: return "REMOVE_MIN"; diff --git a/src/theory/bags/rewrites.h b/src/theory/bags/rewrites.h index 467de36db..461ea8703 100644 --- a/src/theory/bags/rewrites.h +++ b/src/theory/bags/rewrites.h @@ -60,6 +60,7 @@ enum class Rewrite : uint32_t MAP_BAG_MAKE, MAP_UNION_DISJOINT, MEMBER, + PARTITION_CONST, PRODUCT_EMPTY, REMOVE_FROM_UNION, REMOVE_MIN, diff --git a/src/theory/bags/theory_bags.cpp b/src/theory/bags/theory_bags.cpp index 92ea5ecca..adcf3d468 100644 --- a/src/theory/bags/theory_bags.cpp +++ b/src/theory/bags/theory_bags.cpp @@ -38,7 +38,7 @@ TheoryBags::TheoryBags(Env& env, OutputChannel& out, Valuation valuation) d_ig(&d_state, &d_im), d_notify(*this, d_im), d_statistics(), - d_rewriter(&d_statistics.d_rewrites), + d_rewriter(env.getRewriter(), &d_statistics.d_rewrites), d_termReg(env, d_state, d_im), d_solver(env, d_state, d_im, d_termReg), d_cardSolver(env, d_state, d_im), @@ -80,6 +80,7 @@ void TheoryBags::finishInit() d_equalityEngine->addFunctionKind(BAG_CARD); d_equalityEngine->addFunctionKind(BAG_FROM_SET); d_equalityEngine->addFunctionKind(BAG_TO_SET); + d_equalityEngine->addFunctionKind(BAG_PARTITION); d_equalityEngine->addFunctionKind(TABLE_PRODUCT); d_equalityEngine->addFunctionKind(TABLE_PROJECT); } @@ -455,6 +456,7 @@ void TheoryBags::preRegisterTerm(TNode n) case BAG_FROM_SET: case BAG_TO_SET: case BAG_IS_SINGLETON: + case BAG_PARTITION: case TABLE_PROJECT: { std::stringstream ss; diff --git a/src/theory/bags/theory_bags_type_rules.cpp b/src/theory/bags/theory_bags_type_rules.cpp index ef2a5a350..e786a6afc 100644 --- a/src/theory/bags/theory_bags_type_rules.cpp +++ b/src/theory/bags/theory_bags_type_rules.cpp @@ -454,6 +454,51 @@ TypeNode BagFoldTypeRule::computeType(NodeManager* nodeManager, return retType; } +TypeNode BagPartitionTypeRule::computeType(NodeManager* nodeManager, + TNode n, + bool check) +{ + Assert(n.getKind() == kind::BAG_PARTITION); + TypeNode functionType = n[0].getType(check); + TypeNode bagType = n[1].getType(check); + NodeManager* nm = NodeManager::currentNM(); + if (check) + { + if (!bagType.isBag()) + { + throw TypeCheckingExceptionPrivate( + n, + "bag.partition operator expects a bag in the second argument, " + "a non-bag is found"); + } + + TypeNode elementType = bagType.getBagElementType(); + + if (!(functionType.isFunction())) + { + std::stringstream ss; + ss << "Operator " << n.getKind() << " expects a function of type (-> " + << elementType << " " << elementType << " Bool) as a first argument. " + << "Found a term of type '" << functionType << "'."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + std::vector argTypes = functionType.getArgTypes(); + TypeNode rangeType = functionType.getRangeType(); + if (!(argTypes.size() == 2 && elementType.isSubtypeOf(argTypes[0]) + && elementType.isSubtypeOf(argTypes[1]) + && rangeType == nm->booleanType())) + { + std::stringstream ss; + ss << "Operator " << n.getKind() << " expects a function of type (-> " + << elementType << " " << elementType << " Bool) as a first argument. " + << "Found a term of type '" << functionType << "'."; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + } + TypeNode retType = nm->mkBagType(bagType); + return retType; +} + TypeNode TableProductTypeRule::computeType(NodeManager* nodeManager, TNode n, bool check) diff --git a/src/theory/bags/theory_bags_type_rules.h b/src/theory/bags/theory_bags_type_rules.h index 54329b405..04e5bfd04 100644 --- a/src/theory/bags/theory_bags_type_rules.h +++ b/src/theory/bags/theory_bags_type_rules.h @@ -152,13 +152,22 @@ struct BagFilterTypeRule /** * Type rule for (bag.fold f t A) to make sure f is a binary operation of type - * (-> T1 T2 T2), t of type T2, and B is a bag of type (Bag T1) + * (-> T1 T2 T2), t of type T2, and A is a bag of type (Bag T1) */ struct BagFoldTypeRule { static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); }; /* struct BagFoldTypeRule */ +/** + * Type rule for (bag.partition r A) to make sure r is a binary operation of type + * (-> T1 T1 Bool), and A is a bag of type (Bag T1) + */ +struct BagPartitionTypeRule +{ + static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); +}; /* struct BagFoldTypeRule */ + /** * Type rule for (table.product A B) to make sure A,B are bags of tuples, * and get the type of the cross product diff --git a/test/regress/cli/CMakeLists.txt b/test/regress/cli/CMakeLists.txt index 59891768a..ca74fc74f 100644 --- a/test/regress/cli/CMakeLists.txt +++ b/test/regress/cli/CMakeLists.txt @@ -1753,6 +1753,7 @@ set(regress_1_tests regress1/bug694-Unapply1.scala-0.smt2 regress1/bug800.smt2 regress1/bags/bag_member.smt2 + regress1/bags/bag_partition1.smt2 regress1/bags/bags-of-bags-subtypes.smt2 regress1/bags/card1.smt2 regress1/bags/card2.smt2 diff --git a/test/regress/cli/regress1/bags/bag_partition1.smt2 b/test/regress/cli/regress1/bags/bag_partition1.smt2 new file mode 100644 index 000000000..84c73232f --- /dev/null +++ b/test/regress/cli/regress1/bags/bag_partition1.smt2 @@ -0,0 +1,34 @@ +(set-logic HO_ALL) + +(set-info :status sat) + +(set-option :fmf-bound true) +(set-option :uf-lazy-ll true) + +; equivalence relation : inverse +(define-fun r ((x Int) (y Int)) Bool (= 0 (+ x y))) + +(declare-fun A () (Bag Int)) +(declare-fun B () (Bag (Bag Int))) +(declare-fun C () (Bag (Bag Int))) + +(assert + (= A + (bag.union_disjoint + (bag 1 20) (bag (- 1) 50) + (bag 2 30) (bag (- 2) 60) + (bag 3 40) (bag (- 3) 70) + (bag 4 100)))) + +;(define-fun B () (Bag (Bag Int)) +; (bag.union_disjoint (bag (bag 4 100) 1) +; (bag (bag.union_disjoint (bag 1 20) (bag (- 1) 50)) 1) +; (bag (bag.union_disjoint (bag 2 30) (bag (- 2) 60)) 1) +; (bag (bag.union_disjoint (bag 3 40) (bag (- 3) 70)) 1))) + +(assert (= B (bag.partition r A))) +; (define-fun C () (Bag (Bag Int)) (as bag.empty (Bag (Bag Int)))) +(assert (= C (bag.partition r (as bag.empty (Bag Int))))) + +(check-sat) + -- 2.30.2