theory/bags/bag_reduction.h
theory/bags/bags_statistics.cpp
theory/bags/bags_statistics.h
+ theory/bags/bags_utils.cpp
+ theory/bags/bags_utils.h
theory/bags/card_solver.cpp
theory/bags/card_solver.h
theory/bags/infer_info.cpp
theory/bags/inference_generator.h
theory/bags/inference_manager.cpp
theory/bags/inference_manager.h
- theory/bags/normal_form.cpp
- theory/bags/normal_form.h
theory/bags/rewrites.cpp
theory/bags/rewrites.h
theory/bags/solver_state.cpp
{BAG_FROM_SET, cvc5::Kind::BAG_FROM_SET},
{BAG_TO_SET, cvc5::Kind::BAG_TO_SET},
{BAG_MAP, cvc5::Kind::BAG_MAP},
+ {BAG_FILTER, cvc5::Kind::BAG_FILTER},
{BAG_FOLD, cvc5::Kind::BAG_FOLD},
/* Strings ------------------------------------------------------------- */
{STRING_CONCAT, cvc5::Kind::STRING_CONCAT},
{cvc5::Kind::BAG_FROM_SET, BAG_FROM_SET},
{cvc5::Kind::BAG_TO_SET, BAG_TO_SET},
{cvc5::Kind::BAG_MAP, BAG_MAP},
+ {cvc5::Kind::BAG_FILTER, BAG_FILTER},
{cvc5::Kind::BAG_FOLD, BAG_FOLD},
/* Strings --------------------------------------------------------- */
{cvc5::Kind::STRING_CONCAT, STRING_CONCAT},
* - `Solver::mkTerm(Kind kind, const std::vector<Term>& children) const`
*/
BAG_MAP,
+ /**
+ * bag.filter operator filters the elements of a bag.
+ * (bag.filter p B) takes a predicate p of type (-> T Bool) as a first
+ * argument, and a bag B of type (Bag T) as a second argument, and returns a
+ * subbag of type (Bag T) that includes all elements of B that satisfy p
+ * with the same multiplicity.
+ *
+ * Parameters:
+ * - 1: a function of type (-> T Bool)
+ * - 2: a bag of type (Bag T)
+ *
+ * Create with:
+ * - `Solver::mkTerm(Kind kind, const Term& child1, const Term& child2)
+ * const`
+ * - `Solver::mkTerm(Kind kind, const std::vector<Term>& children) const`
+ */
+ BAG_FILTER,
/**
* bag.fold operator combines elements of a bag into a single value.
* (bag.fold f t B) folds the elements of bag B starting with term t and using
addOperator(api::BAG_FROM_SET, "bag.from_set");
addOperator(api::BAG_TO_SET, "bag.to_set");
addOperator(api::BAG_MAP, "bag.map");
+ addOperator(api::BAG_FILTER, "bag.filter");
addOperator(api::BAG_FOLD, "bag.fold");
}
if(d_logic.isTheoryEnabled(theory::THEORY_STRINGS)) {
case kind::BAG_FROM_SET: return "bag.from_set";
case kind::BAG_TO_SET: return "bag.to_set";
case kind::BAG_MAP: return "bag.map";
+ case kind::BAG_FILTER: return "bag.filter";
case kind::BAG_FOLD: return "bag.fold";
// fp theory
#include "theory/bags/bag_solver.h"
#include "expr/emptybag.h"
+#include "theory/bags/bags_utils.h"
#include "theory/bags/inference_generator.h"
#include "theory/bags/inference_manager.h"
-#include "theory/bags/normal_form.h"
#include "theory/bags/solver_state.h"
#include "theory/bags/term_registry.h"
#include "theory/uf/equality_engine_iterator.h"
case kind::BAG_DIFFERENCE_SUBTRACT: checkDifferenceSubtract(n); break;
case kind::BAG_DIFFERENCE_REMOVE: checkDifferenceRemove(n); break;
case kind::BAG_DUPLICATE_REMOVAL: checkDuplicateRemoval(n); break;
+ case kind::BAG_FILTER: checkFilter(n); break;
case kind::BAG_MAP: checkMap(n); break;
default: break;
}
}
}
+void BagSolver::checkFilter(Node n)
+{
+ Assert(n.getKind() == BAG_FILTER);
+
+ set<Node> elements;
+ const set<Node>& downwards = d_state.getElements(n);
+ const set<Node>& upwards = d_state.getElements(n[0]);
+ elements.insert(downwards.begin(), downwards.end());
+ elements.insert(upwards.begin(), upwards.end());
+
+ for (const Node& e : elements)
+ {
+ InferInfo i = d_ig.filterDownwards(n, d_state.getRepresentative(e));
+ d_im.lemmaTheoryInference(&i);
+ }
+ for (const Node& e : elements)
+ {
+ InferInfo i = d_ig.filterUpwards(n, d_state.getRepresentative(e));
+ d_im.lemmaTheoryInference(&i);
+ }
+}
+
} // namespace bags
} // namespace theory
} // namespace cvc5
void checkDisequalBagTerms();
/** apply inference rules for map operator */
void checkMap(Node n);
+ /** apply inference rules for filter operator */
+ void checkFilter(Node n);
/** The solver state object */
SolverState& d_state;
#include "theory/bags/bags_rewriter.h"
#include "expr/emptybag.h"
-#include "theory/bags/normal_form.h"
+#include "theory/bags/bags_utils.h"
#include "util/rational.h"
#include "util/statistics_registry.h"
{
response = rewriteChoose(n);
}
- else if (NormalForm::areChildrenConstants(n))
+ else if (BagsUtils::areChildrenConstants(n))
{
- Node value = NormalForm::evaluate(n);
+ Node value = BagsUtils::evaluate(n);
response = BagsRewriteResponse(value, Rewrite::CONSTANT_EVALUATION);
}
else
case BAG_FROM_SET: response = rewriteFromSet(n); break;
case BAG_TO_SET: response = rewriteToSet(n); break;
case BAG_MAP: response = postRewriteMap(n); break;
+ case BAG_FILTER: response = postRewriteFilter(n); break;
case BAG_FOLD: response = postRewriteFold(n); break;
default: response = BagsRewriteResponse(n, Rewrite::NONE); break;
}
{
// (bag.map f (as bag.empty (Bag T1)) = (as bag.empty (Bag T2))
// (bag.map f (bag "a" 3)) = (bag (f "a") 3)
- std::map<Node, Rational> elements = NormalForm::getBagElements(n[1]);
+ std::map<Node, Rational> elements = BagsUtils::getBagElements(n[1]);
std::map<Node, Rational> mappedElements;
std::map<Node, Rational>::iterator it = elements.begin();
while (it != elements.end())
++it;
}
TypeNode t = d_nm->mkBagType(n[0].getType().getRangeType());
- Node ret = NormalForm::constructConstantBagFromElements(t, mappedElements);
+ Node ret = BagsUtils::constructConstantBagFromElements(t, mappedElements);
return BagsRewriteResponse(ret, Rewrite::MAP_CONST);
}
Kind k = n[1].getKind();
}
}
+BagsRewriteResponse BagsRewriter::postRewriteFilter(const TNode& n) const
+{
+ Assert(n.getKind() == kind::BAG_FILTER);
+ Node P = n[0];
+ Node A = n[1];
+ TypeNode t = A.getType();
+ if (A.isConst())
+ {
+ // (bag.filter p (as bag.empty (Bag T)) = (as bag.empty (Bag T))
+ // (bag.filter p (bag "a" 3) ((bag "b" 2))) =
+ // (bag.union_disjoint
+ // (ite (p "a") (bag "a" 3) (as bag.empty (Bag T)))
+ // (ite (p "b") (bag "b" 2) (as bag.empty (Bag T)))
+
+ Node ret = BagsUtils::evaluateBagFilter(n);
+ return BagsRewriteResponse(ret, Rewrite::FILTER_CONST);
+ }
+ Kind k = A.getKind();
+ switch (k)
+ {
+ case BAG_MAKE:
+ {
+ // (bag.filter p (bag x y)) = (ite (p x) (bag x y) (as bag.empty (Bag T)))
+ Node empty = d_nm->mkConst(EmptyBag(t));
+ Node pOfe = d_nm->mkNode(APPLY_UF, P, A[0]);
+ Node ret = d_nm->mkNode(ITE, pOfe, A, empty);
+ return BagsRewriteResponse(ret, Rewrite::FILTER_BAG_MAKE);
+ }
+
+ case BAG_UNION_DISJOINT:
+ {
+ // (bag.filter p (bag.union_disjoint A B)) =
+ // (bag.union_disjoint (bag.filter p A) (bag.filter p B))
+ Node a = d_nm->mkNode(BAG_FILTER, n[0], n[1][0]);
+ Node b = d_nm->mkNode(BAG_FILTER, n[0], n[1][1]);
+ Node ret = d_nm->mkNode(BAG_UNION_DISJOINT, a, b);
+ return BagsRewriteResponse(ret, Rewrite::FILTER_UNION_DISJOINT);
+ }
+
+ default: return BagsRewriteResponse(n, Rewrite::NONE);
+ }
+}
+
BagsRewriteResponse BagsRewriter::postRewriteFold(const TNode& n) const
{
Assert(n.getKind() == kind::BAG_FOLD);
Node bag = n[2];
if (bag.isConst())
{
- Node value = NormalForm::evaluateBagFold(n);
+ Node value = BagsUtils::evaluateBagFold(n);
return BagsRewriteResponse(value, Rewrite::FOLD_CONST);
}
Kind k = bag.getKind();
if (bag[1].isConst() && bag[1].getConst<Rational>() > Rational(0))
{
// (bag.fold f t (bag x n)) = (f t ... (f t (f t x))) n times, n > 0
- Node value = NormalForm::evaluateBagFold(n);
+ Node value = BagsUtils::evaluateBagFold(n);
return BagsRewriteResponse(value, Rewrite::FOLD_BAG);
}
break;
*/
BagsRewriteResponse postRewriteMap(const TNode& n) const;
+ /**
+ * rewrites for n include:
+ * - (bag.filter p (as bag.empty (Bag T)) = (as bag.empty (Bag T))
+ * - (bag.filter p (bag x y)) = (ite (p x) (bag x y) (as bag.empty (Bag T)))
+ * - (bag.filter p (bag.union_disjoint A B)) =
+ * (bag.union_disjoint (bag.filter p A) (bag.filter p B))
+ * where p: T -> Bool
+ */
+ BagsRewriteResponse postRewriteFilter(const TNode& n) const;
+
/**
* rewrites for n include:
* - (bag.fold f t (as bag.empty (Bag T1))) = t
--- /dev/null
+/******************************************************************************
+ * Top contributors (to current version):
+ * Mudathir Mohamed, Aina Niemetz
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved. See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * Utility functions for bags.
+ */
+#include "bags_utils.h"
+
+#include "expr/emptybag.h"
+#include "smt/logic_exception.h"
+#include "theory/sets/normal_form.h"
+#include "theory/type_enumerator.h"
+#include "util/rational.h"
+
+using namespace cvc5::kind;
+
+namespace cvc5 {
+namespace theory {
+namespace bags {
+
+Node BagsUtils::computeDisjointUnion(TypeNode bagType,
+ const std::vector<Node>& bags)
+{
+ NodeManager* nm = NodeManager::currentNM();
+ if (bags.empty())
+ {
+ return nm->mkConst(EmptyBag(bagType));
+ }
+ if (bags.size() == 1)
+ {
+ return bags[0];
+ }
+ Node unionDisjoint = bags[0];
+ for (size_t i = 1; i < bags.size(); i++)
+ {
+ if (bags[i].getKind() == BAG_EMPTY)
+ {
+ continue;
+ }
+ unionDisjoint = nm->mkNode(BAG_UNION_DISJOINT, unionDisjoint, bags[i]);
+ }
+ return unionDisjoint;
+}
+
+bool BagsUtils::isConstant(TNode n)
+{
+ if (n.getKind() == BAG_EMPTY)
+ {
+ // empty bags are already normalized
+ return true;
+ }
+ if (n.getKind() == BAG_MAKE)
+ {
+ // see the implementation in MkBagTypeRule::computeIsConst
+ return n.isConst();
+ }
+ if (n.getKind() == BAG_UNION_DISJOINT)
+ {
+ if (!(n[0].getKind() == kind::BAG_MAKE && n[0].isConst()))
+ {
+ // the first child is not a constant
+ return false;
+ }
+ // store the previous element to check the ordering of elements
+ Node previousElement = n[0][0];
+ Node current = n[1];
+ while (current.getKind() == BAG_UNION_DISJOINT)
+ {
+ if (!(current[0].getKind() == kind::BAG_MAKE && current[0].isConst()))
+ {
+ // the current element is not a constant
+ return false;
+ }
+ if (previousElement >= current[0][0])
+ {
+ // the ordering is violated
+ return false;
+ }
+ previousElement = current[0][0];
+ current = current[1];
+ }
+ // check last element
+ if (!(current.getKind() == kind::BAG_MAKE && current.isConst()))
+ {
+ // the last element is not a constant
+ return false;
+ }
+ if (previousElement >= current[0])
+ {
+ // the ordering is violated
+ return false;
+ }
+ return true;
+ }
+
+ // only nodes with kinds EMPTY_BAG, BAG_MAKE, and BAG_UNION_DISJOINT can be
+ // constants
+ return false;
+}
+
+bool BagsUtils::areChildrenConstants(TNode n)
+{
+ return std::all_of(n.begin(), n.end(), [](Node c) { return c.isConst(); });
+}
+
+Node BagsUtils::evaluate(TNode n)
+{
+ Assert(areChildrenConstants(n));
+ if (n.isConst())
+ {
+ // a constant node is already in a normal form
+ return n;
+ }
+ switch (n.getKind())
+ {
+ case BAG_MAKE: return evaluateMakeBag(n);
+ case BAG_COUNT: return evaluateBagCount(n);
+ case BAG_DUPLICATE_REMOVAL: return evaluateDuplicateRemoval(n);
+ case BAG_UNION_DISJOINT: return evaluateUnionDisjoint(n);
+ case BAG_UNION_MAX: return evaluateUnionMax(n);
+ case BAG_INTER_MIN: return evaluateIntersectionMin(n);
+ case BAG_DIFFERENCE_SUBTRACT: return evaluateDifferenceSubtract(n);
+ case BAG_DIFFERENCE_REMOVE: return evaluateDifferenceRemove(n);
+ case BAG_CARD: return evaluateCard(n);
+ case BAG_IS_SINGLETON: return evaluateIsSingleton(n);
+ case BAG_FROM_SET: return evaluateFromSet(n);
+ case BAG_TO_SET: return evaluateToSet(n);
+ case BAG_MAP: return evaluateBagMap(n);
+ case BAG_FILTER: return evaluateBagFilter(n);
+ case BAG_FOLD: return evaluateBagFold(n);
+ default: break;
+ }
+ Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n
+ << std::endl;
+}
+
+template <typename T1, typename T2, typename T3, typename T4, typename T5>
+Node BagsUtils::evaluateBinaryOperation(const TNode& n,
+ T1&& equal,
+ T2&& less,
+ T3&& greaterOrEqual,
+ T4&& remainderOfA,
+ T5&& remainderOfB)
+{
+ std::map<Node, Rational> elementsA = getBagElements(n[0]);
+ std::map<Node, Rational> elementsB = getBagElements(n[1]);
+ std::map<Node, Rational> elements;
+
+ std::map<Node, Rational>::const_iterator itA = elementsA.begin();
+ std::map<Node, Rational>::const_iterator itB = elementsB.begin();
+
+ Trace("bags-evaluate") << "[NormalForm::evaluateBinaryOperation "
+ << n.getKind() << "] " << std::endl
+ << "elements A: " << elementsA << std::endl
+ << "elements B: " << elementsB << std::endl;
+
+ while (itA != elementsA.end() && itB != elementsB.end())
+ {
+ if (itA->first == itB->first)
+ {
+ equal(elements, itA, itB);
+ itA++;
+ itB++;
+ }
+ else if (itA->first < itB->first)
+ {
+ less(elements, itA, itB);
+ itA++;
+ }
+ else
+ {
+ greaterOrEqual(elements, itA, itB);
+ itB++;
+ }
+ }
+
+ // handle the remaining elements from A
+ remainderOfA(elements, elementsA, itA);
+ // handle the remaining elements from B
+ remainderOfB(elements, elementsB, itB);
+
+ Trace("bags-evaluate") << "elements: " << elements << std::endl;
+ Node bag = constructConstantBagFromElements(n.getType(), elements);
+ Trace("bags-evaluate") << "bag: " << bag << std::endl;
+ return bag;
+}
+
+std::map<Node, Rational> BagsUtils::getBagElements(TNode n)
+{
+ std::map<Node, Rational> elements;
+ if (n.getKind() == BAG_EMPTY)
+ {
+ return elements;
+ }
+ while (n.getKind() == kind::BAG_UNION_DISJOINT)
+ {
+ Assert(n[0].getKind() == kind::BAG_MAKE);
+ Node element = n[0][0];
+ Rational count = n[0][1].getConst<Rational>();
+ elements[element] = count;
+ n = n[1];
+ }
+ Assert(n.getKind() == kind::BAG_MAKE);
+ Node lastElement = n[0];
+ Rational lastCount = n[1].getConst<Rational>();
+ elements[lastElement] = lastCount;
+ return elements;
+}
+
+Node BagsUtils::constructConstantBagFromElements(
+ TypeNode t, const std::map<Node, Rational>& elements)
+{
+ Assert(t.isBag());
+ NodeManager* nm = NodeManager::currentNM();
+ if (elements.empty())
+ {
+ return nm->mkConst(EmptyBag(t));
+ }
+ TypeNode elementType = t.getBagElementType();
+ std::map<Node, Rational>::const_reverse_iterator it = elements.rbegin();
+ Node bag = nm->mkBag(elementType, it->first, nm->mkConstInt(it->second));
+ while (++it != elements.rend())
+ {
+ Node n = nm->mkBag(elementType, it->first, nm->mkConstInt(it->second));
+ bag = nm->mkNode(BAG_UNION_DISJOINT, n, bag);
+ }
+ return bag;
+}
+
+Node BagsUtils::constructBagFromElements(TypeNode t,
+ const std::map<Node, Node>& elements)
+{
+ Assert(t.isBag());
+ NodeManager* nm = NodeManager::currentNM();
+ if (elements.empty())
+ {
+ return nm->mkConst(EmptyBag(t));
+ }
+ TypeNode elementType = t.getBagElementType();
+ std::map<Node, Node>::const_reverse_iterator it = elements.rbegin();
+ Node bag = nm->mkBag(elementType, it->first, it->second);
+ while (++it != elements.rend())
+ {
+ Node n = nm->mkBag(elementType, it->first, it->second);
+ bag = nm->mkNode(BAG_UNION_DISJOINT, n, bag);
+ }
+ return bag;
+}
+
+Node BagsUtils::evaluateMakeBag(TNode n)
+{
+ // the case where n is const should be handled earlier.
+ // here we handle the case where the multiplicity is zero or negative
+ Assert(n.getKind() == BAG_MAKE && !n.isConst()
+ && n[1].getConst<Rational>().sgn() < 1);
+ Node emptybag = NodeManager::currentNM()->mkConst(EmptyBag(n.getType()));
+ return emptybag;
+}
+
+Node BagsUtils::evaluateBagCount(TNode n)
+{
+ Assert(n.getKind() == BAG_COUNT);
+ // Examples
+ // --------
+ // - (bag.count "x" (as bag.empty (Bag String))) = 0
+ // - (bag.count "x" (bag "y" 5)) = 0
+ // - (bag.count "x" (bag "x" 4)) = 4
+ // - (bag.count "x" (bag.union_disjoint (bag "x" 4) (bag "y" 5)) = 4
+ // - (bag.count "x" (bag.union_disjoint (bag "y" 5) (bag "z" 5)) = 0
+
+ std::map<Node, Rational> elements = getBagElements(n[1]);
+ std::map<Node, Rational>::iterator it = elements.find(n[0]);
+
+ NodeManager* nm = NodeManager::currentNM();
+ if (it != elements.end())
+ {
+ Node count = nm->mkConstInt(it->second);
+ return count;
+ }
+ return nm->mkConstInt(Rational(0));
+}
+
+Node BagsUtils::evaluateDuplicateRemoval(TNode n)
+{
+ Assert(n.getKind() == BAG_DUPLICATE_REMOVAL);
+
+ // Examples
+ // --------
+ // - (bag.duplicate_removal (as bag.empty (Bag String))) = (as bag.empty (Bag
+ // String))
+ // - (bag.duplicate_removal (bag "x" 4)) = (bag "x" 1)
+ // - (bag.duplicate_removal (bag.disjoint_union (bag "x" 3) (bag "y" 5)) =
+ // (bag.disjoint_union (bag "x" 1) (bag "y" 1)
+
+ std::map<Node, Rational> oldElements = getBagElements(n[0]);
+ // copy elements from the old bag
+ std::map<Node, Rational> newElements(oldElements);
+ Rational one = Rational(1);
+ std::map<Node, Rational>::iterator it;
+ for (it = newElements.begin(); it != newElements.end(); it++)
+ {
+ it->second = one;
+ }
+ Node bag = constructConstantBagFromElements(n[0].getType(), newElements);
+ return bag;
+}
+
+Node BagsUtils::evaluateUnionDisjoint(TNode n)
+{
+ Assert(n.getKind() == BAG_UNION_DISJOINT);
+ // Example
+ // -------
+ // input: (bag.union_disjoint A B)
+ // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
+ // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
+ // output:
+ // (bag.union_disjoint A B)
+ // where A = (bag "x" 7)
+ // B = (bag.union_disjoint (bag "y" 1) (bag "z" 2)))
+
+ auto equal = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // compute the sum of the multiplicities
+ elements[itA->first] = itA->second + itB->second;
+ };
+
+ auto less = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // add the element to the result
+ elements[itA->first] = itA->second;
+ };
+
+ auto greaterOrEqual = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // add the element to the result
+ elements[itB->first] = itB->second;
+ };
+
+ auto remainderOfA = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>& elementsA,
+ std::map<Node, Rational>::const_iterator& itA) {
+ // append the remainder of A
+ while (itA != elementsA.end())
+ {
+ elements[itA->first] = itA->second;
+ itA++;
+ }
+ };
+
+ auto remainderOfB = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>& elementsB,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // append the remainder of B
+ while (itB != elementsB.end())
+ {
+ elements[itB->first] = itB->second;
+ itB++;
+ }
+ };
+
+ return evaluateBinaryOperation(
+ n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
+}
+
+Node BagsUtils::evaluateUnionMax(TNode n)
+{
+ Assert(n.getKind() == BAG_UNION_MAX);
+ // Example
+ // -------
+ // input: (bag.union_max A B)
+ // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
+ // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
+ // output:
+ // (bag.union_disjoint A B)
+ // where A = (bag "x" 4)
+ // B = (bag.union_disjoint (bag "y" 1) (bag "z" 2)))
+
+ auto equal = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // compute the maximum multiplicity
+ elements[itA->first] = std::max(itA->second, itB->second);
+ };
+
+ auto less = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // add to the result
+ elements[itA->first] = itA->second;
+ };
+
+ auto greaterOrEqual = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // add to the result
+ elements[itB->first] = itB->second;
+ };
+
+ auto remainderOfA = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>& elementsA,
+ std::map<Node, Rational>::const_iterator& itA) {
+ // append the remainder of A
+ while (itA != elementsA.end())
+ {
+ elements[itA->first] = itA->second;
+ itA++;
+ }
+ };
+
+ auto remainderOfB = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>& elementsB,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // append the remainder of B
+ while (itB != elementsB.end())
+ {
+ elements[itB->first] = itB->second;
+ itB++;
+ }
+ };
+
+ return evaluateBinaryOperation(
+ n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
+}
+
+Node BagsUtils::evaluateIntersectionMin(TNode n)
+{
+ Assert(n.getKind() == BAG_INTER_MIN);
+ // Example
+ // -------
+ // input: (bag.inter_min A B)
+ // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
+ // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
+ // output:
+ // (bag "x" 3)
+
+ auto equal = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // compute the minimum multiplicity
+ elements[itA->first] = std::min(itA->second, itB->second);
+ };
+
+ auto less = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // do nothing
+ };
+
+ auto greaterOrEqual = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // do nothing
+ };
+
+ auto remainderOfA = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>& elementsA,
+ std::map<Node, Rational>::const_iterator& itA) {
+ // do nothing
+ };
+
+ auto remainderOfB = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>& elementsB,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // do nothing
+ };
+
+ return evaluateBinaryOperation(
+ n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
+}
+
+Node BagsUtils::evaluateDifferenceSubtract(TNode n)
+{
+ Assert(n.getKind() == BAG_DIFFERENCE_SUBTRACT);
+ // Example
+ // -------
+ // input: (bag.difference_subtract A B)
+ // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
+ // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
+ // output:
+ // (bag.union_disjoint (bag "x" 1) (bag "z" 2))
+
+ auto equal = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // subtract the multiplicities
+ elements[itA->first] = itA->second - itB->second;
+ };
+
+ auto less = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // itA->first is not in B, so we add it to the difference subtract
+ elements[itA->first] = itA->second;
+ };
+
+ auto greaterOrEqual = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // itB->first is not in A, so we just skip it
+ };
+
+ auto remainderOfA = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>& elementsA,
+ std::map<Node, Rational>::const_iterator& itA) {
+ // append the remainder of A
+ while (itA != elementsA.end())
+ {
+ elements[itA->first] = itA->second;
+ itA++;
+ }
+ };
+
+ auto remainderOfB = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>& elementsB,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // do nothing
+ };
+
+ return evaluateBinaryOperation(
+ n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
+}
+
+Node BagsUtils::evaluateDifferenceRemove(TNode n)
+{
+ Assert(n.getKind() == BAG_DIFFERENCE_REMOVE);
+ // Example
+ // -------
+ // input: (bag.difference_remove A B)
+ // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
+ // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
+ // output:
+ // (bag "z" 2)
+
+ auto equal = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // skip the shared element by doing nothing
+ };
+
+ auto less = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // itA->first is not in B, so we add it to the difference remove
+ elements[itA->first] = itA->second;
+ };
+
+ auto greaterOrEqual = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // itB->first is not in A, so we just skip it
+ };
+
+ auto remainderOfA = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>& elementsA,
+ std::map<Node, Rational>::const_iterator& itA) {
+ // append the remainder of A
+ while (itA != elementsA.end())
+ {
+ elements[itA->first] = itA->second;
+ itA++;
+ }
+ };
+
+ auto remainderOfB = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>& elementsB,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // do nothing
+ };
+
+ return evaluateBinaryOperation(
+ n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
+}
+
+Node BagsUtils::evaluateChoose(TNode n)
+{
+ Assert(n.getKind() == BAG_CHOOSE);
+ // Examples
+ // --------
+ // - (bag.choose (bag "x" 4)) = "x"
+
+ if (n[0].getKind() == BAG_MAKE)
+ {
+ return n[0][0];
+ }
+ throw LogicException("BAG_CHOOSE_TOTAL is not supported yet");
+}
+
+Node BagsUtils::evaluateCard(TNode n)
+{
+ Assert(n.getKind() == BAG_CARD);
+ // Examples
+ // --------
+ // - (card (as bag.empty (Bag String))) = 0
+ // - (bag.choose (bag "x" 4)) = 4
+ // - (bag.choose (bag.union_disjoint (bag "x" 4) (bag "y" 1))) = 5
+
+ std::map<Node, Rational> elements = getBagElements(n[0]);
+ Rational sum(0);
+ for (std::pair<Node, Rational> element : elements)
+ {
+ sum += element.second;
+ }
+
+ NodeManager* nm = NodeManager::currentNM();
+ Node sumNode = nm->mkConstInt(sum);
+ return sumNode;
+}
+
+Node BagsUtils::evaluateIsSingleton(TNode n)
+{
+ Assert(n.getKind() == BAG_IS_SINGLETON);
+ // Examples
+ // --------
+ // - (bag.is_singleton (as bag.empty (Bag String))) = false
+ // - (bag.is_singleton (bag "x" 1)) = true
+ // - (bag.is_singleton (bag "x" 4)) = false
+ // - (bag.is_singleton (bag.union_disjoint (bag "x" 1) (bag "y" 1)))
+ // = false
+
+ if (n[0].getKind() == BAG_MAKE && n[0][1].getConst<Rational>().isOne())
+ {
+ return NodeManager::currentNM()->mkConst(true);
+ }
+ return NodeManager::currentNM()->mkConst(false);
+}
+
+Node BagsUtils::evaluateFromSet(TNode n)
+{
+ Assert(n.getKind() == BAG_FROM_SET);
+
+ // Examples
+ // --------
+ // - (bag.from_set (as set.empty (Set String))) = (as bag.empty (Bag String))
+ // - (bag.from_set (set.singleton "x")) = (bag "x" 1)
+ // - (bag.from_set (set.union (set.singleton "x") (set.singleton "y"))) =
+ // (bag.disjoint_union (bag "x" 1) (bag "y" 1))
+
+ NodeManager* nm = NodeManager::currentNM();
+ std::set<Node> setElements =
+ sets::NormalForm::getElementsFromNormalConstant(n[0]);
+ Rational one = Rational(1);
+ std::map<Node, Rational> bagElements;
+ for (const Node& element : setElements)
+ {
+ bagElements[element] = one;
+ }
+ TypeNode bagType = nm->mkBagType(n[0].getType().getSetElementType());
+ Node bag = constructConstantBagFromElements(bagType, bagElements);
+ return bag;
+}
+
+Node BagsUtils::evaluateToSet(TNode n)
+{
+ Assert(n.getKind() == BAG_TO_SET);
+
+ // Examples
+ // --------
+ // - (bag.to_set (as bag.empty (Bag String))) = (as set.empty (Set String))
+ // - (bag.to_set (bag "x" 4)) = (set.singleton "x")
+ // - (bag.to_set (bag.disjoint_union (bag "x" 3) (bag "y" 5)) =
+ // (set.union (set.singleton "x") (set.singleton "y")))
+
+ NodeManager* nm = NodeManager::currentNM();
+ std::map<Node, Rational> bagElements = getBagElements(n[0]);
+ std::set<Node> setElements;
+ std::map<Node, Rational>::const_reverse_iterator it;
+ for (it = bagElements.rbegin(); it != bagElements.rend(); it++)
+ {
+ setElements.insert(it->first);
+ }
+ TypeNode setType = nm->mkSetType(n[0].getType().getBagElementType());
+ Node set = sets::NormalForm::elementsToSet(setElements, setType);
+ return set;
+}
+
+Node BagsUtils::evaluateBagMap(TNode n)
+{
+ Assert(n.getKind() == BAG_MAP);
+
+ // Examples
+ // --------
+ // - (bag.map ((lambda ((x String)) "z")
+ // (bag.union_disjoint (bag "a" 2) (bag "b" 3)) =
+ // (bag.union_disjoint
+ // (bag ((lambda ((x String)) "z") "a") 2)
+ // (bag ((lambda ((x String)) "z") "b") 3)) =
+ // (bag "z" 5)
+
+ std::map<Node, Rational> elements = BagsUtils::getBagElements(n[1]);
+ std::map<Node, Rational> mappedElements;
+ std::map<Node, Rational>::iterator it = elements.begin();
+ NodeManager* nm = NodeManager::currentNM();
+ while (it != elements.end())
+ {
+ Node mappedElement = nm->mkNode(APPLY_UF, n[0], it->first);
+ mappedElements[mappedElement] = it->second;
+ ++it;
+ }
+ TypeNode t = nm->mkBagType(n[0].getType().getRangeType());
+ Node ret = BagsUtils::constructConstantBagFromElements(t, mappedElements);
+ return ret;
+}
+
+Node BagsUtils::evaluateBagFilter(TNode n)
+{
+ Assert(n.getKind() == BAG_FILTER);
+
+ // - (bag.filter p (as bag.empty (Bag T)) = (as bag.empty (Bag T))
+ // - (bag.filter p (bag.union_disjoint (bag "a" 3) (bag "b" 2))) =
+ // (bag.union_disjoint
+ // (ite (p "a") (bag "a" 3) (as bag.empty (Bag T)))
+ // (ite (p "b") (bag "b" 2) (as bag.empty (Bag T)))
+
+ Node P = n[0];
+ Node A = n[1];
+ TypeNode bagType = A.getType();
+ NodeManager* nm = NodeManager::currentNM();
+ Node empty = nm->mkConst(EmptyBag(bagType));
+
+ std::map<Node, Rational> elements = getBagElements(n[1]);
+ std::vector<Node> bags;
+
+ for (const auto& [e, count] : elements)
+ {
+ Node multiplicity = nm->mkConst(CONST_RATIONAL, count);
+ Node bag = nm->mkBag(bagType.getBagElementType(), e, multiplicity);
+ Node pOfe = nm->mkNode(APPLY_UF, P, e);
+ Node ite = nm->mkNode(ITE, pOfe, bag, empty);
+ bags.push_back(ite);
+ }
+ Node ret = computeDisjointUnion(bagType, bags);
+ return ret;
+}
+
+Node BagsUtils::evaluateBagFold(TNode n)
+{
+ Assert(n.getKind() == BAG_FOLD);
+
+ // Examples
+ // --------
+ // minimum string
+ // - (bag.fold
+ // ((lambda ((x String) (y String)) (ite (str.< x y) x y))
+ // ""
+ // (bag.union_disjoint (bag "a" 2) (bag "b" 3))
+ // = "a"
+
+ Node f = n[0]; // combining function
+ Node ret = n[1]; // initial value
+ Node A = n[2]; // bag
+ std::map<Node, Rational> elements = BagsUtils::getBagElements(A);
+
+ std::map<Node, Rational>::iterator it = elements.begin();
+ NodeManager* nm = NodeManager::currentNM();
+ while (it != elements.end())
+ {
+ // apply the combination function n times, where n is the multiplicity
+ Rational count = it->second;
+ Assert(count.sgn() >= 0) << "negative multiplicity" << std::endl;
+ while (!count.isZero())
+ {
+ ret = nm->mkNode(APPLY_UF, f, it->first, ret);
+ count = count - 1;
+ }
+ ++it;
+ }
+ return ret;
+}
+
+} // namespace bags
+} // namespace theory
+} // namespace cvc5
--- /dev/null
+/******************************************************************************
+ * Top contributors (to current version):
+ * Mudathir Mohamed
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved. See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * Utility functions for bags.
+ */
+
+#include <expr/node.h>
+
+#include "cvc5_private.h"
+
+#ifndef CVC5__THEORY__BAGS__NORMAL_FORM_H
+#define CVC5__THEORY__BAGS__NORMAL_FORM_H
+
+namespace cvc5 {
+namespace theory {
+namespace bags {
+
+class BagsUtils
+{
+ public:
+ /**
+ * @param bagType type of bags
+ * @param bags a vector of bag nodes
+ * @return disjoint union of these bags
+ */
+ static Node computeDisjointUnion(TypeNode bagType,
+ const std::vector<Node>& bags);
+ /**
+ * Returns true if n is considered a to be a (canonical) constant bag value.
+ * A canonical bag value is one whose AST is:
+ * (bag.union_disjoint (bag e1 c1) ...
+ * (bag.union_disjoint (bag e_{n-1} c_{n-1}) (bag e_n c_n))))
+ * where c1 ... cn are positive integers, e1 ... en are constants, and the
+ * node identifier of these constants are such that: e1 < ... < en.
+ * Also handles the corner cases of empty bag and bag constructed by bag
+ */
+ static bool isConstant(TNode n);
+ /**
+ * check whether all children of the given node are constants
+ */
+ static bool areChildrenConstants(TNode n);
+ /**
+ * evaluate the node n to a constant value.
+ * As a precondition, children of n should be constants.
+ */
+ static Node evaluate(TNode n);
+
+ /**
+ * get the elements along with their multiplicities in a given bag
+ * @param n a constant node whose type is a bag
+ * @return a map whose keys are constant elements and values are
+ * multiplicities
+ */
+ static std::map<Node, Rational> getBagElements(TNode n);
+
+ /**
+ * construct a constant bag from constant elements
+ * @param t the type of the returned bag
+ * @param elements a map whose keys are constant elements and values are
+ * multiplicities
+ * @return a constant bag that contains
+ */
+ static Node constructConstantBagFromElements(
+ TypeNode t, const std::map<Node, Rational>& elements);
+
+ /**
+ * construct a constant bag from node elements
+ * @param t the type of the returned bag
+ * @param elements a map whose keys are constant elements and values are
+ * multiplicities
+ * @return a constant bag that contains
+ */
+ static Node constructBagFromElements(TypeNode t,
+ const std::map<Node, Node>& elements);
+
+ /**
+ * @param n has the form (bag.fold f t A) where A is a constant bag
+ * @return a single value which is the result of the fold
+ */
+ static Node evaluateBagFold(TNode n);
+
+ /**
+ * @param n has the form (bag.filter p A) where A is a constant bag
+ * @return A filtered with predicate p
+ */
+ static Node evaluateBagFilter(TNode n);
+
+ private:
+ /**
+ * a high order helper function that return a constant bag that is the result
+ * of (op A B) where op is a binary operator and A, B are constant bags.
+ * The result is computed from the elements of A (elementsA with iterator itA)
+ * and elements of B (elementsB with iterator itB).
+ * The arguments below specify how these iterators are used to generate the
+ * elements of the result (elements).
+ * @param n a node whose kind is a binary operator (bag.union_disjoint,
+ * union_max, intersection_min, difference_subtract, difference_remove) and
+ * whose children are constant bags.
+ * @param equal a lambda expression that receives (elements, itA, itB) and
+ * specify the action that needs to be taken when the elements of itA, itB are
+ * equal.
+ * @param less a lambda expression that receives (elements, itA, itB) and
+ * specify the action that needs to be taken when the element itA is less than
+ * the element of itB.
+ * @param greaterOrEqual less a lambda expression that receives (elements,
+ * itA, itB) and specify the action that needs to be taken when the element
+ * itA is greater than or equal than the element of itB.
+ * @param remainderOfA a lambda expression that receives (elements, elementsA,
+ * itA) and specify the action that needs to be taken to the remaining
+ * elements of A when all elements of B are visited.
+ * @param remainderOfB a lambda expression that receives (elements, elementsB,
+ * itB) and specify the action that needs to be taken to the remaining
+ * elements of B when all elements of A are visited.
+ * @return a constant bag that the result of (op n[0] n[1])
+ */
+ template <typename T1, typename T2, typename T3, typename T4, typename T5>
+ static Node evaluateBinaryOperation(const TNode& n,
+ T1&& equal,
+ T2&& less,
+ T3&& greaterOrEqual,
+ T4&& remainderOfA,
+ T5&& remainderOfB);
+ /**
+ * evaluate n as follows:
+ * - (bag a 0) = (as bag.empty T) where T is the type of the original bag
+ * - (bag a (-c)) = (as bag.empty T) where T is the type the original bag,
+ * and c > 0 is a constant
+ */
+ static Node evaluateMakeBag(TNode n);
+
+ /**
+ * returns the multiplicity in a constant bag
+ * @param n has the form (bag.count x A) where x, A are constants
+ * @return the multiplicity of element x in bag A.
+ */
+ static Node evaluateBagCount(TNode n);
+
+ /**
+ * @param n has the form (bag.duplicate_removal A) where A is a constant bag
+ * @return a constant bag constructed from the elements in A where each
+ * element has multiplicity one
+ */
+ static Node evaluateDuplicateRemoval(TNode n);
+
+ /**
+ * evaluates union disjoint node such that the returned node is a canonical
+ * bag that has the form
+ * (bag.union_disjoint (bag e1 c1) ...
+ * (bag.union_disjoint * (bag e_{n-1} c_{n-1}) (bag e_n c_n)))) where
+ * c1... cn are positive integers, e1 ... en are constants, and the node
+ * identifier of these constants are such that: e1 < ... < en.
+ * @param n has the form (bag.union_disjoint A B) where A, B are constant bags
+ * @return the union disjoint of A and B
+ */
+ static Node evaluateUnionDisjoint(TNode n);
+ /**
+ * @param n has the form (bag.union_max A B) where A, B are constant bags
+ * @return the union max of A and B
+ */
+ static Node evaluateUnionMax(TNode n);
+ /**
+ * @param n has the form (bag.inter_min A B) where A, B are constant bags
+ * @return the intersection min of A and B
+ */
+ static Node evaluateIntersectionMin(TNode n);
+ /**
+ * @param n has the form (bag.difference_subtract A B) where A, B are constant
+ * bags
+ * @return the difference subtract of A and B
+ */
+ static Node evaluateDifferenceSubtract(TNode n);
+ /**
+ * @param n has the form (bag.difference_remove A B) where A, B are constant
+ * bags
+ * @return the difference remove of A and B
+ */
+ static Node evaluateDifferenceRemove(TNode n);
+ /**
+ * @param n has the form (bag.choose A) where A is a constant bag
+ * @return x if n has the form (bag.choose (bag x c)). Otherwise an error is
+ * thrown.
+ */
+ static Node evaluateChoose(TNode n);
+ /**
+ * @param n has the form (bag.card A) where A is a constant bag
+ * @return the number of elements in bag A
+ */
+ static Node evaluateCard(TNode n);
+ /**
+ * @param n has the form (bag.is_singleton A) where A is a constant bag
+ * @return whether the bag A has cardinality one.
+ */
+ static Node evaluateIsSingleton(TNode n);
+ /**
+ * @param n has the form (bag.from_set A) where A is a constant set
+ * @return a constant bag that contains exactly the elements in A.
+ */
+ static Node evaluateFromSet(TNode n);
+ /**
+ * @param n has the form (bag.to_set A) where A is a constant bag
+ * @return a constant set constructed from the elements in A.
+ */
+ static Node evaluateToSet(TNode n);
+ /**
+ * @param n has the form (bag.map f A) where A is a constant bag
+ * @return a constant bag constructed from the images of elements in A.
+ */
+ static Node evaluateBagMap(TNode n);
+};
+} // namespace bags
+} // namespace theory
+} // namespace cvc5
+
+#endif /* CVC5__THEORY__BAGS__NORMAL_FORM_H */
#include "expr/emptybag.h"
#include "smt/logic_exception.h"
+#include "theory/bags/bags_utils.h"
#include "theory/bags/inference_generator.h"
#include "theory/bags/inference_manager.h"
-#include "theory/bags/normal_form.h"
#include "theory/bags/solver_state.h"
#include "theory/bags/term_registry.h"
#include "theory/uf/equality_engine_iterator.h"
return inferInfo;
}
+InferInfo InferenceGenerator::filterDownwards(Node n, Node e)
+{
+ Assert(n.getKind() == BAG_FILTER && n[1].getType().isBag());
+ Assert(e.getType().isSubtypeOf(n[1].getType().getBagElementType()));
+
+ Node P = n[0];
+ Node A = n[1];
+ InferInfo inferInfo(d_im, InferenceId::BAGS_FILTER_DOWN);
+
+ Node countA = getMultiplicityTerm(e, A);
+ Node skolem = getSkolem(n, inferInfo);
+ Node count = getMultiplicityTerm(e, skolem);
+
+ Node member = d_nm->mkNode(GEQ, count, d_one);
+ Node pOfe = d_nm->mkNode(APPLY_UF, P, e);
+ Node equal = count.eqNode(countA);
+
+ inferInfo.d_conclusion = pOfe.andNode(equal);
+ inferInfo.d_premises.push_back(member);
+ return inferInfo;
+}
+
+InferInfo InferenceGenerator::filterUpwards(Node n, Node e)
+{
+ Assert(n.getKind() == BAG_FILTER && n[1].getType().isBag());
+ Assert(e.getType().isSubtypeOf(n[1].getType().getBagElementType()));
+
+ Node P = n[0];
+ Node A = n[1];
+ InferInfo inferInfo(d_im, InferenceId::BAGS_FILTER_UP);
+
+ Node countA = getMultiplicityTerm(e, A);
+ Node skolem = getSkolem(n, inferInfo);
+ Node count = getMultiplicityTerm(e, skolem);
+
+ Node member = d_nm->mkNode(GEQ, countA, d_one);
+ Node pOfe = d_nm->mkNode(APPLY_UF, P, e);
+ Node equal = count.eqNode(countA);
+ Node included = pOfe.andNode(equal);
+ Node equalZero = count.eqNode(d_zero);
+ Node excluded = pOfe.notNode().andNode(equalZero);
+ inferInfo.d_conclusion = included.orNode(excluded);
+ inferInfo.d_premises.push_back(member);
+ return inferInfo;
+}
+
} // namespace bags
} // namespace theory
} // namespace cvc5
*/
InferInfo mapUpwards(Node n, Node uf, Node preImageSize, Node y, Node x);
+ /**
+ * @param n is (bag.filter p A) where p is a function (-> E Bool),
+ * A a bag of type (Bag E)
+ * @param e is an element of type E
+ * @return an inference that represents the following implication
+ * (=>
+ * (bag.member e skolem)
+ * (and
+ * (p e)
+ * (= (bag.count e skolem) (bag.count A)))
+ * where skolem is a variable equals (bag.filter p A)
+ */
+ InferInfo filterDownwards(Node n, Node e);
+
+ /**
+ * @param n is (bag.filter p A) where p is a function (-> E Bool),
+ * A a bag of type (Bag E)
+ * @param e is an element of type E
+ * @return an inference that represents the following implication
+ * (=>
+ * (bag.member e A)
+ * (or
+ * (and (p e) (= (bag.count e skolem) (bag.count A)))
+ * (and (not (p e)) (= (bag.count e skolem) 0)))
+ * where skolem is a variable equals (bag.filter p A)
+ */
+ InferInfo filterUpwards(Node n, Node e);
+
/**
* @param element of type T
* @param bag of type (bag T)
# of the second argument, a bag of type (Bag T1), and returns a bag of type (Bag T2).
operator BAG_MAP 2 "bag map function"
+# The bag.filter operator takes a predicate of type (-> T Bool) and a bag of type (Bag T)
+# and return the same bag excluding those elements that do not satisfy the predicate
+operator BAG_FILTER 2 "bag filter operator"
+
# bag.fold operator combines elements of a bag into a single value.
# (bag.fold f t B) folds the elements of bag B starting with term t and using
# the combining function f.
typerule BAG_FROM_SET ::cvc5::theory::bags::FromSetTypeRule
typerule BAG_TO_SET ::cvc5::theory::bags::ToSetTypeRule
typerule BAG_MAP ::cvc5::theory::bags::BagMapTypeRule
+typerule BAG_FILTER ::cvc5::theory::bags::BagFilterTypeRule
typerule BAG_FOLD ::cvc5::theory::bags::BagFoldTypeRule
construle BAG_UNION_DISJOINT ::cvc5::theory::bags::BinaryOperatorTypeRule
+++ /dev/null
-/******************************************************************************
- * Top contributors (to current version):
- * Mudathir Mohamed, Aina Niemetz
- *
- * This file is part of the cvc5 project.
- *
- * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
- * in the top-level source directory and their institutional affiliations.
- * All rights reserved. See the file COPYING in the top-level source
- * directory for licensing information.
- * ****************************************************************************
- *
- * Normal form for bag constants.
- */
-#include "normal_form.h"
-
-#include "expr/emptybag.h"
-#include "smt/logic_exception.h"
-#include "theory/sets/normal_form.h"
-#include "theory/type_enumerator.h"
-#include "util/rational.h"
-
-using namespace cvc5::kind;
-
-namespace cvc5 {
-namespace theory {
-namespace bags {
-
-bool NormalForm::isConstant(TNode n)
-{
- if (n.getKind() == BAG_EMPTY)
- {
- // empty bags are already normalized
- return true;
- }
- if (n.getKind() == BAG_MAKE)
- {
- // see the implementation in MkBagTypeRule::computeIsConst
- return n.isConst();
- }
- if (n.getKind() == BAG_UNION_DISJOINT)
- {
- if (!(n[0].getKind() == kind::BAG_MAKE && n[0].isConst()))
- {
- // the first child is not a constant
- return false;
- }
- // store the previous element to check the ordering of elements
- Node previousElement = n[0][0];
- Node current = n[1];
- while (current.getKind() == BAG_UNION_DISJOINT)
- {
- if (!(current[0].getKind() == kind::BAG_MAKE && current[0].isConst()))
- {
- // the current element is not a constant
- return false;
- }
- if (previousElement >= current[0][0])
- {
- // the ordering is violated
- return false;
- }
- previousElement = current[0][0];
- current = current[1];
- }
- // check last element
- if (!(current.getKind() == kind::BAG_MAKE && current.isConst()))
- {
- // the last element is not a constant
- return false;
- }
- if (previousElement >= current[0])
- {
- // the ordering is violated
- return false;
- }
- return true;
- }
-
- // only nodes with kinds EMPTY_BAG, BAG_MAKE, and BAG_UNION_DISJOINT can be
- // constants
- return false;
-}
-
-bool NormalForm::areChildrenConstants(TNode n)
-{
- return std::all_of(n.begin(), n.end(), [](Node c) { return c.isConst(); });
-}
-
-Node NormalForm::evaluate(TNode n)
-{
- Assert(areChildrenConstants(n));
- if (n.isConst())
- {
- // a constant node is already in a normal form
- return n;
- }
- switch (n.getKind())
- {
- case BAG_MAKE: return evaluateMakeBag(n);
- case BAG_COUNT: return evaluateBagCount(n);
- case BAG_DUPLICATE_REMOVAL: return evaluateDuplicateRemoval(n);
- case BAG_UNION_DISJOINT: return evaluateUnionDisjoint(n);
- case BAG_UNION_MAX: return evaluateUnionMax(n);
- case BAG_INTER_MIN: return evaluateIntersectionMin(n);
- case BAG_DIFFERENCE_SUBTRACT: return evaluateDifferenceSubtract(n);
- case BAG_DIFFERENCE_REMOVE: return evaluateDifferenceRemove(n);
- case BAG_CARD: return evaluateCard(n);
- case BAG_IS_SINGLETON: return evaluateIsSingleton(n);
- case BAG_FROM_SET: return evaluateFromSet(n);
- case BAG_TO_SET: return evaluateToSet(n);
- case BAG_MAP: return evaluateBagMap(n);
- case BAG_FOLD: return evaluateBagFold(n);
- default: break;
- }
- Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n
- << std::endl;
-}
-
-template <typename T1, typename T2, typename T3, typename T4, typename T5>
-Node NormalForm::evaluateBinaryOperation(const TNode& n,
- T1&& equal,
- T2&& less,
- T3&& greaterOrEqual,
- T4&& remainderOfA,
- T5&& remainderOfB)
-{
- std::map<Node, Rational> elementsA = getBagElements(n[0]);
- std::map<Node, Rational> elementsB = getBagElements(n[1]);
- std::map<Node, Rational> elements;
-
- std::map<Node, Rational>::const_iterator itA = elementsA.begin();
- std::map<Node, Rational>::const_iterator itB = elementsB.begin();
-
- Trace("bags-evaluate") << "[NormalForm::evaluateBinaryOperation "
- << n.getKind() << "] " << std::endl
- << "elements A: " << elementsA << std::endl
- << "elements B: " << elementsB << std::endl;
-
- while (itA != elementsA.end() && itB != elementsB.end())
- {
- if (itA->first == itB->first)
- {
- equal(elements, itA, itB);
- itA++;
- itB++;
- }
- else if (itA->first < itB->first)
- {
- less(elements, itA, itB);
- itA++;
- }
- else
- {
- greaterOrEqual(elements, itA, itB);
- itB++;
- }
- }
-
- // handle the remaining elements from A
- remainderOfA(elements, elementsA, itA);
- // handle the remaining elements from B
- remainderOfB(elements, elementsB, itB);
-
- Trace("bags-evaluate") << "elements: " << elements << std::endl;
- Node bag = constructConstantBagFromElements(n.getType(), elements);
- Trace("bags-evaluate") << "bag: " << bag << std::endl;
- return bag;
-}
-
-std::map<Node, Rational> NormalForm::getBagElements(TNode n)
-{
- std::map<Node, Rational> elements;
- if (n.getKind() == BAG_EMPTY)
- {
- return elements;
- }
- while (n.getKind() == kind::BAG_UNION_DISJOINT)
- {
- Assert(n[0].getKind() == kind::BAG_MAKE);
- Node element = n[0][0];
- Rational count = n[0][1].getConst<Rational>();
- elements[element] = count;
- n = n[1];
- }
- Assert(n.getKind() == kind::BAG_MAKE);
- Node lastElement = n[0];
- Rational lastCount = n[1].getConst<Rational>();
- elements[lastElement] = lastCount;
- return elements;
-}
-
-Node NormalForm::constructConstantBagFromElements(
- TypeNode t, const std::map<Node, Rational>& elements)
-{
- Assert(t.isBag());
- NodeManager* nm = NodeManager::currentNM();
- if (elements.empty())
- {
- return nm->mkConst(EmptyBag(t));
- }
- TypeNode elementType = t.getBagElementType();
- std::map<Node, Rational>::const_reverse_iterator it = elements.rbegin();
- Node bag = nm->mkBag(elementType, it->first, nm->mkConstInt(it->second));
- while (++it != elements.rend())
- {
- Node n = nm->mkBag(elementType, it->first, nm->mkConstInt(it->second));
- bag = nm->mkNode(BAG_UNION_DISJOINT, n, bag);
- }
- return bag;
-}
-
-Node NormalForm::constructBagFromElements(TypeNode t,
- const std::map<Node, Node>& elements)
-{
- Assert(t.isBag());
- NodeManager* nm = NodeManager::currentNM();
- if (elements.empty())
- {
- return nm->mkConst(EmptyBag(t));
- }
- TypeNode elementType = t.getBagElementType();
- std::map<Node, Node>::const_reverse_iterator it = elements.rbegin();
- Node bag = nm->mkBag(elementType, it->first, it->second);
- while (++it != elements.rend())
- {
- Node n = nm->mkBag(elementType, it->first, it->second);
- bag = nm->mkNode(BAG_UNION_DISJOINT, n, bag);
- }
- return bag;
-}
-
-Node NormalForm::evaluateMakeBag(TNode n)
-{
- // the case where n is const should be handled earlier.
- // here we handle the case where the multiplicity is zero or negative
- Assert(n.getKind() == BAG_MAKE && !n.isConst()
- && n[1].getConst<Rational>().sgn() < 1);
- Node emptybag = NodeManager::currentNM()->mkConst(EmptyBag(n.getType()));
- return emptybag;
-}
-
-Node NormalForm::evaluateBagCount(TNode n)
-{
- Assert(n.getKind() == BAG_COUNT);
- // Examples
- // --------
- // - (bag.count "x" (as bag.empty (Bag String))) = 0
- // - (bag.count "x" (bag "y" 5)) = 0
- // - (bag.count "x" (bag "x" 4)) = 4
- // - (bag.count "x" (bag.union_disjoint (bag "x" 4) (bag "y" 5)) = 4
- // - (bag.count "x" (bag.union_disjoint (bag "y" 5) (bag "z" 5)) = 0
-
- std::map<Node, Rational> elements = getBagElements(n[1]);
- std::map<Node, Rational>::iterator it = elements.find(n[0]);
-
- NodeManager* nm = NodeManager::currentNM();
- if (it != elements.end())
- {
- Node count = nm->mkConstInt(it->second);
- return count;
- }
- return nm->mkConstInt(Rational(0));
-}
-
-Node NormalForm::evaluateDuplicateRemoval(TNode n)
-{
- Assert(n.getKind() == BAG_DUPLICATE_REMOVAL);
-
- // Examples
- // --------
- // - (bag.duplicate_removal (as bag.empty (Bag String))) = (as bag.empty (Bag
- // String))
- // - (bag.duplicate_removal (bag "x" 4)) = (bag "x" 1)
- // - (bag.duplicate_removal (bag.disjoint_union (bag "x" 3) (bag "y" 5)) =
- // (bag.disjoint_union (bag "x" 1) (bag "y" 1)
-
- std::map<Node, Rational> oldElements = getBagElements(n[0]);
- // copy elements from the old bag
- std::map<Node, Rational> newElements(oldElements);
- Rational one = Rational(1);
- std::map<Node, Rational>::iterator it;
- for (it = newElements.begin(); it != newElements.end(); it++)
- {
- it->second = one;
- }
- Node bag = constructConstantBagFromElements(n[0].getType(), newElements);
- return bag;
-}
-
-Node NormalForm::evaluateUnionDisjoint(TNode n)
-{
- Assert(n.getKind() == BAG_UNION_DISJOINT);
- // Example
- // -------
- // input: (bag.union_disjoint A B)
- // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
- // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
- // output:
- // (bag.union_disjoint A B)
- // where A = (bag "x" 7)
- // B = (bag.union_disjoint (bag "y" 1) (bag "z" 2)))
-
- auto equal = [](std::map<Node, Rational>& elements,
- std::map<Node, Rational>::const_iterator& itA,
- std::map<Node, Rational>::const_iterator& itB) {
- // compute the sum of the multiplicities
- elements[itA->first] = itA->second + itB->second;
- };
-
- auto less = [](std::map<Node, Rational>& elements,
- std::map<Node, Rational>::const_iterator& itA,
- std::map<Node, Rational>::const_iterator& itB) {
- // add the element to the result
- elements[itA->first] = itA->second;
- };
-
- auto greaterOrEqual = [](std::map<Node, Rational>& elements,
- std::map<Node, Rational>::const_iterator& itA,
- std::map<Node, Rational>::const_iterator& itB) {
- // add the element to the result
- elements[itB->first] = itB->second;
- };
-
- auto remainderOfA = [](std::map<Node, Rational>& elements,
- std::map<Node, Rational>& elementsA,
- std::map<Node, Rational>::const_iterator& itA) {
- // append the remainder of A
- while (itA != elementsA.end())
- {
- elements[itA->first] = itA->second;
- itA++;
- }
- };
-
- auto remainderOfB = [](std::map<Node, Rational>& elements,
- std::map<Node, Rational>& elementsB,
- std::map<Node, Rational>::const_iterator& itB) {
- // append the remainder of B
- while (itB != elementsB.end())
- {
- elements[itB->first] = itB->second;
- itB++;
- }
- };
-
- return evaluateBinaryOperation(
- n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
-}
-
-Node NormalForm::evaluateUnionMax(TNode n)
-{
- Assert(n.getKind() == BAG_UNION_MAX);
- // Example
- // -------
- // input: (bag.union_max A B)
- // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
- // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
- // output:
- // (bag.union_disjoint A B)
- // where A = (bag "x" 4)
- // B = (bag.union_disjoint (bag "y" 1) (bag "z" 2)))
-
- auto equal = [](std::map<Node, Rational>& elements,
- std::map<Node, Rational>::const_iterator& itA,
- std::map<Node, Rational>::const_iterator& itB) {
- // compute the maximum multiplicity
- elements[itA->first] = std::max(itA->second, itB->second);
- };
-
- auto less = [](std::map<Node, Rational>& elements,
- std::map<Node, Rational>::const_iterator& itA,
- std::map<Node, Rational>::const_iterator& itB) {
- // add to the result
- elements[itA->first] = itA->second;
- };
-
- auto greaterOrEqual = [](std::map<Node, Rational>& elements,
- std::map<Node, Rational>::const_iterator& itA,
- std::map<Node, Rational>::const_iterator& itB) {
- // add to the result
- elements[itB->first] = itB->second;
- };
-
- auto remainderOfA = [](std::map<Node, Rational>& elements,
- std::map<Node, Rational>& elementsA,
- std::map<Node, Rational>::const_iterator& itA) {
- // append the remainder of A
- while (itA != elementsA.end())
- {
- elements[itA->first] = itA->second;
- itA++;
- }
- };
-
- auto remainderOfB = [](std::map<Node, Rational>& elements,
- std::map<Node, Rational>& elementsB,
- std::map<Node, Rational>::const_iterator& itB) {
- // append the remainder of B
- while (itB != elementsB.end())
- {
- elements[itB->first] = itB->second;
- itB++;
- }
- };
-
- return evaluateBinaryOperation(
- n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
-}
-
-Node NormalForm::evaluateIntersectionMin(TNode n)
-{
- Assert(n.getKind() == BAG_INTER_MIN);
- // Example
- // -------
- // input: (bag.inter_min A B)
- // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
- // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
- // output:
- // (bag "x" 3)
-
- auto equal = [](std::map<Node, Rational>& elements,
- std::map<Node, Rational>::const_iterator& itA,
- std::map<Node, Rational>::const_iterator& itB) {
- // compute the minimum multiplicity
- elements[itA->first] = std::min(itA->second, itB->second);
- };
-
- auto less = [](std::map<Node, Rational>& elements,
- std::map<Node, Rational>::const_iterator& itA,
- std::map<Node, Rational>::const_iterator& itB) {
- // do nothing
- };
-
- auto greaterOrEqual = [](std::map<Node, Rational>& elements,
- std::map<Node, Rational>::const_iterator& itA,
- std::map<Node, Rational>::const_iterator& itB) {
- // do nothing
- };
-
- auto remainderOfA = [](std::map<Node, Rational>& elements,
- std::map<Node, Rational>& elementsA,
- std::map<Node, Rational>::const_iterator& itA) {
- // do nothing
- };
-
- auto remainderOfB = [](std::map<Node, Rational>& elements,
- std::map<Node, Rational>& elementsB,
- std::map<Node, Rational>::const_iterator& itB) {
- // do nothing
- };
-
- return evaluateBinaryOperation(
- n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
-}
-
-Node NormalForm::evaluateDifferenceSubtract(TNode n)
-{
- Assert(n.getKind() == BAG_DIFFERENCE_SUBTRACT);
- // Example
- // -------
- // input: (bag.difference_subtract A B)
- // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
- // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
- // output:
- // (bag.union_disjoint (bag "x" 1) (bag "z" 2))
-
- auto equal = [](std::map<Node, Rational>& elements,
- std::map<Node, Rational>::const_iterator& itA,
- std::map<Node, Rational>::const_iterator& itB) {
- // subtract the multiplicities
- elements[itA->first] = itA->second - itB->second;
- };
-
- auto less = [](std::map<Node, Rational>& elements,
- std::map<Node, Rational>::const_iterator& itA,
- std::map<Node, Rational>::const_iterator& itB) {
- // itA->first is not in B, so we add it to the difference subtract
- elements[itA->first] = itA->second;
- };
-
- auto greaterOrEqual = [](std::map<Node, Rational>& elements,
- std::map<Node, Rational>::const_iterator& itA,
- std::map<Node, Rational>::const_iterator& itB) {
- // itB->first is not in A, so we just skip it
- };
-
- auto remainderOfA = [](std::map<Node, Rational>& elements,
- std::map<Node, Rational>& elementsA,
- std::map<Node, Rational>::const_iterator& itA) {
- // append the remainder of A
- while (itA != elementsA.end())
- {
- elements[itA->first] = itA->second;
- itA++;
- }
- };
-
- auto remainderOfB = [](std::map<Node, Rational>& elements,
- std::map<Node, Rational>& elementsB,
- std::map<Node, Rational>::const_iterator& itB) {
- // do nothing
- };
-
- return evaluateBinaryOperation(
- n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
-}
-
-Node NormalForm::evaluateDifferenceRemove(TNode n)
-{
- Assert(n.getKind() == BAG_DIFFERENCE_REMOVE);
- // Example
- // -------
- // input: (bag.difference_remove A B)
- // where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
- // B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
- // output:
- // (bag "z" 2)
-
- auto equal = [](std::map<Node, Rational>& elements,
- std::map<Node, Rational>::const_iterator& itA,
- std::map<Node, Rational>::const_iterator& itB) {
- // skip the shared element by doing nothing
- };
-
- auto less = [](std::map<Node, Rational>& elements,
- std::map<Node, Rational>::const_iterator& itA,
- std::map<Node, Rational>::const_iterator& itB) {
- // itA->first is not in B, so we add it to the difference remove
- elements[itA->first] = itA->second;
- };
-
- auto greaterOrEqual = [](std::map<Node, Rational>& elements,
- std::map<Node, Rational>::const_iterator& itA,
- std::map<Node, Rational>::const_iterator& itB) {
- // itB->first is not in A, so we just skip it
- };
-
- auto remainderOfA = [](std::map<Node, Rational>& elements,
- std::map<Node, Rational>& elementsA,
- std::map<Node, Rational>::const_iterator& itA) {
- // append the remainder of A
- while (itA != elementsA.end())
- {
- elements[itA->first] = itA->second;
- itA++;
- }
- };
-
- auto remainderOfB = [](std::map<Node, Rational>& elements,
- std::map<Node, Rational>& elementsB,
- std::map<Node, Rational>::const_iterator& itB) {
- // do nothing
- };
-
- return evaluateBinaryOperation(
- n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
-}
-
-Node NormalForm::evaluateChoose(TNode n)
-{
- Assert(n.getKind() == BAG_CHOOSE);
- // Examples
- // --------
- // - (bag.choose (bag "x" 4)) = "x"
-
- if (n[0].getKind() == BAG_MAKE)
- {
- return n[0][0];
- }
- throw LogicException("BAG_CHOOSE_TOTAL is not supported yet");
-}
-
-Node NormalForm::evaluateCard(TNode n)
-{
- Assert(n.getKind() == BAG_CARD);
- // Examples
- // --------
- // - (card (as bag.empty (Bag String))) = 0
- // - (bag.choose (bag "x" 4)) = 4
- // - (bag.choose (bag.union_disjoint (bag "x" 4) (bag "y" 1))) = 5
-
- std::map<Node, Rational> elements = getBagElements(n[0]);
- Rational sum(0);
- for (std::pair<Node, Rational> element : elements)
- {
- sum += element.second;
- }
-
- NodeManager* nm = NodeManager::currentNM();
- Node sumNode = nm->mkConstInt(sum);
- return sumNode;
-}
-
-Node NormalForm::evaluateIsSingleton(TNode n)
-{
- Assert(n.getKind() == BAG_IS_SINGLETON);
- // Examples
- // --------
- // - (bag.is_singleton (as bag.empty (Bag String))) = false
- // - (bag.is_singleton (bag "x" 1)) = true
- // - (bag.is_singleton (bag "x" 4)) = false
- // - (bag.is_singleton (bag.union_disjoint (bag "x" 1) (bag "y" 1)))
- // = false
-
- if (n[0].getKind() == BAG_MAKE && n[0][1].getConst<Rational>().isOne())
- {
- return NodeManager::currentNM()->mkConst(true);
- }
- return NodeManager::currentNM()->mkConst(false);
-}
-
-Node NormalForm::evaluateFromSet(TNode n)
-{
- Assert(n.getKind() == BAG_FROM_SET);
-
- // Examples
- // --------
- // - (bag.from_set (as set.empty (Set String))) = (as bag.empty (Bag String))
- // - (bag.from_set (set.singleton "x")) = (bag "x" 1)
- // - (bag.from_set (set.union (set.singleton "x") (set.singleton "y"))) =
- // (bag.disjoint_union (bag "x" 1) (bag "y" 1))
-
- NodeManager* nm = NodeManager::currentNM();
- std::set<Node> setElements =
- sets::NormalForm::getElementsFromNormalConstant(n[0]);
- Rational one = Rational(1);
- std::map<Node, Rational> bagElements;
- for (const Node& element : setElements)
- {
- bagElements[element] = one;
- }
- TypeNode bagType = nm->mkBagType(n[0].getType().getSetElementType());
- Node bag = constructConstantBagFromElements(bagType, bagElements);
- return bag;
-}
-
-Node NormalForm::evaluateToSet(TNode n)
-{
- Assert(n.getKind() == BAG_TO_SET);
-
- // Examples
- // --------
- // - (bag.to_set (as bag.empty (Bag String))) = (as set.empty (Set String))
- // - (bag.to_set (bag "x" 4)) = (set.singleton "x")
- // - (bag.to_set (bag.disjoint_union (bag "x" 3) (bag "y" 5)) =
- // (set.union (set.singleton "x") (set.singleton "y")))
-
- NodeManager* nm = NodeManager::currentNM();
- std::map<Node, Rational> bagElements = getBagElements(n[0]);
- std::set<Node> setElements;
- std::map<Node, Rational>::const_reverse_iterator it;
- for (it = bagElements.rbegin(); it != bagElements.rend(); it++)
- {
- setElements.insert(it->first);
- }
- TypeNode setType = nm->mkSetType(n[0].getType().getBagElementType());
- Node set = sets::NormalForm::elementsToSet(setElements, setType);
- return set;
-}
-
-Node NormalForm::evaluateBagMap(TNode n)
-{
- Assert(n.getKind() == BAG_MAP);
-
- // Examples
- // --------
- // - (bag.map ((lambda ((x String)) "z")
- // (bag.union_disjoint (bag "a" 2) (bag "b" 3)) =
- // (bag.union_disjoint
- // (bag ((lambda ((x String)) "z") "a") 2)
- // (bag ((lambda ((x String)) "z") "b") 3)) =
- // (bag "z" 5)
-
- std::map<Node, Rational> elements = NormalForm::getBagElements(n[1]);
- std::map<Node, Rational> mappedElements;
- std::map<Node, Rational>::iterator it = elements.begin();
- NodeManager* nm = NodeManager::currentNM();
- while (it != elements.end())
- {
- Node mappedElement = nm->mkNode(APPLY_UF, n[0], it->first);
- mappedElements[mappedElement] = it->second;
- ++it;
- }
- TypeNode t = nm->mkBagType(n[0].getType().getRangeType());
- Node ret = NormalForm::constructConstantBagFromElements(t, mappedElements);
- return ret;
-}
-
-Node NormalForm::evaluateBagFold(TNode n)
-{
- Assert(n.getKind() == BAG_FOLD);
-
- // Examples
- // --------
- // minimum string
- // - (bag.fold
- // ((lambda ((x String) (y String)) (ite (str.< x y) x y))
- // ""
- // (bag.union_disjoint (bag "a" 2) (bag "b" 3))
- // = "a"
-
- Node f = n[0]; // combining function
- Node ret = n[1]; // initial value
- Node A = n[2]; // bag
- std::map<Node, Rational> elements = NormalForm::getBagElements(A);
-
- std::map<Node, Rational>::iterator it = elements.begin();
- NodeManager* nm = NodeManager::currentNM();
- while (it != elements.end())
- {
- // apply the combination function n times, where n is the multiplicity
- Rational count = it->second;
- Assert(count.sgn() >= 0) << "negative multiplicity" << std::endl;
- while (!count.isZero())
- {
- ret = nm->mkNode(APPLY_UF, f, it->first, ret);
- count = count - 1;
- }
- ++it;
- }
- return ret;
-}
-
-} // namespace bags
-} // namespace theory
-} // namespace cvc5
+++ /dev/null
-/******************************************************************************
- * Top contributors (to current version):
- * Mudathir Mohamed
- *
- * This file is part of the cvc5 project.
- *
- * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
- * in the top-level source directory and their institutional affiliations.
- * All rights reserved. See the file COPYING in the top-level source
- * directory for licensing information.
- * ****************************************************************************
- *
- * Normal form for bag constants.
- */
-
-#include <expr/node.h>
-
-#include "cvc5_private.h"
-
-#ifndef CVC5__THEORY__BAGS__NORMAL_FORM_H
-#define CVC5__THEORY__BAGS__NORMAL_FORM_H
-
-namespace cvc5 {
-namespace theory {
-namespace bags {
-
-class NormalForm
-{
- public:
- /**
- * Returns true if n is considered a to be a (canonical) constant bag value.
- * A canonical bag value is one whose AST is:
- * (bag.union_disjoint (bag e1 c1) ...
- * (bag.union_disjoint (bag e_{n-1} c_{n-1}) (bag e_n c_n))))
- * where c1 ... cn are positive integers, e1 ... en are constants, and the
- * node identifier of these constants are such that: e1 < ... < en.
- * Also handles the corner cases of empty bag and bag constructed by bag
- */
- static bool isConstant(TNode n);
- /**
- * check whether all children of the given node are constants
- */
- static bool areChildrenConstants(TNode n);
- /**
- * evaluate the node n to a constant value.
- * As a precondition, children of n should be constants.
- */
- static Node evaluate(TNode n);
-
- /**
- * get the elements along with their multiplicities in a given bag
- * @param n a constant node whose type is a bag
- * @return a map whose keys are constant elements and values are
- * multiplicities
- */
- static std::map<Node, Rational> getBagElements(TNode n);
-
- /**
- * construct a constant bag from constant elements
- * @param t the type of the returned bag
- * @param elements a map whose keys are constant elements and values are
- * multiplicities
- * @return a constant bag that contains
- */
- static Node constructConstantBagFromElements(
- TypeNode t, const std::map<Node, Rational>& elements);
-
- /**
- * construct a constant bag from node elements
- * @param t the type of the returned bag
- * @param elements a map whose keys are constant elements and values are
- * multiplicities
- * @return a constant bag that contains
- */
- static Node constructBagFromElements(TypeNode t,
- const std::map<Node, Node>& elements);
-
- /**
- * @param n has the form (bag.fold f t A) where A is a constant bag
- * @return a single value which is the result of the fold
- */
- static Node evaluateBagFold(TNode n);
-
- private:
- /**
- * a high order helper function that return a constant bag that is the result
- * of (op A B) where op is a binary operator and A, B are constant bags.
- * The result is computed from the elements of A (elementsA with iterator itA)
- * and elements of B (elementsB with iterator itB).
- * The arguments below specify how these iterators are used to generate the
- * elements of the result (elements).
- * @param n a node whose kind is a binary operator (bag.union_disjoint,
- * union_max, intersection_min, difference_subtract, difference_remove) and
- * whose children are constant bags.
- * @param equal a lambda expression that receives (elements, itA, itB) and
- * specify the action that needs to be taken when the elements of itA, itB are
- * equal.
- * @param less a lambda expression that receives (elements, itA, itB) and
- * specify the action that needs to be taken when the element itA is less than
- * the element of itB.
- * @param greaterOrEqual less a lambda expression that receives (elements,
- * itA, itB) and specify the action that needs to be taken when the element
- * itA is greater than or equal than the element of itB.
- * @param remainderOfA a lambda expression that receives (elements, elementsA,
- * itA) and specify the action that needs to be taken to the remaining
- * elements of A when all elements of B are visited.
- * @param remainderOfB a lambda expression that receives (elements, elementsB,
- * itB) and specify the action that needs to be taken to the remaining
- * elements of B when all elements of A are visited.
- * @return a constant bag that the result of (op n[0] n[1])
- */
- template <typename T1, typename T2, typename T3, typename T4, typename T5>
- static Node evaluateBinaryOperation(const TNode& n,
- T1&& equal,
- T2&& less,
- T3&& greaterOrEqual,
- T4&& remainderOfA,
- T5&& remainderOfB);
- /**
- * evaluate n as follows:
- * - (bag a 0) = (as bag.empty T) where T is the type of the original bag
- * - (bag a (-c)) = (as bag.empty T) where T is the type the original bag,
- * and c > 0 is a constant
- */
- static Node evaluateMakeBag(TNode n);
-
- /**
- * returns the multiplicity in a constant bag
- * @param n has the form (bag.count x A) where x, A are constants
- * @return the multiplicity of element x in bag A.
- */
- static Node evaluateBagCount(TNode n);
-
- /**
- * @param n has the form (bag.duplicate_removal A) where A is a constant bag
- * @return a constant bag constructed from the elements in A where each
- * element has multiplicity one
- */
- static Node evaluateDuplicateRemoval(TNode n);
-
- /**
- * evaluates union disjoint node such that the returned node is a canonical
- * bag that has the form
- * (bag.union_disjoint (bag e1 c1) ...
- * (bag.union_disjoint * (bag e_{n-1} c_{n-1}) (bag e_n c_n)))) where
- * c1... cn are positive integers, e1 ... en are constants, and the node
- * identifier of these constants are such that: e1 < ... < en.
- * @param n has the form (bag.union_disjoint A B) where A, B are constant bags
- * @return the union disjoint of A and B
- */
- static Node evaluateUnionDisjoint(TNode n);
- /**
- * @param n has the form (bag.union_max A B) where A, B are constant bags
- * @return the union max of A and B
- */
- static Node evaluateUnionMax(TNode n);
- /**
- * @param n has the form (bag.inter_min A B) where A, B are constant bags
- * @return the intersection min of A and B
- */
- static Node evaluateIntersectionMin(TNode n);
- /**
- * @param n has the form (bag.difference_subtract A B) where A, B are constant
- * bags
- * @return the difference subtract of A and B
- */
- static Node evaluateDifferenceSubtract(TNode n);
- /**
- * @param n has the form (bag.difference_remove A B) where A, B are constant
- * bags
- * @return the difference remove of A and B
- */
- static Node evaluateDifferenceRemove(TNode n);
- /**
- * @param n has the form (bag.choose A) where A is a constant bag
- * @return x if n has the form (bag.choose (bag x c)). Otherwise an error is
- * thrown.
- */
- static Node evaluateChoose(TNode n);
- /**
- * @param n has the form (bag.card A) where A is a constant bag
- * @return the number of elements in bag A
- */
- static Node evaluateCard(TNode n);
- /**
- * @param n has the form (bag.is_singleton A) where A is a constant bag
- * @return whether the bag A has cardinality one.
- */
- static Node evaluateIsSingleton(TNode n);
- /**
- * @param n has the form (bag.from_set A) where A is a constant set
- * @return a constant bag that contains exactly the elements in A.
- */
- static Node evaluateFromSet(TNode n);
- /**
- * @param n has the form (bag.to_set A) where A is a constant bag
- * @return a constant set constructed from the elements in A.
- */
- static Node evaluateToSet(TNode n);
- /**
- * @param n has the form (bag.map f A) where A is a constant bag
- * @return a constant bag constructed from the images of elements in A.
- */
- static Node evaluateBagMap(TNode n);
-};
-} // namespace bags
-} // namespace theory
-} // namespace cvc5
-
-#endif /* CVC5__THEORY__BAGS__NORMAL_FORM_H */
case Rewrite::EQ_CONST_FALSE: return "EQ_CONST_FALSE";
case Rewrite::EQ_REFL: return "EQ_REFL";
case Rewrite::EQ_SYM: return "EQ_SYM";
+ case Rewrite::FILTER_CONST: return "FILTER_CONST";
+ case Rewrite::FILTER_BAG_MAKE: return "FILTER_BAG_MAKE";
+ case Rewrite::FILTER_UNION_DISJOINT: return "FILTER_UNION_DISJOINT";
case Rewrite::FROM_SINGLETON: return "FROM_SINGLETON";
case Rewrite::FOLD_BAG: return "FOLD_BAG";
case Rewrite::FOLD_CONST: return "FOLD_CONST";
EQ_CONST_FALSE,
EQ_REFL,
EQ_SYM,
+ FILTER_CONST,
+ FILTER_BAG_MAKE,
+ FILTER_UNION_DISJOINT,
FROM_SINGLETON,
FOLD_BAG,
FOLD_CONST,
#include "expr/skolem_manager.h"
#include "proof/proof_checker.h"
#include "smt/logic_exception.h"
-#include "theory/bags/normal_form.h"
+#include "theory/bags/bags_utils.h"
#include "theory/quantifiers/fmf/bounded_integers.h"
#include "theory/rewriter.h"
#include "theory/theory_model.h"
Node value = m->getRepresentative(countSkolem);
elementReps[key] = value;
}
- Node constructedBag = NormalForm::constructBagFromElements(tn, elementReps);
+ Node constructedBag = BagsUtils::constructBagFromElements(tn, elementReps);
constructedBag = rewrite(constructedBag);
Trace("bags-model") << "constructed bag for " << n
<< " is: " << constructedBag << std::endl;
if (constructedRational < rCardRational
&& !d_env.isFiniteType(elementType))
{
- Node newElement = nm->getSkolemManager()->mkDummySkolem("slack", elementType);
+ Node newElement =
+ nm->getSkolemManager()->mkDummySkolem("slack", elementType);
Trace("bags-model") << "newElement is " << newElement << std::endl;
Rational difference = rCardRational - constructedRational;
Node multiplicity = nm->mkConst(CONST_RATIONAL, difference);
#include "theory/bags/theory_bags_type_enumerator.h"
#include "expr/emptybag.h"
-#include "theory/bags/normal_form.h"
+#include "theory/bags/bags_utils.h"
#include "theory_bags_type_enumerator.h"
#include "util/rational.h"
else
{
// increase the multiplicity of one of the elements in the current bag
- std::map<Node, Rational> elements =
- NormalForm::getBagElements(d_currentBag);
+ std::map<Node, Rational> elements = BagsUtils::getBagElements(d_currentBag);
Node element = elements.begin()->first;
elements[element] = elements[element] + Rational(1);
- d_currentBag = NormalForm::constructConstantBagFromElements(
+ d_currentBag = BagsUtils::constructConstantBagFromElements(
d_currentBag.getType(), elements);
}
#include "base/check.h"
#include "expr/emptybag.h"
#include "theory/bags/bag_make_op.h"
-#include "theory/bags/normal_form.h"
+#include "theory/bags/bags_utils.h"
#include "util/cardinality.h"
#include "util/rational.h"
// only UNION_DISJOINT has a const rule in kinds.
// Other binary operators do not have const rules in kinds
Assert(n.getKind() == kind::BAG_UNION_DISJOINT);
- return NormalForm::isConstant(n);
+ return BagsUtils::isConstant(n);
}
TypeNode SubBagTypeRule::computeType(NodeManager* nodeManager,
return retType;
}
+TypeNode BagFilterTypeRule::computeType(NodeManager* nodeManager,
+ TNode n,
+ bool check)
+{
+ Assert(n.getKind() == kind::BAG_FILTER);
+ TypeNode functionType = n[0].getType(check);
+ TypeNode bagType = n[1].getType(check);
+ if (check)
+ {
+ if (!bagType.isBag())
+ {
+ throw TypeCheckingExceptionPrivate(
+ n,
+ "bag.filter operator expects a bag in the second argument, "
+ "a non-bag is found");
+ }
+
+ TypeNode elementType = bagType.getBagElementType();
+
+ if (!(functionType.isFunction()))
+ {
+ std::stringstream ss;
+ ss << "Operator " << n.getKind() << " expects a function of type (-> "
+ << elementType << " Bool) as a first argument. "
+ << "Found a term of type '" << functionType << "'.";
+ throw TypeCheckingExceptionPrivate(n, ss.str());
+ }
+ std::vector<TypeNode> argTypes = functionType.getArgTypes();
+ NodeManager* nm = NodeManager::currentNM();
+ if (!(argTypes.size() == 1 && argTypes[0] == elementType
+ && functionType.getRangeType() == nm->booleanType()))
+ {
+ std::stringstream ss;
+ ss << "Operator " << n.getKind() << " expects a function of type (-> "
+ << elementType << " Bool). "
+ << "Found a function of type '" << functionType << "'.";
+ throw TypeCheckingExceptionPrivate(n, ss.str());
+ }
+ }
+ return bagType;
+}
+
TypeNode BagFoldTypeRule::computeType(NodeManager* nodeManager,
TNode n,
bool check)
static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
}; /* struct BagMapTypeRule */
+/**
+ * Type rule for (bag.filter p B) to make sure p is a unary predicate of type
+ * (-> T Bool) where B is a bag of type (Bag T)
+ */
+struct BagFilterTypeRule
+{
+ static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
+}; /* struct BagFilterTypeRule */
+
/**
* Type rule for (bag.fold f t A) to make sure f is a binary operation of type
* (-> T1 T2 T2), t of type T2, and B is a bag of type (Bag T1)
case InferenceId::BAGS_DIFFERENCE_REMOVE: return "BAGS_DIFFERENCE_REMOVE";
case InferenceId::BAGS_DUPLICATE_REMOVAL: return "BAGS_DUPLICATE_REMOVAL";
case InferenceId::BAGS_MAP: return "BAGS_MAP";
+ case InferenceId::BAGS_FILTER_DOWN: return "BAGS_FILTER_DOWN";
+ case InferenceId::BAGS_FILTER_UP: return "BAGS_FILTER_UP";
case InferenceId::BAGS_FOLD: return "BAGS_FOLD";
case InferenceId::BAGS_CARD: return "BAGS_CARD";
BAGS_DIFFERENCE_REMOVE,
BAGS_DUPLICATE_REMOVAL,
BAGS_MAP,
+ BAGS_FILTER_DOWN,
+ BAGS_FILTER_UP,
BAGS_FOLD,
BAGS_CARD,
// ---------------------------------- end bags theory
regress1/bags/duplicate_removal1.smt2
regress1/bags/duplicate_removal2.smt2
regress1/bags/emptybag1.smt2
+ regress1/bags/filter1.smt2
+ regress1/bags/filter2.smt2
+ regress1/bags/filter3.smt2
+ regress1/bags/filter4.smt2
+ regress1/bags/filter5.smt2
regress1/bags/fol_0000119.smt2
regress1/bags/fold1.smt2
regress1/bags/fuzzy1.smt2
--- /dev/null
+(set-logic ALL)
+(set-info :status sat)
+(declare-fun A () (Bag Int))
+(declare-fun B () (Bag Int))
+(declare-fun x () Int)
+(declare-fun y () Int)
+(declare-fun p (Int) Bool)
+(assert (= A (bag.union_max (bag x 1) (bag y 2))))
+(assert (= B (bag.filter p A)))
+(assert (distinct (p x) (p y)))
+(check-sat)
--- /dev/null
+(set-logic HO_ALL)
+(set-info :status sat)
+(set-option :fmf-bound true)
+(declare-fun A () (Bag Int))
+(declare-fun B () (Bag Int))
+(declare-fun p (Int) Bool)
+(assert (= B (bag.filter p A)))
+(assert (= (bag.count (- 2) B) 57))
+(check-sat)
--- /dev/null
+(set-logic HO_ALL)
+(set-info :status unsat)
+(set-option :fmf-bound true)
+(declare-fun A () (Bag Int))
+(declare-fun B () (Bag Int))
+(define-fun p ((x Int)) Bool (> x 1))
+(assert (= B (bag.filter p A)))
+(assert (= (bag.count 3 B) 57))
+(assert (= (bag.count 3 B) 58))
+(check-sat)
--- /dev/null
+(set-logic HO_ALL)
+(set-info :status unsat)
+(declare-fun A () (Bag Int))
+(declare-fun B () (Bag Int))
+(declare-fun element () Int)
+(declare-fun p (Int) Bool)
+(assert (= B (bag.filter p A)))
+(assert (p element))
+(assert (not (bag.member element B)))
+(assert (bag.member element A))
+(check-sat)
--- /dev/null
+(set-logic HO_ALL)
+(set-info :status unsat)
+(declare-fun A () (Bag Int))
+(declare-fun B () (Bag Int))
+(declare-fun element () Int)
+(declare-fun p (Int) Bool)
+(assert (= B (bag.filter p A)))
+(assert (p element))
+(assert (not (bag.member element A)))
+(assert (bag.member element B))
+(check-sat)
(declare-fun y () Int)
(declare-fun f (Int) Int)
(assert (= A (bag.union_max (bag x 1) (bag y 2))))
-(assert (= A (bag.union_max (bag x 1) (bag y 2))))
(assert (= B (bag.map f A)))
(assert (distinct (f x) (f y) x y))
(check-sat)
#include "expr/emptyset.h"
#include "test_smt.h"
#include "theory/bags/bags_rewriter.h"
-#include "theory/bags/normal_form.h"
+#include "theory/bags/bags_utils.h"
#include "theory/strings/type_enumerator.h"
#include "util/rational.h"
#include "util/string.h"
Node emptybag = d_nodeManager->mkConst(EmptyBag(d_nodeManager->stringType()));
// empty bags are in normal form
ASSERT_TRUE(emptybag.isConst());
- Node n = NormalForm::evaluate(emptybag);
+ Node n = BagsUtils::evaluate(emptybag);
ASSERT_EQ(emptybag, n);
}
ASSERT_FALSE(negative.isConst());
ASSERT_FALSE(zero.isConst());
- ASSERT_EQ(emptybag, NormalForm::evaluate(negative));
- ASSERT_EQ(emptybag, NormalForm::evaluate(zero));
- ASSERT_EQ(positive, NormalForm::evaluate(positive));
+ ASSERT_EQ(emptybag, BagsUtils::evaluate(negative));
+ ASSERT_EQ(emptybag, BagsUtils::evaluate(zero));
+ ASSERT_EQ(positive, BagsUtils::evaluate(positive));
}
TEST_F(TestTheoryWhiteBagsNormalForm, bag_count)
Node input1 = d_nodeManager->mkNode(BAG_COUNT, x, empty);
Node output1 = zero;
- ASSERT_EQ(output1, NormalForm::evaluate(input1));
+ ASSERT_EQ(output1, BagsUtils::evaluate(input1));
Node input2 = d_nodeManager->mkNode(BAG_COUNT, x, y_5);
Node output2 = zero;
- ASSERT_EQ(output2, NormalForm::evaluate(input2));
+ ASSERT_EQ(output2, BagsUtils::evaluate(input2));
Node input3 = d_nodeManager->mkNode(BAG_COUNT, x, x_4);
Node output3 = four;
- ASSERT_EQ(output2, NormalForm::evaluate(input2));
+ ASSERT_EQ(output2, BagsUtils::evaluate(input2));
Node unionDisjointXY = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_4, y_5);
Node input4 = d_nodeManager->mkNode(BAG_COUNT, x, unionDisjointXY);
Node output4 = four;
- ASSERT_EQ(output3, NormalForm::evaluate(input3));
+ ASSERT_EQ(output3, BagsUtils::evaluate(input3));
Node unionDisjointYZ = d_nodeManager->mkNode(BAG_UNION_DISJOINT, y_5, z_5);
Node input5 = d_nodeManager->mkNode(BAG_COUNT, x, unionDisjointYZ);
Node output5 = zero;
- ASSERT_EQ(output4, NormalForm::evaluate(input4));
+ ASSERT_EQ(output4, BagsUtils::evaluate(input4));
}
TEST_F(TestTheoryWhiteBagsNormalForm, duplicate_removal)
EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
Node input1 = d_nodeManager->mkNode(BAG_DUPLICATE_REMOVAL, emptybag);
Node output1 = emptybag;
- ASSERT_EQ(output1, NormalForm::evaluate(input1));
+ ASSERT_EQ(output1, BagsUtils::evaluate(input1));
Node x = d_nodeManager->mkConst(String("x"));
Node y = d_nodeManager->mkConst(String("y"));
Node input2 = d_nodeManager->mkNode(BAG_DUPLICATE_REMOVAL, x_4);
Node output2 = x_1;
- ASSERT_EQ(output2, NormalForm::evaluate(input2));
+ ASSERT_EQ(output2, BagsUtils::evaluate(input2));
Node normalBag = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_4, y_5);
Node input3 = d_nodeManager->mkNode(BAG_DUPLICATE_REMOVAL, normalBag);
Node output3 = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_1, y_1);
- ASSERT_EQ(output3, NormalForm::evaluate(input3));
+ ASSERT_EQ(output3, BagsUtils::evaluate(input3));
}
TEST_F(TestTheoryWhiteBagsNormalForm, union_max)
d_nodeManager->mkNode(BAG_UNION_DISJOINT, y_1, z_2));
ASSERT_TRUE(output.isConst());
- ASSERT_EQ(output, NormalForm::evaluate(input));
+ ASSERT_EQ(output, BagsUtils::evaluate(input));
}
TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint1)
Node unionDisjointAB = d_nodeManager->mkNode(BAG_UNION_DISJOINT, A, B);
// unionDisjointAB is already in a normal form
ASSERT_TRUE(unionDisjointAB.isConst());
- ASSERT_EQ(unionDisjointAB, NormalForm::evaluate(unionDisjointAB));
+ ASSERT_EQ(unionDisjointAB, BagsUtils::evaluate(unionDisjointAB));
Node unionDisjointBA = d_nodeManager->mkNode(BAG_UNION_DISJOINT, B, A);
// unionDisjointAB is the normal form of unionDisjointBA
ASSERT_FALSE(unionDisjointBA.isConst());
- ASSERT_EQ(unionDisjointAB, NormalForm::evaluate(unionDisjointBA));
+ ASSERT_EQ(unionDisjointAB, BagsUtils::evaluate(unionDisjointBA));
Node unionDisjointAB_C =
d_nodeManager->mkNode(BAG_UNION_DISJOINT, unionDisjointAB, C);
// unionDisjointA_BC is the normal form of unionDisjointAB_C
ASSERT_FALSE(unionDisjointAB_C.isConst());
ASSERT_TRUE(unionDisjointA_BC.isConst());
- ASSERT_EQ(unionDisjointA_BC, NormalForm::evaluate(unionDisjointAB_C));
+ ASSERT_EQ(unionDisjointA_BC, BagsUtils::evaluate(unionDisjointAB_C));
Node unionDisjointAA = d_nodeManager->mkNode(BAG_UNION_DISJOINT, A, A);
Node AA =
d_nodeManager->mkConst(CONST_RATIONAL, Rational(4)));
ASSERT_FALSE(unionDisjointAA.isConst());
ASSERT_TRUE(AA.isConst());
- ASSERT_EQ(AA, NormalForm::evaluate(unionDisjointAA));
+ ASSERT_EQ(AA, BagsUtils::evaluate(unionDisjointAA));
}
TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint2)
d_nodeManager->mkNode(BAG_UNION_DISJOINT, y_1, z_2));
ASSERT_TRUE(output.isConst());
- ASSERT_EQ(output, NormalForm::evaluate(input));
+ ASSERT_EQ(output, BagsUtils::evaluate(input));
}
TEST_F(TestTheoryWhiteBagsNormalForm, intersection_min)
Node output = x_3;
ASSERT_TRUE(output.isConst());
- ASSERT_EQ(output, NormalForm::evaluate(input));
+ ASSERT_EQ(output, BagsUtils::evaluate(input));
}
TEST_F(TestTheoryWhiteBagsNormalForm, difference_subtract)
Node output = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_1, z_2);
ASSERT_TRUE(output.isConst());
- ASSERT_EQ(output, NormalForm::evaluate(input));
+ ASSERT_EQ(output, BagsUtils::evaluate(input));
}
TEST_F(TestTheoryWhiteBagsNormalForm, difference_remove)
Node output = z_2;
ASSERT_TRUE(output.isConst());
- ASSERT_EQ(output, NormalForm::evaluate(input));
+ ASSERT_EQ(output, BagsUtils::evaluate(input));
}
TEST_F(TestTheoryWhiteBagsNormalForm, bag_card)
Node input1 = d_nodeManager->mkNode(BAG_CARD, empty);
Node output1 = d_nodeManager->mkConst(CONST_RATIONAL, Rational(0));
- ASSERT_EQ(output1, NormalForm::evaluate(input1));
+ ASSERT_EQ(output1, BagsUtils::evaluate(input1));
Node input2 = d_nodeManager->mkNode(BAG_CARD, x_4);
Node output2 = d_nodeManager->mkConst(CONST_RATIONAL, Rational(4));
- ASSERT_EQ(output2, NormalForm::evaluate(input2));
+ ASSERT_EQ(output2, BagsUtils::evaluate(input2));
Node union_disjoint = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_4, y_1);
Node input3 = d_nodeManager->mkNode(BAG_CARD, union_disjoint);
Node output3 = d_nodeManager->mkConst(CONST_RATIONAL, Rational(5));
- ASSERT_EQ(output3, NormalForm::evaluate(input3));
+ ASSERT_EQ(output3, BagsUtils::evaluate(input3));
}
TEST_F(TestTheoryWhiteBagsNormalForm, is_singleton)
Node input1 = d_nodeManager->mkNode(BAG_IS_SINGLETON, empty);
Node output1 = falseNode;
- ASSERT_EQ(output1, NormalForm::evaluate(input1));
+ ASSERT_EQ(output1, BagsUtils::evaluate(input1));
Node input2 = d_nodeManager->mkNode(BAG_IS_SINGLETON, x_1);
Node output2 = trueNode;
- ASSERT_EQ(output2, NormalForm::evaluate(input2));
+ ASSERT_EQ(output2, BagsUtils::evaluate(input2));
Node input3 = d_nodeManager->mkNode(BAG_IS_SINGLETON, x_4);
Node output3 = falseNode;
- ASSERT_EQ(output2, NormalForm::evaluate(input2));
+ ASSERT_EQ(output2, BagsUtils::evaluate(input2));
Node union_disjoint = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_1, y_1);
Node input4 = d_nodeManager->mkNode(BAG_IS_SINGLETON, union_disjoint);
Node output4 = falseNode;
- ASSERT_EQ(output3, NormalForm::evaluate(input3));
+ ASSERT_EQ(output3, BagsUtils::evaluate(input3));
}
TEST_F(TestTheoryWhiteBagsNormalForm, from_set)
EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
Node input1 = d_nodeManager->mkNode(BAG_FROM_SET, emptyset);
Node output1 = emptybag;
- ASSERT_EQ(output1, NormalForm::evaluate(input1));
+ ASSERT_EQ(output1, BagsUtils::evaluate(input1));
Node x = d_nodeManager->mkConst(String("x"));
Node y = d_nodeManager->mkConst(String("y"));
Node input2 = d_nodeManager->mkNode(BAG_FROM_SET, xSingleton);
Node output2 = x_1;
- ASSERT_EQ(output2, NormalForm::evaluate(input2));
+ ASSERT_EQ(output2, BagsUtils::evaluate(input2));
// for normal sets, the first node is the largest, not smallest
Node normalSet = d_nodeManager->mkNode(SET_UNION, ySingleton, xSingleton);
Node input3 = d_nodeManager->mkNode(BAG_FROM_SET, normalSet);
Node output3 = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_1, y_1);
- ASSERT_EQ(output3, NormalForm::evaluate(input3));
+ ASSERT_EQ(output3, BagsUtils::evaluate(input3));
}
TEST_F(TestTheoryWhiteBagsNormalForm, to_set)
EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
Node input1 = d_nodeManager->mkNode(BAG_TO_SET, emptybag);
Node output1 = emptyset;
- ASSERT_EQ(output1, NormalForm::evaluate(input1));
+ ASSERT_EQ(output1, BagsUtils::evaluate(input1));
Node x = d_nodeManager->mkConst(String("x"));
Node y = d_nodeManager->mkConst(String("y"));
Node input2 = d_nodeManager->mkNode(BAG_TO_SET, x_4);
Node output2 = xSingleton;
- ASSERT_EQ(output2, NormalForm::evaluate(input2));
+ ASSERT_EQ(output2, BagsUtils::evaluate(input2));
// for normal sets, the first node is the largest, not smallest
Node normalBag = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_4, y_5);
Node input3 = d_nodeManager->mkNode(BAG_TO_SET, normalBag);
Node output3 = d_nodeManager->mkNode(SET_UNION, ySingleton, xSingleton);
- ASSERT_EQ(output3, NormalForm::evaluate(input3));
+ ASSERT_EQ(output3, BagsUtils::evaluate(input3));
}
} // namespace test
} // namespace cvc5