KIND_ENUM(BAG_MAP, internal::Kind::BAG_MAP),
KIND_ENUM(BAG_FILTER, internal::Kind::BAG_FILTER),
KIND_ENUM(BAG_FOLD, internal::Kind::BAG_FOLD),
+ KIND_ENUM(BAG_PARTITION, internal::Kind::BAG_PARTITION),
KIND_ENUM(TABLE_PRODUCT, internal::Kind::TABLE_PRODUCT),
KIND_ENUM(TABLE_PROJECT, internal::Kind::TABLE_PROJECT),
/* Strings ---------------------------------------------------------- */
{internal::Kind::BAG_MAP, BAG_MAP},
{internal::Kind::BAG_FILTER, BAG_FILTER},
{internal::Kind::BAG_FOLD, BAG_FOLD},
+ {internal::Kind::BAG_PARTITION, BAG_PARTITION},
{internal::Kind::TABLE_PRODUCT, TABLE_PRODUCT},
{internal::Kind::TABLE_PROJECT, TABLE_PROJECT},
{internal::Kind::TABLE_PROJECT_OP, TABLE_PROJECT},
* \endrst
*/
BAG_FOLD,
+ /**
+ * Bag partition.
+ *
+ * \rst
+ * This operator partitions of a bag of elements into disjoint bags.
+ * (bag.partition :math:`r \; B`) partitions the elements of bag :math:`B`
+ * of type :math:`(Bag \; E)` based on the equivalence relations :math:`r` of
+ * type :math:`(\rightarrow \; E \; E \; Bool)`.
+ * It returns a bag of bags of type :math:`(Bag \; (Bag \; E))`.
+ *
+ * - Arity: ``2``
+ *
+ * - ``1:`` Term of function Sort :math:`(\rightarrow \; E \; E \; Bool)`
+ * - ``2:`` Term of bag Sort (Bag :math:`E`)
+ * \endrst
+ *
+ * - Create Term of this Kind with:
+ *
+ * - Solver::mkTerm(Kind, const std::vector<Term>&) const
+ * - Solver::mkTerm(const Op&, const std::vector<Term>&) const
+ *
+ * \rst
+ * .. warning:: This kind is experimental and may be changed or removed in
+ * future versions.
+ * \endrst
+ */
+ BAG_PARTITION,
/**
* Table cross product.
*
addOperator(cvc5::BAG_MAP, "bag.map");
addOperator(cvc5::BAG_FILTER, "bag.filter");
addOperator(cvc5::BAG_FOLD, "bag.fold");
+ addOperator(cvc5::BAG_PARTITION, "bag.partition");
addOperator(cvc5::TABLE_PRODUCT, "table.product");
}
if (d_logic.isTheoryEnabled(internal::theory::THEORY_STRINGS))
case kind::BAG_MAP: return "bag.map";
case kind::BAG_FILTER: return "bag.filter";
case kind::BAG_FOLD: return "bag.fold";
+ case kind::BAG_PARTITION: return "bag.partition";
case kind::TABLE_PRODUCT: return "table.product";
case kind::TABLE_PROJECT: return "table.project";
#include "expr/emptybag.h"
#include "theory/bags/bags_utils.h"
+#include "theory/rewriter.h"
#include "util/rational.h"
#include "util/statistics_registry.h"
{
}
-BagsRewriter::BagsRewriter(HistogramStat<Rewrite>* statistics)
- : d_statistics(statistics)
+BagsRewriter::BagsRewriter(Rewriter* r, HistogramStat<Rewrite>* statistics)
+ : d_rewriter(r), d_statistics(statistics)
{
d_nm = NodeManager::currentNM();
d_zero = d_nm->mkConstInt(Rational(0));
case BAG_MAP: response = postRewriteMap(n); break;
case BAG_FILTER: response = postRewriteFilter(n); break;
case BAG_FOLD: response = postRewriteFold(n); break;
+ case BAG_PARTITION: response = postRewritePartition(n); break;
case TABLE_PRODUCT: response = postRewriteProduct(n); break;
default: response = BagsRewriteResponse(n, Rewrite::NONE); break;
}
return BagsRewriteResponse(n, Rewrite::NONE);
}
+BagsRewriteResponse BagsRewriter::postRewritePartition(const TNode& n) const
+{
+ Assert(n.getKind() == kind::BAG_PARTITION);
+ if (n[1].isConst())
+ {
+ Node ret = BagsUtils::evaluateBagPartition(d_rewriter, n);
+ if (ret != n)
+ {
+ return BagsRewriteResponse(ret, Rewrite::PARTITION_CONST);
+ }
+ }
+
+ return BagsRewriteResponse(n, Rewrite::NONE);
+}
+
BagsRewriteResponse BagsRewriter::postRewriteProduct(const TNode& n) const
{
Assert(n.getKind() == TABLE_PRODUCT);
class BagsRewriter : public TheoryRewriter
{
public:
- BagsRewriter(HistogramStat<Rewrite>* statistics = nullptr);
+ BagsRewriter(Rewriter* r, HistogramStat<Rewrite>* statistics = nullptr);
/**
* postRewrite nodes with kinds: BAG_MAKE, BAG_COUNT, BAG_UNION_MAX,
* where f: T1 -> T2 -> T2
*/
BagsRewriteResponse postRewriteFold(const TNode& n) const;
+ BagsRewriteResponse postRewritePartition(const TNode& n) const;
/**
* rewrites for n include:
* - (bag.product A (as bag.empty T2)) = (as bag.empty T)
NodeManager* d_nm;
Node d_zero;
Node d_one;
+ /**
+ * Pointer to the rewriter. NOTE this is a cyclic dependency, and should
+ * be removed.
+ */
+ Rewriter* d_rewriter;
/** Reference to the rewriter statistics. */
HistogramStat<Rewrite>* d_statistics;
}; /* class TheoryBagsRewriter */
#include "smt/logic_exception.h"
#include "table_project_op.h"
#include "theory/datatypes/tuple_utils.h"
+#include "theory/rewriter.h"
#include "theory/sets/normal_form.h"
#include "theory/type_enumerator.h"
+#include "theory/uf/equality_engine.h"
#include "util/rational.h"
using namespace cvc5::internal::kind;
return ret;
}
+Node BagsUtils::evaluateBagPartition(Rewriter* rewriter, TNode n)
+{
+ Assert(n.getKind() == BAG_PARTITION);
+ NodeManager* nm = NodeManager::currentNM();
+
+ // Examples
+ // --------
+ // minimum string
+ // - (bag.partition
+ // ((lambda ((x Int) (y Int)) (= 0 (+ x y)))
+ // (bag.union_disjoint
+ // (bag 1 20) (bag (- 1) 50)
+ // (bag 2 30) (bag (- 2) 60)
+ // (bag 3 40) (bag (- 3) 70)
+ // (bag 4 100)))
+ // = (bag.union_disjoint
+ // (bag (bag 4 100) 1)
+ // (bag (bag.union_disjoint (bag 1 20) (bag (- 1) 50)) 1)
+ // (bag (bag.union_disjoint (bag 2 30) (bag (- 2) 60)) 1)
+ // (bag (bag.union_disjoint (bag 3 40) (bag (- 3) 70)) 1)))
+
+ Node r = n[0]; // equivalence relation
+ Node A = n[1]; // bag
+ TypeNode bagType = A.getType();
+ TypeNode partitionType = n.getType();
+ std::map<Node, Rational> elements = BagsUtils::getBagElements(A);
+ Trace("bags-partition") << "elements: " << elements << std::endl;
+ // a simple map from elements to equivalent classes with this invariant:
+ // each key element must appear exactly once in one of the values.
+ std::map<Node, std::set<Node>> sets;
+ std::set<Node> emptyClass;
+ for (const auto& pair : elements)
+ {
+ // initially each singleton element is an equivalence class
+ sets[pair.first] = {pair.first};
+ }
+ for (std::map<Node, Rational>::iterator i = elements.begin();
+ i != elements.end();
+ ++i)
+ {
+ if (sets[i->first].empty())
+ {
+ // skip this element since its equivalent class has already been processed
+ continue;
+ }
+ std::map<Node, Rational>::iterator j = i;
+ ++j;
+ while (j != elements.end())
+ {
+ Node sameClass = nm->mkNode(APPLY_UF, r, i->first, j->first);
+ sameClass = rewriter->rewrite(sameClass);
+ if (!sameClass.isConst())
+ {
+ // we can not pursue further, so we return n itself
+ return n;
+ }
+ if (sameClass.getConst<bool>())
+ {
+ // add element j to the equivalent class
+ sets[i->first].insert(j->first);
+ // mark the equivalent class of j as processed
+ sets[j->first] = emptyClass;
+ }
+ ++j;
+ }
+ }
+
+ // construct the partition parts
+ std::map<Node, Rational> parts;
+ for (std::pair<Node, std::set<Node>> pair : sets)
+ {
+ const std::set<Node>& eqc = pair.second;
+ if (eqc.empty())
+ {
+ continue;
+ }
+ std::vector<Node> bags;
+ for (const Node& node : eqc)
+ {
+ Node bag = nm->mkBag(
+ bagType.getBagElementType(), node, nm->mkConstInt(elements[node]));
+ bags.push_back(bag);
+ }
+ Node part = computeDisjointUnion(bagType, bags);
+ // each part in the partitions has multiplicity one
+ parts[part] = Rational(1);
+ }
+ Node ret = constructConstantBagFromElements(partitionType, parts);
+ Trace("bags-partition") << "ret: " << ret << std::endl;
+ return ret;
+}
+
Node BagsUtils::constructProductTuple(TNode n, TNode e1, TNode e2)
{
Assert(n.getKind() == TABLE_PRODUCT);
#ifndef CVC5__THEORY__BAGS__UTILS_H
#define CVC5__THEORY__BAGS__UTILS_H
+#include "theory/theory_rewriter.h"
+
namespace cvc5::internal {
namespace theory {
namespace bags {
*/
static Node evaluateBagFold(TNode n);
+ /**
+ * @param n has the form (bag.partition r A) where A is a constant bag
+ * @return a partition of A based on the equivalence relation r
+ */
+ static Node evaluateBagPartition(Rewriter *rewriter, TNode n);
+
/**
* @param n has the form (bag.filter p A) where A is a constant bag
* @return A filtered with predicate p
# B: a bag of type (Bag T1)
operator BAG_FOLD 3 "bag fold operator"
+# bag.partition operator partitions a bag into a bag of bags based on an equivalence relation such that
+# each element occurs exactly in one these bags.
+operator BAG_PARTITION 2 "bag partition operator"
+
typerule BAG_UNION_MAX ::cvc5::internal::theory::bags::BinaryOperatorTypeRule
typerule BAG_UNION_DISJOINT ::cvc5::internal::theory::bags::BinaryOperatorTypeRule
typerule BAG_INTER_MIN ::cvc5::internal::theory::bags::BinaryOperatorTypeRule
typerule BAG_MAP ::cvc5::internal::theory::bags::BagMapTypeRule
typerule BAG_FILTER ::cvc5::internal::theory::bags::BagFilterTypeRule
typerule BAG_FOLD ::cvc5::internal::theory::bags::BagFoldTypeRule
+typerule BAG_PARTITION ::cvc5::internal::theory::bags::BagPartitionTypeRule
construle BAG_UNION_DISJOINT ::cvc5::internal::theory::bags::BinaryOperatorTypeRule
construle BAG_MAKE ::cvc5::internal::theory::bags::BagMakeTypeRule
case Rewrite::MAP_BAG_MAKE: return "MAP_BAG_MAKE";
case Rewrite::MAP_UNION_DISJOINT: return "MAP_UNION_DISJOINT";
case Rewrite::MEMBER: return "MEMBER";
+ case Rewrite::PARTITION_CONST: return "PARTITION_CONST";
case Rewrite::PRODUCT_EMPTY: return "PRODUCT_EMPTY";
case Rewrite::REMOVE_FROM_UNION: return "REMOVE_FROM_UNION";
case Rewrite::REMOVE_MIN: return "REMOVE_MIN";
MAP_BAG_MAKE,
MAP_UNION_DISJOINT,
MEMBER,
+ PARTITION_CONST,
PRODUCT_EMPTY,
REMOVE_FROM_UNION,
REMOVE_MIN,
d_ig(&d_state, &d_im),
d_notify(*this, d_im),
d_statistics(),
- d_rewriter(&d_statistics.d_rewrites),
+ d_rewriter(env.getRewriter(), &d_statistics.d_rewrites),
d_termReg(env, d_state, d_im),
d_solver(env, d_state, d_im, d_termReg),
d_cardSolver(env, d_state, d_im),
d_equalityEngine->addFunctionKind(BAG_CARD);
d_equalityEngine->addFunctionKind(BAG_FROM_SET);
d_equalityEngine->addFunctionKind(BAG_TO_SET);
+ d_equalityEngine->addFunctionKind(BAG_PARTITION);
d_equalityEngine->addFunctionKind(TABLE_PRODUCT);
d_equalityEngine->addFunctionKind(TABLE_PROJECT);
}
case BAG_FROM_SET:
case BAG_TO_SET:
case BAG_IS_SINGLETON:
+ case BAG_PARTITION:
case TABLE_PROJECT:
{
std::stringstream ss;
return retType;
}
+TypeNode BagPartitionTypeRule::computeType(NodeManager* nodeManager,
+ TNode n,
+ bool check)
+{
+ Assert(n.getKind() == kind::BAG_PARTITION);
+ TypeNode functionType = n[0].getType(check);
+ TypeNode bagType = n[1].getType(check);
+ NodeManager* nm = NodeManager::currentNM();
+ if (check)
+ {
+ if (!bagType.isBag())
+ {
+ throw TypeCheckingExceptionPrivate(
+ n,
+ "bag.partition operator expects a bag in the second argument, "
+ "a non-bag is found");
+ }
+
+ TypeNode elementType = bagType.getBagElementType();
+
+ if (!(functionType.isFunction()))
+ {
+ std::stringstream ss;
+ ss << "Operator " << n.getKind() << " expects a function of type (-> "
+ << elementType << " " << elementType << " Bool) as a first argument. "
+ << "Found a term of type '" << functionType << "'.";
+ throw TypeCheckingExceptionPrivate(n, ss.str());
+ }
+ std::vector<TypeNode> argTypes = functionType.getArgTypes();
+ TypeNode rangeType = functionType.getRangeType();
+ if (!(argTypes.size() == 2 && elementType.isSubtypeOf(argTypes[0])
+ && elementType.isSubtypeOf(argTypes[1])
+ && rangeType == nm->booleanType()))
+ {
+ std::stringstream ss;
+ ss << "Operator " << n.getKind() << " expects a function of type (-> "
+ << elementType << " " << elementType << " Bool) as a first argument. "
+ << "Found a term of type '" << functionType << "'.";
+ throw TypeCheckingExceptionPrivate(n, ss.str());
+ }
+ }
+ TypeNode retType = nm->mkBagType(bagType);
+ return retType;
+}
+
TypeNode TableProductTypeRule::computeType(NodeManager* nodeManager,
TNode n,
bool check)
/**
* Type rule for (bag.fold f t A) to make sure f is a binary operation of type
- * (-> T1 T2 T2), t of type T2, and B is a bag of type (Bag T1)
+ * (-> T1 T2 T2), t of type T2, and A is a bag of type (Bag T1)
*/
struct BagFoldTypeRule
{
static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
}; /* struct BagFoldTypeRule */
+/**
+ * Type rule for (bag.partition r A) to make sure r is a binary operation of type
+ * (-> T1 T1 Bool), and A is a bag of type (Bag T1)
+ */
+struct BagPartitionTypeRule
+{
+ static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
+}; /* struct BagFoldTypeRule */
+
/**
* Type rule for (table.product A B) to make sure A,B are bags of tuples,
* and get the type of the cross product
regress1/bug694-Unapply1.scala-0.smt2
regress1/bug800.smt2
regress1/bags/bag_member.smt2
+ regress1/bags/bag_partition1.smt2
regress1/bags/bags-of-bags-subtypes.smt2
regress1/bags/card1.smt2
regress1/bags/card2.smt2
--- /dev/null
+(set-logic HO_ALL)
+
+(set-info :status sat)
+
+(set-option :fmf-bound true)
+(set-option :uf-lazy-ll true)
+
+; equivalence relation : inverse
+(define-fun r ((x Int) (y Int)) Bool (= 0 (+ x y)))
+
+(declare-fun A () (Bag Int))
+(declare-fun B () (Bag (Bag Int)))
+(declare-fun C () (Bag (Bag Int)))
+
+(assert
+ (= A
+ (bag.union_disjoint
+ (bag 1 20) (bag (- 1) 50)
+ (bag 2 30) (bag (- 2) 60)
+ (bag 3 40) (bag (- 3) 70)
+ (bag 4 100))))
+
+;(define-fun B () (Bag (Bag Int))
+; (bag.union_disjoint (bag (bag 4 100) 1)
+; (bag (bag.union_disjoint (bag 1 20) (bag (- 1) 50)) 1)
+; (bag (bag.union_disjoint (bag 2 30) (bag (- 2) 60)) 1)
+; (bag (bag.union_disjoint (bag 3 40) (bag (- 3) 70)) 1)))
+
+(assert (= B (bag.partition r A)))
+; (define-fun C () (Bag (Bag Int)) (as bag.empty (Bag (Bag Int))))
+(assert (= C (bag.partition r (as bag.empty (Bag Int)))))
+
+(check-sat)
+