// no need to rewrite n if it is already in a normal form
response = BagsRewriteResponse(n, Rewrite::NONE);
}
- else if (NormalForm::AreChildrenConstants(n))
+ else if (NormalForm::areChildrenConstants(n))
{
Node value = NormalForm::evaluate(n);
response = BagsRewriteResponse(value, Rewrite::CONSTANT_EVALUATION);
#include "normal_form.h"
+#include "theory/sets/normal_form.h"
+#include "theory/type_enumerator.h"
+
+using namespace CVC4::kind;
+
namespace CVC4 {
namespace theory {
namespace bags {
-bool NormalForm::checkNormalConstant(TNode n)
+bool NormalForm::isConstant(TNode n)
{
- // TODO(projects#223): complete this function
+ if (n.getKind() == EMPTYBAG)
+ {
+ // empty bags are already normalized
+ return true;
+ }
+ if (n.getKind() == MK_BAG)
+ {
+ // see the implementation in MkBagTypeRule::computeIsConst
+ return n.isConst();
+ }
+ if (n.getKind() == UNION_DISJOINT)
+ {
+ if (!(n[0].getKind() == kind::MK_BAG && n[0].isConst()))
+ {
+ // the first child is not a constant
+ return false;
+ }
+ // store the previous element to check the ordering of elements
+ Node previousElement = n[0][0];
+ Node current = n[1];
+ while (current.getKind() == UNION_DISJOINT)
+ {
+ if (!(current[0].getKind() == kind::MK_BAG && current[0].isConst()))
+ {
+ // the current element is not a constant
+ return false;
+ }
+ if (previousElement >= current[0][0])
+ {
+ // the ordering is violated
+ return false;
+ }
+ previousElement = current[0][0];
+ current = current[1];
+ }
+ // check last element
+ if (!(current.getKind() == kind::MK_BAG && current.isConst()))
+ {
+ // the last element is not a constant
+ return false;
+ }
+ if (previousElement >= current[0])
+ {
+ // the ordering is violated
+ return false;
+ }
+ return true;
+ }
+
+ // only nodes with kinds EMPTY_BAG, MK_BAG, and UNION_DISJOINT can be
+ // constants
return false;
}
-bool NormalForm::AreChildrenConstants(TNode n)
+bool NormalForm::areChildrenConstants(TNode n)
{
return std::all_of(n.begin(), n.end(), [](Node c) { return c.isConst(); });
}
Node NormalForm::evaluate(TNode n)
{
- // TODO(projects#223): complete this function
- return CVC4::Node();
+ Assert(areChildrenConstants(n));
+ if (n.isConst())
+ {
+ // a constant node is already in a normal form
+ return n;
+ }
+ switch (n.getKind())
+ {
+ case MK_BAG: return evaluateMakeBag(n);
+ case BAG_COUNT: return evaluateBagCount(n);
+ case UNION_DISJOINT: return evaluateUnionDisjoint(n);
+ case UNION_MAX: return evaluateUnionMax(n);
+ case INTERSECTION_MIN: return evaluateIntersectionMin(n);
+ case DIFFERENCE_SUBTRACT: return evaluateDifferenceSubtract(n);
+ case DIFFERENCE_REMOVE: return evaluateDifferenceRemove(n);
+ case BAG_CHOOSE: return evaluateChoose(n);
+ case BAG_CARD: return evaluateCard(n);
+ case BAG_IS_SINGLETON: return evaluateIsSingleton(n);
+ case BAG_FROM_SET: return evaluateFromSet(n);
+ case BAG_TO_SET: return evaluateToSet(n);
+ default: break;
+ }
+ Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n
+ << std::endl;
+}
+
+template <typename T1, typename T2, typename T3, typename T4, typename T5>
+Node NormalForm::evaluateBinaryOperation(const TNode& n,
+ T1&& equal,
+ T2&& less,
+ T3&& greaterOrEqual,
+ T4&& remainderOfA,
+ T5&& remainderOfB)
+{
+ std::map<Node, Rational> elementsA = getBagElements(n[0]);
+ std::map<Node, Rational> elementsB = getBagElements(n[1]);
+ std::map<Node, Rational> elements;
+
+ std::map<Node, Rational>::const_iterator itA = elementsA.begin();
+ std::map<Node, Rational>::const_iterator itB = elementsB.begin();
+
+ Trace("bags-evaluate") << "[NormalForm::evaluateBinaryOperation "
+ << n.getKind() << "] " << std::endl
+ << "elements A: " << elementsA << std::endl
+ << "elements B: " << elementsB << std::endl;
+
+ while (itA != elementsA.end() && itB != elementsB.end())
+ {
+ if (itA->first == itB->first)
+ {
+ equal(elements, itA, itB);
+ itA++;
+ itB++;
+ }
+ else if (itA->first < itB->first)
+ {
+ less(elements, itA, itB);
+ itA++;
+ }
+ else
+ {
+ greaterOrEqual(elements, itA, itB);
+ itB++;
+ }
+ }
+
+ // handle the remaining elements from A
+ remainderOfA(elements, elementsA, itA);
+ // handle the remaining elements from B
+ remainderOfA(elements, elementsB, itB);
+
+ Trace("bags-evaluate") << "elements: " << elements << std::endl;
+ Node bag = constructBagFromElements(n.getType(), elements);
+ Trace("bags-evaluate") << "bag: " << bag << std::endl;
+ return bag;
+}
+
+std::map<Node, Rational> NormalForm::getBagElements(TNode n)
+{
+ Assert(n.isConst()) << "node " << n << " is not in a normal form"
+ << std::endl;
+ std::map<Node, Rational> elements;
+ if (n.getKind() == EMPTYBAG)
+ {
+ return elements;
+ }
+ while (n.getKind() == kind::UNION_DISJOINT)
+ {
+ Assert(n[0].getKind() == kind::MK_BAG);
+ Node element = n[0][0];
+ Rational count = n[0][1].getConst<Rational>();
+ elements[element] = count;
+ n = n[1];
+ }
+ Assert(n.getKind() == kind::MK_BAG);
+ Node lastElement = n[0];
+ Rational lastCount = n[1].getConst<Rational>();
+ elements[lastElement] = lastCount;
+ return elements;
+}
+
+Node NormalForm::constructBagFromElements(
+ TypeNode t, const std::map<Node, Rational>& elements)
+{
+ Assert(t.isBag());
+ NodeManager* nm = NodeManager::currentNM();
+ if (elements.empty())
+ {
+ return nm->mkConst(EmptyBag(t));
+ }
+ TypeNode elementType = t.getBagElementType();
+ std::map<Node, Rational>::const_reverse_iterator it = elements.rbegin();
+ Node bag =
+ nm->mkBag(elementType, it->first, nm->mkConst<Rational>(it->second));
+ while (++it != elements.rend())
+ {
+ Node n =
+ nm->mkBag(elementType, it->first, nm->mkConst<Rational>(it->second));
+ bag = nm->mkNode(UNION_DISJOINT, n, bag);
+ }
+ return bag;
+}
+
+Node NormalForm::evaluateMakeBag(TNode n)
+{
+ // the case where n is const should be handled earlier.
+ // here we handle the case where the multiplicity is zero or negative
+ Assert(n.getKind() == MK_BAG && !n.isConst()
+ && n[1].getConst<Rational>().sgn() < 1);
+ Node emptybag = NodeManager::currentNM()->mkConst(EmptyBag(n.getType()));
+ return emptybag;
+}
+
+Node NormalForm::evaluateBagCount(TNode n)
+{
+ Assert(n.getKind() == BAG_COUNT);
+ // Examples
+ // --------
+ // - (bag.count "x" (emptybag String)) = 0
+ // - (bag.count "x" (mkBag "y" 5)) = 0
+ // - (bag.count "x" (mkBag "x" 4)) = 4
+ // - (bag.count "x" (union_disjoint (mkBag "x" 4) (mkBag "y" 5)) = 4
+ // - (bag.count "x" (union_disjoint (mkBag "y" 5) (mkBag "z" 5)) = 0
+
+ std::map<Node, Rational> elements = getBagElements(n[1]);
+ std::map<Node, Rational>::iterator it = elements.find(n[0]);
+
+ NodeManager* nm = NodeManager::currentNM();
+ if (it != elements.end())
+ {
+ Node count = nm->mkConst(it->second);
+ return count;
+ }
+ return nm->mkConst(Rational(0));
+}
+
+Node NormalForm::evaluateUnionDisjoint(TNode n)
+{
+ Assert(n.getKind() == UNION_DISJOINT);
+ // Example
+ // -------
+ // input: (union_disjoint A B)
+ // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
+ // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
+ // output:
+ // (union_disjoint A B)
+ // where A = (MK_BAG "x" 7)
+ // B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2)))
+
+ auto equal = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // compute the sum of the multiplicities
+ elements[itA->first] = itA->second + itB->second;
+ };
+
+ auto less = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // add the element to the result
+ elements[itA->first] = itA->second;
+ };
+
+ auto greaterOrEqual = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // add the element to the result
+ elements[itB->first] = itB->second;
+ };
+
+ auto remainderOfA = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>& elementsA,
+ std::map<Node, Rational>::const_iterator& itA) {
+ // append the remainder of A
+ while (itA != elementsA.end())
+ {
+ elements[itA->first] = itA->second;
+ itA++;
+ }
+ };
+
+ auto remainderOfB = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>& elementsB,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // append the remainder of B
+ while (itB != elementsB.end())
+ {
+ elements[itB->first] = itB->second;
+ itB++;
+ }
+ };
+
+ return evaluateBinaryOperation(
+ n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
+}
+
+Node NormalForm::evaluateUnionMax(TNode n)
+{
+ Assert(n.getKind() == UNION_MAX);
+ // Example
+ // -------
+ // input: (union_max A B)
+ // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
+ // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
+ // output:
+ // (union_disjoint A B)
+ // where A = (MK_BAG "x" 4)
+ // B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2)))
+
+ auto equal = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // compute the maximum multiplicity
+ elements[itA->first] = std::max(itA->second, itB->second);
+ };
+
+ auto less = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // add to the result
+ elements[itA->first] = itA->second;
+ };
+
+ auto greaterOrEqual = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // add to the result
+ elements[itB->first] = itB->second;
+ };
+
+ auto remainderOfA = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>& elementsA,
+ std::map<Node, Rational>::const_iterator& itA) {
+ // append the remainder of A
+ while (itA != elementsA.end())
+ {
+ elements[itA->first] = itA->second;
+ itA++;
+ }
+ };
+
+ auto remainderOfB = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>& elementsB,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // append the remainder of B
+ while (itB != elementsB.end())
+ {
+ elements[itB->first] = itB->second;
+ itB++;
+ }
+ };
+
+ return evaluateBinaryOperation(
+ n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
}
+
+Node NormalForm::evaluateIntersectionMin(TNode n)
+{
+ Assert(n.getKind() == INTERSECTION_MIN);
+ // Example
+ // -------
+ // input: (intersectionMin A B)
+ // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
+ // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
+ // output:
+ // (MK_BAG "x" 3)
+
+ auto equal = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // compute the minimum multiplicity
+ elements[itA->first] = std::min(itA->second, itB->second);
+ };
+
+ auto less = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // do nothing
+ };
+
+ auto greaterOrEqual = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // do nothing
+ };
+
+ auto remainderOfA = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>& elementsA,
+ std::map<Node, Rational>::const_iterator& itA) {
+ // do nothing
+ };
+
+ auto remainderOfB = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>& elementsB,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // do nothing
+ };
+
+ return evaluateBinaryOperation(
+ n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
+}
+
+Node NormalForm::evaluateDifferenceSubtract(TNode n)
+{
+ Assert(n.getKind() == DIFFERENCE_SUBTRACT);
+ // Example
+ // -------
+ // input: (difference_subtract A B)
+ // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
+ // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
+ // output:
+ // (union_disjoint (MK_BAG "x" 1) (MK_BAG "z" 2))
+
+ auto equal = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // subtract the multiplicities
+ elements[itA->first] = itA->second - itB->second;
+ };
+
+ auto less = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // itA->first is not in B, so we add it to the difference subtract
+ elements[itA->first] = itA->second;
+ };
+
+ auto greaterOrEqual = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // itB->first is not in A, so we just skip it
+ };
+
+ auto remainderOfA = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>& elementsA,
+ std::map<Node, Rational>::const_iterator& itA) {
+ // append the remainder of A
+ while (itA != elementsA.end())
+ {
+ elements[itA->first] = itA->second;
+ itA++;
+ }
+ };
+
+ auto remainderOfB = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>& elementsB,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // do nothing
+ };
+
+ return evaluateBinaryOperation(
+ n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
+}
+
+Node NormalForm::evaluateDifferenceRemove(TNode n)
+{
+ Assert(n.getKind() == DIFFERENCE_REMOVE);
+ // Example
+ // -------
+ // input: (difference_subtract A B)
+ // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
+ // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
+ // output:
+ // (MK_BAG "z" 2)
+
+ auto equal = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // skip the shared element by doing nothing
+ };
+
+ auto less = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // itA->first is not in B, so we add it to the difference remove
+ elements[itA->first] = itA->second;
+ };
+
+ auto greaterOrEqual = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>::const_iterator& itA,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // itB->first is not in A, so we just skip it
+ };
+
+ auto remainderOfA = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>& elementsA,
+ std::map<Node, Rational>::const_iterator& itA) {
+ // append the remainder of A
+ while (itA != elementsA.end())
+ {
+ elements[itA->first] = itA->second;
+ itA++;
+ }
+ };
+
+ auto remainderOfB = [](std::map<Node, Rational>& elements,
+ std::map<Node, Rational>& elementsB,
+ std::map<Node, Rational>::const_iterator& itB) {
+ // do nothing
+ };
+
+ return evaluateBinaryOperation(
+ n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
+}
+
+Node NormalForm::evaluateChoose(TNode n)
+{
+ Assert(n.getKind() == BAG_CHOOSE);
+ // Examples
+ // --------
+ // - (choose (emptyBag String)) = "" // the empty string which is the first
+ // element returned by the type enumerator
+ // - (choose (MK_BAG "x" 4)) = "x"
+ // - (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) = "x"
+ // deterministically return the first element
+
+ if (n[0].getKind() == EMPTYBAG)
+ {
+ TypeNode elementType = n[0].getType().getBagElementType();
+ TypeEnumerator typeEnumerator(elementType);
+ // get the first value from the typeEnumerator
+ Node element = *typeEnumerator;
+ return element;
+ }
+
+ if (n[0].getKind() == MK_BAG)
+ {
+ return n[0][0];
+ }
+ Assert(n[0].getKind() == UNION_DISJOINT);
+ // return the first element
+ // e.g. (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1)))
+ return n[0][0][0];
+}
+
+Node NormalForm::evaluateCard(TNode n)
+{
+ Assert(n.getKind() == BAG_CARD);
+ // Examples
+ // --------
+ // - (card (emptyBag String)) = 0
+ // - (choose (MK_BAG "x" 4)) = 4
+ // - (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) = 5
+
+ std::map<Node, Rational> elements = getBagElements(n[0]);
+ Rational sum(0);
+ for (std::pair<Node, Rational> element : elements)
+ {
+ sum += element.second;
+ }
+
+ NodeManager* nm = NodeManager::currentNM();
+ Node sumNode = nm->mkConst(sum);
+ return sumNode;
+}
+
+Node NormalForm::evaluateIsSingleton(TNode n)
+{
+ Assert(n.getKind() == BAG_IS_SINGLETON);
+ // Examples
+ // --------
+ // - (bag.is_singleton (emptyBag String)) = false
+ // - (bag.is_singleton (MK_BAG "x" 1)) = true
+ // - (bag.is_singleton (MK_BAG "x" 4)) = false
+ // - (bag.is_singleton (union_disjoint (MK_BAG "x" 1) (MK_BAG "y" 1))) = false
+
+ if (n[0].getKind() == MK_BAG && n[0][1].getConst<Rational>().isOne())
+ {
+ return NodeManager::currentNM()->mkConst(true);
+ }
+ return NodeManager::currentNM()->mkConst(false);
+}
+
+Node NormalForm::evaluateFromSet(TNode n)
+{
+ Assert(n.getKind() == BAG_FROM_SET);
+
+ // Examples
+ // --------
+ // - (bag.from_set (emptyset String)) = (emptybag String)
+ // - (bag.from_set (singleton "x")) = (mkBag "x" 1)
+ // - (bag.from_set (union (singleton "x") (singleton "y"))) =
+ // (disjoint_union (mkBag "x" 1) (mkBag "y" 1))
+
+ NodeManager* nm = NodeManager::currentNM();
+ std::set<Node> setElements =
+ sets::NormalForm::getElementsFromNormalConstant(n[0]);
+ Rational one = Rational(1);
+ std::map<Node, Rational> bagElements;
+ for (const Node& element : setElements)
+ {
+ bagElements[element] = one;
+ }
+ TypeNode bagType = nm->mkBagType(n[0].getType().getSetElementType());
+ Node bag = constructBagFromElements(bagType, bagElements);
+ return bag;
+}
+
+Node NormalForm::evaluateToSet(TNode n)
+{
+ Assert(n.getKind() == BAG_TO_SET);
+
+ // Examples
+ // --------
+ // - (bag.to_set (emptybag String)) = (emptyset String)
+ // - (bag.to_set (mkBag "x" 4)) = (singleton "x")
+ // - (bag.to_set (disjoint_union (mkBag "x" 3) (mkBag "y" 5)) =
+ // (union (singleton "x") (singleton "y")))
+
+ NodeManager* nm = NodeManager::currentNM();
+ std::map<Node, Rational> bagElements = getBagElements(n[0]);
+ std::set<Node> setElements;
+ std::map<Node, Rational>::const_reverse_iterator it;
+ for (it = bagElements.rbegin(); it != bagElements.rend(); it++)
+ {
+ setElements.insert(it->first);
+ }
+ TypeNode setType = nm->mkSetType(n[0].getType().getBagElementType());
+ Node set = sets::NormalForm::elementsToSet(setElements, setType);
+ return set;
+}
+
} // namespace bags
} // namespace theory
} // namespace CVC4
\ No newline at end of file
/**
* Returns true if n is considered a to be a (canonical) constant bag value.
* A canonical bag value is one whose AST is:
- * (disjoint-union (mk-bag e1 n1) ...
- * (disjoint-union (mk-bag e_{n-1} n_{n-1}) (mk-bag e_n n_n))))
- * where c1 ... cn are constants and the node identifier of these constants
- * are such that:
- * c1 > ... > cn.
- * Also handles the corner cases of empty bag and singleton bag.
+ * (union_disjoint (mkBag e1 c1) ...
+ * (union_disjoint (mkBag e_{n-1} c_{n-1}) (mkBag e_n c_n))))
+ * where c1 ... cn are positive integers, e1 ... en are constants, and the
+ * node identifier of these constants are such that: e1 < ... < en.
+ * Also handles the corner cases of empty bag and bag constructed by mkBag
*/
- static bool checkNormalConstant(TNode n);
+ static bool isConstant(TNode n);
/**
- * check whether all children of the given node are in normal form
+ * check whether all children of the given node are constants
*/
- static bool AreChildrenConstants(TNode n);
+ static bool areChildrenConstants(TNode n);
/**
- * evaluate the node n to a constant value
+ * evaluate the node n to a constant value.
+ * As a precondition, children of n should be constants.
*/
static Node evaluate(TNode n);
+
+ /**
+ * get the elements along with their multiplicities in a given bag
+ * @param n a constant node whose type is a bag
+ * @return a map whose keys are constant elements and values are
+ * multiplicities
+ */
+ static std::map<Node, Rational> getBagElements(TNode n);
+
+ /**
+ * construct a constant bag from constant elements
+ * @param t the type of the returned bag
+ * @param elements a map whose keys are constant elements and values are
+ * multiplicities
+ * @return a constant bag that contains
+ */
+ static Node constructBagFromElements(
+ TypeNode t, const std::map<Node, Rational>& elements);
+
+ private:
+ /**
+ * a high order helper function that return a constant bag that is the result
+ * of (op A B) where op is a binary operator and A, B are constant bags.
+ * The result is computed from the elements of A (elementsA with iterator itA)
+ * and elements of B (elementsB with iterator itB).
+ * The arguments below specify how these iterators are used to generate the
+ * elements of the result (elements).
+ * @param n a node whose kind is a binary operator (union_disjoint, union_max,
+ * intersection_min, difference_subtract, difference_remove) and whose
+ * children are constant bags.
+ * @param equal a lambda expression that receives (elements, itA, itB) and
+ * specify the action that needs to be taken when the elements of itA, itB are
+ * equal.
+ * @param less a lambda expression that receives (elements, itA, itB) and
+ * specify the action that needs to be taken when the element itA is less than
+ * the element of itB.
+ * @param greaterOrEqual less a lambda expression that receives (elements,
+ * itA, itB) and specify the action that needs to be taken when the element
+ * itA is greater than or equal than the element of itB.
+ * @param remainderOfA a lambda expression that receives (elements, elementsA,
+ * itA) and specify the action that needs to be taken to the remaining
+ * elements of A when all elements of B are visited.
+ * @param remainderOfB a lambda expression that receives (elements, elementsB,
+ * itB) and specify the action that needs to be taken to the remaining
+ * elements of B when all elements of A are visited.
+ * @return a constant bag that the result of (op n[0] n[1])
+ */
+ template <typename T1, typename T2, typename T3, typename T4, typename T5>
+ static Node evaluateBinaryOperation(const TNode& n,
+ T1&& equal,
+ T2&& less,
+ T3&& greaterOrEqual,
+ T4&& remainderOfA,
+ T5&& remainderOfB);
+ /**
+ * evaluate n as follows:
+ * - (mkBag a 0) = (emptybag T) where T is the type of the original bag
+ * - (mkBag a (-c)) = (emptybag T) where T is the type the original bag,
+ * and c > 0 is a constant
+ */
+ static Node evaluateMakeBag(TNode n);
+
+ /**
+ * returns the multiplicity in a constant bag
+ * @param n has the form (bag.count x A) where x, A are constants
+ * @return the multiplicity of element x in bag A.
+ */
+ static Node evaluateBagCount(TNode n);
+
+ /**
+ * evaluates union disjoint node such that the returned node is a canonical
+ * bag that has the form
+ * (union_disjoint (mkBag e1 c1) ...
+ * (union_disjoint * (mkBag e_{n-1} c_{n-1}) (mkBag e_n c_n)))) where
+ * c1... cn are positive integers, e1 ... en are constants, and the node
+ * identifier of these constants are such that: e1 < ... < en.
+ * @param n has the form (union_disjoint A B) where A, B are constant bags
+ * @return the union disjoint of A and B
+ */
+ static Node evaluateUnionDisjoint(TNode n);
+ /**
+ * @param n has the form (union_max A B) where A, B are constant bags
+ * @return the union max of A and B
+ */
+ static Node evaluateUnionMax(TNode n);
+ /**
+ * @param n has the form (intersection_min A B) where A, B are constant bags
+ * @return the intersection min of A and B
+ */
+ static Node evaluateIntersectionMin(TNode n);
+ /**
+ * @param n has the form (difference_subtract A B) where A, B are constant
+ * bags
+ * @return the difference subtract of A and B
+ */
+ static Node evaluateDifferenceSubtract(TNode n);
+ /**
+ * @param n has the form (difference_remove A B) where A, B are constant bags
+ * @return the difference remove of A and B
+ */
+ static Node evaluateDifferenceRemove(TNode n);
+ /**
+ * @param n has the form (bag.choose A) where A is a constant bag
+ * @return the first element of A if A is not empty. Otherwise, it returns the
+ * first element returned by the type enumerator for the elements
+ */
+ static Node evaluateChoose(TNode n);
+ /**
+ * @param n has the form (bag.card A) where A is a constant bag
+ * @return the number of elements in bag A
+ */
+ static Node evaluateCard(TNode n);
+ /**
+ * @param n has the form (bag.is_singleton A) where A is a constant bag
+ * @return whether the bag A has cardinality one.
+ */
+ static Node evaluateIsSingleton(TNode n);
+ /**
+ * @param n has the form (bag.from_set A) where A is a constant set
+ * @return a constant bag that contains exactly the elements in A.
+ */
+ static Node evaluateFromSet(TNode n);
+ /**
+ * @param n has the form (bag.to_set A) where A is a constant bag
+ * @return a constant set constructed from the elements in A.
+ */
+ static Node evaluateToSet(TNode n);
};
} // namespace bags
} // namespace theory
// only UNION_DISJOINT has a const rule in kinds.
// Other binary operators do not have const rules in kinds
Assert(n.getKind() == kind::UNION_DISJOINT);
- return NormalForm::checkNormalConstant(n);
+ return NormalForm::isConstant(n);
}
}; /* struct BinaryOperatorTypeRule */
cvc4_add_unit_test_white(logic_info_white theory)
cvc4_add_unit_test_white(sequences_rewriter_white theory)
cvc4_add_unit_test_white(theory_arith_white theory)
+cvc4_add_unit_test_white(theory_bags_normal_form_white theory)
cvc4_add_unit_test_white(theory_bags_rewriter_white theory)
cvc4_add_unit_test_white(theory_bags_type_rules_white theory)
cvc4_add_unit_test_white(theory_bv_rewriter_white theory)
--- /dev/null
+/********************* */
+/*! \file theory_bags_normal_form_white.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ ** Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved. See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief White box testing of bags normal form
+ **/
+
+#include <cxxtest/TestSuite.h>
+
+#include "expr/dtype.h"
+#include "smt/smt_engine.h"
+#include "theory/bags/bags_rewriter.h"
+#include "theory/bags/normal_form.h"
+#include "theory/strings/type_enumerator.h"
+
+using namespace CVC4;
+using namespace CVC4::smt;
+using namespace CVC4::theory;
+using namespace CVC4::kind;
+using namespace CVC4::theory::bags;
+using namespace std;
+
+typedef expr::Attribute<Node, Node> attribute;
+
+class BagsNormalFormWhite : public CxxTest::TestSuite
+{
+ public:
+ void setUp() override
+ {
+ d_em.reset(new ExprManager());
+ d_smt.reset(new SmtEngine(d_em.get()));
+ d_nm.reset(NodeManager::fromExprManager(d_em.get()));
+ d_smt->finishInit();
+ d_rewriter.reset(new BagsRewriter(nullptr));
+ }
+
+ void tearDown() override
+ {
+ d_rewriter.reset();
+ d_smt.reset();
+ d_nm.release();
+ d_em.reset();
+ }
+
+ std::vector<Node> getNStrings(size_t n)
+ {
+ std::vector<Node> elements(n);
+ CVC4::theory::strings::StringEnumerator enumerator(d_nm->stringType());
+
+ for (size_t i = 0; i < n; i++)
+ {
+ ++enumerator;
+ elements[i] = *enumerator;
+ }
+
+ return elements;
+ }
+
+ void testEmptyBagNormalForm()
+ {
+ Node emptybag = d_nm->mkConst(EmptyBag(d_nm->stringType()));
+ // empty bags are in normal form
+ TS_ASSERT(emptybag.isConst());
+ Node n = NormalForm::evaluate(emptybag);
+ TS_ASSERT(emptybag == n);
+ }
+
+ void testBagEquality() {}
+
+ void testMkBagConstantElement()
+ {
+ vector<Node> elements = getNStrings(1);
+ Node negative = d_nm->mkBag(
+ d_nm->stringType(), elements[0], d_nm->mkConst(Rational(-1)));
+ Node zero = d_nm->mkBag(
+ d_nm->stringType(), elements[0], d_nm->mkConst(Rational(0)));
+ Node positive = d_nm->mkBag(
+ d_nm->stringType(), elements[0], d_nm->mkConst(Rational(1)));
+ Node emptybag =
+ d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+
+ TS_ASSERT(!negative.isConst());
+ TS_ASSERT(!zero.isConst());
+ TS_ASSERT(emptybag == NormalForm::evaluate(negative));
+ TS_ASSERT(emptybag == NormalForm::evaluate(zero));
+ TS_ASSERT(positive == NormalForm::evaluate(positive));
+ }
+
+ void testBagCount()
+ {
+ // Examples
+ // -------
+ // (bag.count "x" (emptybag String)) = 0
+ // (bag.count "x" (mkBag "y" 5)) = 0
+ // (bag.count "x" (mkBag "x" 4)) = 4
+ // (bag.count "x" (union_disjoint (mkBag "x" 4) (mkBag "y" 5)) = 4
+ // (bag.count "x" (union_disjoint (mkBag "y" 5) (mkBag "z" 5)) = 0
+
+ Node zero = d_nm->mkConst(Rational(0));
+ Node four = d_nm->mkConst(Rational(4));
+ Node empty = d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+ Node x = d_nm->mkConst(String("x"));
+ Node y = d_nm->mkConst(String("y"));
+ Node z = d_nm->mkConst(String("z"));
+ Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4)));
+ Node y_5 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(5)));
+ Node z_5 = d_nm->mkBag(d_nm->stringType(), z, d_nm->mkConst(Rational(5)));
+
+ Node input1 = d_nm->mkNode(BAG_COUNT, x, empty);
+ Node output1 = zero;
+ TS_ASSERT(output1 == NormalForm::evaluate(input1));
+
+ Node input2 = d_nm->mkNode(BAG_COUNT, x, y_5);
+ Node output2 = zero;
+ TS_ASSERT(output2 == NormalForm::evaluate(input2));
+
+ Node input3 = d_nm->mkNode(BAG_COUNT, x, x_4);
+ Node output3 = four;
+ TS_ASSERT(output2 == NormalForm::evaluate(input2));
+
+ Node unionDisjointXY = d_nm->mkNode(UNION_DISJOINT, x_4, y_5);
+ Node input4 = d_nm->mkNode(BAG_COUNT, x, unionDisjointXY);
+ Node output4 = four;
+ TS_ASSERT(output3 == NormalForm::evaluate(input3));
+
+ Node unionDisjointYZ = d_nm->mkNode(UNION_DISJOINT, y_5, z_5);
+ Node input5 = d_nm->mkNode(BAG_COUNT, x, unionDisjointYZ);
+ Node output5 = zero;
+ TS_ASSERT(output4 == NormalForm::evaluate(input4));
+ }
+
+ void testUnionMax()
+ {
+ // Example
+ // -------
+ // input: (union_max A B)
+ // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
+ // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
+ // output:
+ // (union_disjoint A B)
+ // where A = (MK_BAG "x" 4)
+ // B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2)))
+
+ Node x = d_nm->mkConst(String("x"));
+ Node y = d_nm->mkConst(String("y"));
+ Node z = d_nm->mkConst(String("z"));
+ Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4)));
+ Node x_3 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(3)));
+ Node x_7 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(7)));
+ Node z_2 = d_nm->mkBag(d_nm->stringType(), z, d_nm->mkConst(Rational(2)));
+ Node y_1 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(1)));
+
+ Node A = d_nm->mkNode(UNION_DISJOINT, x_4, z_2);
+ Node B = d_nm->mkNode(UNION_DISJOINT, x_3, y_1);
+ Node input = d_nm->mkNode(UNION_MAX, A, B);
+
+ // output
+ Node output = d_nm->mkNode(
+ UNION_DISJOINT, x_4, d_nm->mkNode(UNION_DISJOINT, y_1, z_2));
+
+ TS_ASSERT(output.isConst());
+ TS_ASSERT(output == NormalForm::evaluate(input));
+ }
+
+ void testUnionDisjoint1()
+ {
+ vector<Node> elements = getNStrings(3);
+ Node emptybag =
+ d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+ Node A = d_nm->mkBag(
+ d_nm->stringType(), elements[0], d_nm->mkConst(Rational(2)));
+ Node B = d_nm->mkBag(
+ d_nm->stringType(), elements[1], d_nm->mkConst(Rational(3)));
+ Node C = d_nm->mkBag(
+ d_nm->stringType(), elements[2], d_nm->mkConst(Rational(4)));
+
+ Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
+ // unionDisjointAB is already in a normal form
+ TS_ASSERT(unionDisjointAB.isConst());
+ TS_ASSERT(unionDisjointAB == NormalForm::evaluate(unionDisjointAB));
+
+ Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
+ // unionDisjointAB is is the normal form of unionDisjointBA
+ TS_ASSERT(!unionDisjointBA.isConst());
+ TS_ASSERT(unionDisjointAB == NormalForm::evaluate(unionDisjointBA));
+
+ Node unionDisjointAB_C = d_nm->mkNode(UNION_DISJOINT, unionDisjointAB, C);
+ Node unionDisjointBC = d_nm->mkNode(UNION_DISJOINT, B, C);
+ Node unionDisjointA_BC = d_nm->mkNode(UNION_DISJOINT, A, unionDisjointBC);
+ // unionDisjointA_BC is the normal form of unionDisjointAB_C
+ TS_ASSERT(!unionDisjointAB_C.isConst());
+ TS_ASSERT(unionDisjointA_BC.isConst());
+ TS_ASSERT(unionDisjointA_BC == NormalForm::evaluate(unionDisjointAB_C));
+
+ Node unionDisjointAA = d_nm->mkNode(UNION_DISJOINT, A, A);
+ Node AA = d_nm->mkBag(
+ d_nm->stringType(), elements[0], d_nm->mkConst(Rational(4)));
+ TS_ASSERT(!unionDisjointAA.isConst());
+ TS_ASSERT(AA.isConst());
+ TS_ASSERT(AA == NormalForm::evaluate(unionDisjointAA));
+ }
+
+ void testUnionDisjoint2()
+ {
+ // Example
+ // -------
+ // input: (union_disjoint A B)
+ // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
+ // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
+ // output:
+ // (union_disjoint A B)
+ // where A = (MK_BAG "x" 7)
+ // B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2)))
+
+ Node x = d_nm->mkConst(String("x"));
+ Node y = d_nm->mkConst(String("y"));
+ Node z = d_nm->mkConst(String("z"));
+ Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4)));
+ Node x_3 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(3)));
+ Node x_7 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(7)));
+ Node z_2 = d_nm->mkBag(d_nm->stringType(), z, d_nm->mkConst(Rational(2)));
+ Node y_1 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(1)));
+
+ Node A = d_nm->mkNode(UNION_DISJOINT, x_4, z_2);
+ Node B = d_nm->mkNode(UNION_DISJOINT, x_3, y_1);
+ Node input = d_nm->mkNode(UNION_DISJOINT, A, B);
+
+ // output
+ Node output = d_nm->mkNode(
+ UNION_DISJOINT, x_7, d_nm->mkNode(UNION_DISJOINT, y_1, z_2));
+
+ TS_ASSERT(output.isConst());
+ TS_ASSERT(output == NormalForm::evaluate(input));
+ }
+
+ void testIntersectionMin()
+ {
+ // Example
+ // -------
+ // input: (intersection_min A B)
+ // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
+ // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
+ // output:
+ // (MK_BAG "x" 3)
+
+ Node x = d_nm->mkConst(String("x"));
+ Node y = d_nm->mkConst(String("y"));
+ Node z = d_nm->mkConst(String("z"));
+ Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4)));
+ Node x_3 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(3)));
+ Node x_7 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(7)));
+ Node z_2 = d_nm->mkBag(d_nm->stringType(), z, d_nm->mkConst(Rational(2)));
+ Node y_1 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(1)));
+
+ Node A = d_nm->mkNode(UNION_DISJOINT, x_4, z_2);
+ Node B = d_nm->mkNode(UNION_DISJOINT, x_3, y_1);
+ Node input = d_nm->mkNode(INTERSECTION_MIN, A, B);
+
+ // output
+ Node output = x_3;
+
+ TS_ASSERT(output.isConst());
+ TS_ASSERT(output == NormalForm::evaluate(input));
+ }
+
+ void testDifferenceSubtract()
+ {
+ // Example
+ // -------
+ // input: (difference_subtract A B)
+ // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
+ // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
+ // output:
+ // (union_disjoint (MK_BAG "x" 1) (MK_BAG "z" 2))
+
+ Node x = d_nm->mkConst(String("x"));
+ Node y = d_nm->mkConst(String("y"));
+ Node z = d_nm->mkConst(String("z"));
+ Node x_1 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(1)));
+ Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4)));
+ Node x_3 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(3)));
+ Node x_7 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(7)));
+ Node z_2 = d_nm->mkBag(d_nm->stringType(), z, d_nm->mkConst(Rational(2)));
+ Node y_1 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(1)));
+
+ Node A = d_nm->mkNode(UNION_DISJOINT, x_4, z_2);
+ Node B = d_nm->mkNode(UNION_DISJOINT, x_3, y_1);
+ Node input = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, B);
+
+ // output
+ Node output = d_nm->mkNode(UNION_DISJOINT, x_1, z_2);
+
+ TS_ASSERT(output.isConst());
+ TS_ASSERT(output == NormalForm::evaluate(input));
+ }
+
+ void testDifferenceRemove()
+ {
+ // Example
+ // -------
+ // input: (difference_remove A B)
+ // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
+ // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
+ // output:
+ // (MK_BAG "z" 2)
+
+ Node x = d_nm->mkConst(String("x"));
+ Node y = d_nm->mkConst(String("y"));
+ Node z = d_nm->mkConst(String("z"));
+ Node x_1 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(1)));
+ Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4)));
+ Node x_3 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(3)));
+ Node x_7 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(7)));
+ Node z_2 = d_nm->mkBag(d_nm->stringType(), z, d_nm->mkConst(Rational(2)));
+ Node y_1 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(1)));
+
+ Node A = d_nm->mkNode(UNION_DISJOINT, x_4, z_2);
+ Node B = d_nm->mkNode(UNION_DISJOINT, x_3, y_1);
+ Node input = d_nm->mkNode(DIFFERENCE_REMOVE, A, B);
+
+ // output
+ Node output = z_2;
+
+ TS_ASSERT(output.isConst());
+ TS_ASSERT(output == NormalForm::evaluate(input));
+ }
+
+ void testChoose()
+ {
+ // Example
+ // -------
+ // input: (choose (emptybag String))
+ // output: "A"; the first element returned by the type enumerator
+ // input: (choose (MK_BAG "x" 4))
+ // output: "x"
+ // input: (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1)))
+ // output: "x"; deterministically return the first element
+ Node empty = d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+ Node x = d_nm->mkConst(String("x"));
+ Node y = d_nm->mkConst(String("y"));
+ Node z = d_nm->mkConst(String("z"));
+ Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4)));
+ Node y_1 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(1)));
+
+ Node input1 = d_nm->mkNode(BAG_CHOOSE, empty);
+ Node output1 = d_nm->mkConst(String(""));
+
+ TS_ASSERT(output1 == NormalForm::evaluate(input1));
+
+ Node input2 = d_nm->mkNode(BAG_CHOOSE, x_4);
+ Node output2 = x;
+ TS_ASSERT(output2 == NormalForm::evaluate(input2));
+
+ Node union_disjoint = d_nm->mkNode(UNION_DISJOINT, x_4, y_1);
+ Node input3 = d_nm->mkNode(BAG_CHOOSE, union_disjoint);
+ Node output3 = x;
+ TS_ASSERT(output3 == NormalForm::evaluate(input3));
+ }
+
+ void testBagCard()
+ {
+ // Examples
+ // --------
+ // - (card (emptybag String)) = 0
+ // - (choose (MK_BAG "x" 4)) = 4
+ // - (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) = 5
+ Node empty = d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+ Node x = d_nm->mkConst(String("x"));
+ Node y = d_nm->mkConst(String("y"));
+ Node z = d_nm->mkConst(String("z"));
+ Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4)));
+ Node y_1 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(1)));
+
+ Node input1 = d_nm->mkNode(BAG_CARD, empty);
+ Node output1 = d_nm->mkConst(Rational(0));
+
+ TS_ASSERT(output1 == NormalForm::evaluate(input1));
+
+ Node input2 = d_nm->mkNode(BAG_CARD, x_4);
+ Node output2 = d_nm->mkConst(Rational(4));
+ TS_ASSERT(output2 == NormalForm::evaluate(input2));
+
+ Node union_disjoint = d_nm->mkNode(UNION_DISJOINT, x_4, y_1);
+ Node input3 = d_nm->mkNode(BAG_CARD, union_disjoint);
+ Node output3 = d_nm->mkConst(Rational(5));
+ TS_ASSERT(output3 == NormalForm::evaluate(input3));
+ }
+
+ void testIsSingleton()
+ {
+ // Examples
+ // --------
+ // - (bag.is_singleton (emptybag String)) = false
+ // - (bag.is_singleton (MK_BAG "x" 1)) = true
+ // - (bag.is_singleton (MK_BAG "x" 4)) = false
+ // - (bag.is_singleton (union_disjoint (MK_BAG "x" 1) (MK_BAG "y" 1))) =
+ // false
+ Node falseNode = d_nm->mkConst(false);
+ Node trueNode = d_nm->mkConst(true);
+ Node empty = d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+ Node x = d_nm->mkConst(String("x"));
+ Node y = d_nm->mkConst(String("y"));
+ Node z = d_nm->mkConst(String("z"));
+ Node x_1 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(1)));
+ Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4)));
+ Node y_1 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(1)));
+
+ Node input1 = d_nm->mkNode(BAG_IS_SINGLETON, empty);
+ Node output1 = falseNode;
+ TS_ASSERT(output1 == NormalForm::evaluate(input1));
+
+ Node input2 = d_nm->mkNode(BAG_IS_SINGLETON, x_1);
+ Node output2 = trueNode;
+ TS_ASSERT(output2 == NormalForm::evaluate(input2));
+
+ Node input3 = d_nm->mkNode(BAG_IS_SINGLETON, x_4);
+ Node output3 = falseNode;
+ TS_ASSERT(output2 == NormalForm::evaluate(input2));
+
+ Node union_disjoint = d_nm->mkNode(UNION_DISJOINT, x_1, y_1);
+ Node input4 = d_nm->mkNode(BAG_IS_SINGLETON, union_disjoint);
+ Node output4 = falseNode;
+ TS_ASSERT(output3 == NormalForm::evaluate(input3));
+ }
+
+ void testFromSet()
+ {
+ // Examples
+ // --------
+ // - (bag.from_set (emptyset String)) = (emptybag String)
+ // - (bag.from_set (singleton "x")) = (mkBag "x" 1)
+ // - (bag.to_set (union (singleton "x") (singleton "y"))) =
+ // (disjoint_union (mkBag "x" 1) (mkBag "y" 1))
+
+ Node emptyset =
+ d_nm->mkConst(EmptySet(d_nm->mkSetType(d_nm->stringType())));
+ Node emptybag =
+ d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+ Node input1 = d_nm->mkNode(BAG_FROM_SET, emptyset);
+ Node output1 = emptybag;
+ TS_ASSERT(output1 == NormalForm::evaluate(input1));
+
+ Node x = d_nm->mkConst(String("x"));
+ Node y = d_nm->mkConst(String("y"));
+
+ Node xSingleton = d_nm->mkSingleton(d_nm->stringType(), x);
+ Node ySingleton = d_nm->mkSingleton(d_nm->stringType(), y);
+
+ Node x_1 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(1)));
+ Node y_1 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(1)));
+
+ Node input2 = d_nm->mkNode(BAG_FROM_SET, xSingleton);
+ Node output2 = x_1;
+ TS_ASSERT(output2 == NormalForm::evaluate(input2));
+
+ // for normal sets, the first node is the largest, not smallest
+ Node normalSet = d_nm->mkNode(UNION, ySingleton, xSingleton);
+ Node input3 = d_nm->mkNode(BAG_FROM_SET, normalSet);
+ Node output3 = d_nm->mkNode(UNION_DISJOINT, x_1, y_1);
+ TS_ASSERT(output3 == NormalForm::evaluate(input3));
+ }
+
+ void testToSet()
+ {
+ // Examples
+ // --------
+ // - (bag.to_set (emptybag String)) = (emptyset String)
+ // - (bag.to_set (mkBag "x" 4)) = (singleton "x")
+ // - (bag.to_set (disjoint_union (mkBag "x" 3) (mkBag "y" 5)) =
+ // (union (singleton "x") (singleton "y")))
+
+ Node emptyset =
+ d_nm->mkConst(EmptySet(d_nm->mkSetType(d_nm->stringType())));
+ Node emptybag =
+ d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+ Node input1 = d_nm->mkNode(BAG_TO_SET, emptybag);
+ Node output1 = emptyset;
+ TS_ASSERT(output1 == NormalForm::evaluate(input1));
+
+ Node x = d_nm->mkConst(String("x"));
+ Node y = d_nm->mkConst(String("y"));
+
+ Node xSingleton = d_nm->mkSingleton(d_nm->stringType(), x);
+ Node ySingleton = d_nm->mkSingleton(d_nm->stringType(), y);
+
+ Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4)));
+ Node y_5 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(5)));
+
+ Node input2 = d_nm->mkNode(BAG_TO_SET, x_4);
+ Node output2 = xSingleton;
+ TS_ASSERT(output2 == NormalForm::evaluate(input2));
+
+ // for normal sets, the first node is the largest, not smallest
+ Node normalBag = d_nm->mkNode(UNION_DISJOINT, x_4, y_5);
+ Node input3 = d_nm->mkNode(BAG_TO_SET, normalBag);
+ Node output3 = d_nm->mkNode(UNION, ySingleton, xSingleton);
+ TS_ASSERT(output3 == NormalForm::evaluate(input3));
+ }
+
+ private:
+ std::unique_ptr<ExprManager> d_em;
+ std::unique_ptr<SmtEngine> d_smt;
+ std::unique_ptr<NodeManager> d_nm;
+ std::unique_ptr<BagsRewriter> d_rewriter;
+}; /* class BagsTypeRuleBlack */
Node bag = d_nm->mkBag(d_nm->stringType(), elements[0], d_nm->mkConst(Rational(10)));
TS_ASSERT_THROWS_NOTHING(d_nm->mkNode(BAG_TO_SET, bag));
TS_ASSERT(d_nm->mkNode(BAG_TO_SET, bag).getType().isSet());
+ std::cout<<"Rational(4, 4).isIntegral() " << d_nm->mkConst(Rational(4,4)).getType()<< std::endl;
+ std::cout<<"Rational(8, 4).isIntegral() " << d_nm->mkConst(Rational(8,4)).getType()<< std::endl;
+ std::cout<<"Rational(1, 4).isIntegral() " << d_nm->mkConst(Rational(1,4)).getType()<< std::endl;
+
+ std::cout<<"Rational(4, 4).isIntegral() " << d_nm->mkNode(TO_REAL, d_nm->mkConst(Rational(4,4))).getType()<< std::endl;
+ std::cout<<"Rational(8, 4).isIntegral() " << d_nm->mkNode(TO_REAL, d_nm->mkConst(Rational(8,4))).getType()<< std::endl;
+ std::cout<<"Rational(1, 4).isIntegral() " << d_nm->mkNode(TO_REAL, d_nm->mkConst(Rational(1,4))).getType()<< std::endl;
}
private: