From 31983bd41f8c6ec736e374946de355fd1a9bc6f1 Mon Sep 17 00:00:00 2001 From: mudathirmahgoub Date: Wed, 21 Oct 2020 17:33:57 -0500 Subject: [PATCH] Implement bags evaluator (#5322) This PR implements NormalForm::evaluate for bags --- src/theory/bags/bags_rewriter.cpp | 2 +- src/theory/bags/normal_form.cpp | 604 +++++++++++++++++- src/theory/bags/normal_form.h | 147 ++++- src/theory/bags/theory_bags_type_rules.h | 2 +- test/unit/theory/CMakeLists.txt | 1 + .../theory/theory_bags_normal_form_white.h | 512 +++++++++++++++ .../theory/theory_bags_type_rules_white.h | 7 + 7 files changed, 1258 insertions(+), 17 deletions(-) create mode 100644 test/unit/theory/theory_bags_normal_form_white.h diff --git a/src/theory/bags/bags_rewriter.cpp b/src/theory/bags/bags_rewriter.cpp index c413a5e7e..26c54d4ec 100644 --- a/src/theory/bags/bags_rewriter.cpp +++ b/src/theory/bags/bags_rewriter.cpp @@ -51,7 +51,7 @@ RewriteResponse BagsRewriter::postRewrite(TNode n) // no need to rewrite n if it is already in a normal form response = BagsRewriteResponse(n, Rewrite::NONE); } - else if (NormalForm::AreChildrenConstants(n)) + else if (NormalForm::areChildrenConstants(n)) { Node value = NormalForm::evaluate(n); response = BagsRewriteResponse(value, Rewrite::CONSTANT_EVALUATION); diff --git a/src/theory/bags/normal_form.cpp b/src/theory/bags/normal_form.cpp index facad3c92..f2dea62d3 100644 --- a/src/theory/bags/normal_form.cpp +++ b/src/theory/bags/normal_form.cpp @@ -12,26 +12,620 @@ #include "normal_form.h" +#include "theory/sets/normal_form.h" +#include "theory/type_enumerator.h" + +using namespace CVC4::kind; + namespace CVC4 { namespace theory { namespace bags { -bool NormalForm::checkNormalConstant(TNode n) +bool NormalForm::isConstant(TNode n) { - // TODO(projects#223): complete this function + if (n.getKind() == EMPTYBAG) + { + // empty bags are already normalized + return true; + } + if (n.getKind() == MK_BAG) + { + // see the implementation in MkBagTypeRule::computeIsConst + return n.isConst(); + } + if (n.getKind() == UNION_DISJOINT) + { + if (!(n[0].getKind() == kind::MK_BAG && n[0].isConst())) + { + // the first child is not a constant + return false; + } + // store the previous element to check the ordering of elements + Node previousElement = n[0][0]; + Node current = n[1]; + while (current.getKind() == UNION_DISJOINT) + { + if (!(current[0].getKind() == kind::MK_BAG && current[0].isConst())) + { + // the current element is not a constant + return false; + } + if (previousElement >= current[0][0]) + { + // the ordering is violated + return false; + } + previousElement = current[0][0]; + current = current[1]; + } + // check last element + if (!(current.getKind() == kind::MK_BAG && current.isConst())) + { + // the last element is not a constant + return false; + } + if (previousElement >= current[0]) + { + // the ordering is violated + return false; + } + return true; + } + + // only nodes with kinds EMPTY_BAG, MK_BAG, and UNION_DISJOINT can be + // constants return false; } -bool NormalForm::AreChildrenConstants(TNode n) +bool NormalForm::areChildrenConstants(TNode n) { return std::all_of(n.begin(), n.end(), [](Node c) { return c.isConst(); }); } Node NormalForm::evaluate(TNode n) { - // TODO(projects#223): complete this function - return CVC4::Node(); + Assert(areChildrenConstants(n)); + if (n.isConst()) + { + // a constant node is already in a normal form + return n; + } + switch (n.getKind()) + { + case MK_BAG: return evaluateMakeBag(n); + case BAG_COUNT: return evaluateBagCount(n); + case UNION_DISJOINT: return evaluateUnionDisjoint(n); + case UNION_MAX: return evaluateUnionMax(n); + case INTERSECTION_MIN: return evaluateIntersectionMin(n); + case DIFFERENCE_SUBTRACT: return evaluateDifferenceSubtract(n); + case DIFFERENCE_REMOVE: return evaluateDifferenceRemove(n); + case BAG_CHOOSE: return evaluateChoose(n); + case BAG_CARD: return evaluateCard(n); + case BAG_IS_SINGLETON: return evaluateIsSingleton(n); + case BAG_FROM_SET: return evaluateFromSet(n); + case BAG_TO_SET: return evaluateToSet(n); + default: break; + } + Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n + << std::endl; +} + +template +Node NormalForm::evaluateBinaryOperation(const TNode& n, + T1&& equal, + T2&& less, + T3&& greaterOrEqual, + T4&& remainderOfA, + T5&& remainderOfB) +{ + std::map elementsA = getBagElements(n[0]); + std::map elementsB = getBagElements(n[1]); + std::map elements; + + std::map::const_iterator itA = elementsA.begin(); + std::map::const_iterator itB = elementsB.begin(); + + Trace("bags-evaluate") << "[NormalForm::evaluateBinaryOperation " + << n.getKind() << "] " << std::endl + << "elements A: " << elementsA << std::endl + << "elements B: " << elementsB << std::endl; + + while (itA != elementsA.end() && itB != elementsB.end()) + { + if (itA->first == itB->first) + { + equal(elements, itA, itB); + itA++; + itB++; + } + else if (itA->first < itB->first) + { + less(elements, itA, itB); + itA++; + } + else + { + greaterOrEqual(elements, itA, itB); + itB++; + } + } + + // handle the remaining elements from A + remainderOfA(elements, elementsA, itA); + // handle the remaining elements from B + remainderOfA(elements, elementsB, itB); + + Trace("bags-evaluate") << "elements: " << elements << std::endl; + Node bag = constructBagFromElements(n.getType(), elements); + Trace("bags-evaluate") << "bag: " << bag << std::endl; + return bag; +} + +std::map NormalForm::getBagElements(TNode n) +{ + Assert(n.isConst()) << "node " << n << " is not in a normal form" + << std::endl; + std::map elements; + if (n.getKind() == EMPTYBAG) + { + return elements; + } + while (n.getKind() == kind::UNION_DISJOINT) + { + Assert(n[0].getKind() == kind::MK_BAG); + Node element = n[0][0]; + Rational count = n[0][1].getConst(); + elements[element] = count; + n = n[1]; + } + Assert(n.getKind() == kind::MK_BAG); + Node lastElement = n[0]; + Rational lastCount = n[1].getConst(); + elements[lastElement] = lastCount; + return elements; +} + +Node NormalForm::constructBagFromElements( + TypeNode t, const std::map& elements) +{ + Assert(t.isBag()); + NodeManager* nm = NodeManager::currentNM(); + if (elements.empty()) + { + return nm->mkConst(EmptyBag(t)); + } + TypeNode elementType = t.getBagElementType(); + std::map::const_reverse_iterator it = elements.rbegin(); + Node bag = + nm->mkBag(elementType, it->first, nm->mkConst(it->second)); + while (++it != elements.rend()) + { + Node n = + nm->mkBag(elementType, it->first, nm->mkConst(it->second)); + bag = nm->mkNode(UNION_DISJOINT, n, bag); + } + return bag; +} + +Node NormalForm::evaluateMakeBag(TNode n) +{ + // the case where n is const should be handled earlier. + // here we handle the case where the multiplicity is zero or negative + Assert(n.getKind() == MK_BAG && !n.isConst() + && n[1].getConst().sgn() < 1); + Node emptybag = NodeManager::currentNM()->mkConst(EmptyBag(n.getType())); + return emptybag; +} + +Node NormalForm::evaluateBagCount(TNode n) +{ + Assert(n.getKind() == BAG_COUNT); + // Examples + // -------- + // - (bag.count "x" (emptybag String)) = 0 + // - (bag.count "x" (mkBag "y" 5)) = 0 + // - (bag.count "x" (mkBag "x" 4)) = 4 + // - (bag.count "x" (union_disjoint (mkBag "x" 4) (mkBag "y" 5)) = 4 + // - (bag.count "x" (union_disjoint (mkBag "y" 5) (mkBag "z" 5)) = 0 + + std::map elements = getBagElements(n[1]); + std::map::iterator it = elements.find(n[0]); + + NodeManager* nm = NodeManager::currentNM(); + if (it != elements.end()) + { + Node count = nm->mkConst(it->second); + return count; + } + return nm->mkConst(Rational(0)); +} + +Node NormalForm::evaluateUnionDisjoint(TNode n) +{ + Assert(n.getKind() == UNION_DISJOINT); + // Example + // ------- + // input: (union_disjoint A B) + // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) + // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) + // output: + // (union_disjoint A B) + // where A = (MK_BAG "x" 7) + // B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2))) + + auto equal = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // compute the sum of the multiplicities + elements[itA->first] = itA->second + itB->second; + }; + + auto less = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // add the element to the result + elements[itA->first] = itA->second; + }; + + auto greaterOrEqual = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // add the element to the result + elements[itB->first] = itB->second; + }; + + auto remainderOfA = [](std::map& elements, + std::map& elementsA, + std::map::const_iterator& itA) { + // append the remainder of A + while (itA != elementsA.end()) + { + elements[itA->first] = itA->second; + itA++; + } + }; + + auto remainderOfB = [](std::map& elements, + std::map& elementsB, + std::map::const_iterator& itB) { + // append the remainder of B + while (itB != elementsB.end()) + { + elements[itB->first] = itB->second; + itB++; + } + }; + + return evaluateBinaryOperation( + n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); +} + +Node NormalForm::evaluateUnionMax(TNode n) +{ + Assert(n.getKind() == UNION_MAX); + // Example + // ------- + // input: (union_max A B) + // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) + // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) + // output: + // (union_disjoint A B) + // where A = (MK_BAG "x" 4) + // B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2))) + + auto equal = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // compute the maximum multiplicity + elements[itA->first] = std::max(itA->second, itB->second); + }; + + auto less = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // add to the result + elements[itA->first] = itA->second; + }; + + auto greaterOrEqual = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // add to the result + elements[itB->first] = itB->second; + }; + + auto remainderOfA = [](std::map& elements, + std::map& elementsA, + std::map::const_iterator& itA) { + // append the remainder of A + while (itA != elementsA.end()) + { + elements[itA->first] = itA->second; + itA++; + } + }; + + auto remainderOfB = [](std::map& elements, + std::map& elementsB, + std::map::const_iterator& itB) { + // append the remainder of B + while (itB != elementsB.end()) + { + elements[itB->first] = itB->second; + itB++; + } + }; + + return evaluateBinaryOperation( + n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); } + +Node NormalForm::evaluateIntersectionMin(TNode n) +{ + Assert(n.getKind() == INTERSECTION_MIN); + // Example + // ------- + // input: (intersectionMin A B) + // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) + // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) + // output: + // (MK_BAG "x" 3) + + auto equal = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // compute the minimum multiplicity + elements[itA->first] = std::min(itA->second, itB->second); + }; + + auto less = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // do nothing + }; + + auto greaterOrEqual = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // do nothing + }; + + auto remainderOfA = [](std::map& elements, + std::map& elementsA, + std::map::const_iterator& itA) { + // do nothing + }; + + auto remainderOfB = [](std::map& elements, + std::map& elementsB, + std::map::const_iterator& itB) { + // do nothing + }; + + return evaluateBinaryOperation( + n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); +} + +Node NormalForm::evaluateDifferenceSubtract(TNode n) +{ + Assert(n.getKind() == DIFFERENCE_SUBTRACT); + // Example + // ------- + // input: (difference_subtract A B) + // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) + // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) + // output: + // (union_disjoint (MK_BAG "x" 1) (MK_BAG "z" 2)) + + auto equal = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // subtract the multiplicities + elements[itA->first] = itA->second - itB->second; + }; + + auto less = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // itA->first is not in B, so we add it to the difference subtract + elements[itA->first] = itA->second; + }; + + auto greaterOrEqual = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // itB->first is not in A, so we just skip it + }; + + auto remainderOfA = [](std::map& elements, + std::map& elementsA, + std::map::const_iterator& itA) { + // append the remainder of A + while (itA != elementsA.end()) + { + elements[itA->first] = itA->second; + itA++; + } + }; + + auto remainderOfB = [](std::map& elements, + std::map& elementsB, + std::map::const_iterator& itB) { + // do nothing + }; + + return evaluateBinaryOperation( + n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); +} + +Node NormalForm::evaluateDifferenceRemove(TNode n) +{ + Assert(n.getKind() == DIFFERENCE_REMOVE); + // Example + // ------- + // input: (difference_subtract A B) + // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) + // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) + // output: + // (MK_BAG "z" 2) + + auto equal = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // skip the shared element by doing nothing + }; + + auto less = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // itA->first is not in B, so we add it to the difference remove + elements[itA->first] = itA->second; + }; + + auto greaterOrEqual = [](std::map& elements, + std::map::const_iterator& itA, + std::map::const_iterator& itB) { + // itB->first is not in A, so we just skip it + }; + + auto remainderOfA = [](std::map& elements, + std::map& elementsA, + std::map::const_iterator& itA) { + // append the remainder of A + while (itA != elementsA.end()) + { + elements[itA->first] = itA->second; + itA++; + } + }; + + auto remainderOfB = [](std::map& elements, + std::map& elementsB, + std::map::const_iterator& itB) { + // do nothing + }; + + return evaluateBinaryOperation( + n, equal, less, greaterOrEqual, remainderOfA, remainderOfB); +} + +Node NormalForm::evaluateChoose(TNode n) +{ + Assert(n.getKind() == BAG_CHOOSE); + // Examples + // -------- + // - (choose (emptyBag String)) = "" // the empty string which is the first + // element returned by the type enumerator + // - (choose (MK_BAG "x" 4)) = "x" + // - (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) = "x" + // deterministically return the first element + + if (n[0].getKind() == EMPTYBAG) + { + TypeNode elementType = n[0].getType().getBagElementType(); + TypeEnumerator typeEnumerator(elementType); + // get the first value from the typeEnumerator + Node element = *typeEnumerator; + return element; + } + + if (n[0].getKind() == MK_BAG) + { + return n[0][0]; + } + Assert(n[0].getKind() == UNION_DISJOINT); + // return the first element + // e.g. (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) + return n[0][0][0]; +} + +Node NormalForm::evaluateCard(TNode n) +{ + Assert(n.getKind() == BAG_CARD); + // Examples + // -------- + // - (card (emptyBag String)) = 0 + // - (choose (MK_BAG "x" 4)) = 4 + // - (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) = 5 + + std::map elements = getBagElements(n[0]); + Rational sum(0); + for (std::pair element : elements) + { + sum += element.second; + } + + NodeManager* nm = NodeManager::currentNM(); + Node sumNode = nm->mkConst(sum); + return sumNode; +} + +Node NormalForm::evaluateIsSingleton(TNode n) +{ + Assert(n.getKind() == BAG_IS_SINGLETON); + // Examples + // -------- + // - (bag.is_singleton (emptyBag String)) = false + // - (bag.is_singleton (MK_BAG "x" 1)) = true + // - (bag.is_singleton (MK_BAG "x" 4)) = false + // - (bag.is_singleton (union_disjoint (MK_BAG "x" 1) (MK_BAG "y" 1))) = false + + if (n[0].getKind() == MK_BAG && n[0][1].getConst().isOne()) + { + return NodeManager::currentNM()->mkConst(true); + } + return NodeManager::currentNM()->mkConst(false); +} + +Node NormalForm::evaluateFromSet(TNode n) +{ + Assert(n.getKind() == BAG_FROM_SET); + + // Examples + // -------- + // - (bag.from_set (emptyset String)) = (emptybag String) + // - (bag.from_set (singleton "x")) = (mkBag "x" 1) + // - (bag.from_set (union (singleton "x") (singleton "y"))) = + // (disjoint_union (mkBag "x" 1) (mkBag "y" 1)) + + NodeManager* nm = NodeManager::currentNM(); + std::set setElements = + sets::NormalForm::getElementsFromNormalConstant(n[0]); + Rational one = Rational(1); + std::map bagElements; + for (const Node& element : setElements) + { + bagElements[element] = one; + } + TypeNode bagType = nm->mkBagType(n[0].getType().getSetElementType()); + Node bag = constructBagFromElements(bagType, bagElements); + return bag; +} + +Node NormalForm::evaluateToSet(TNode n) +{ + Assert(n.getKind() == BAG_TO_SET); + + // Examples + // -------- + // - (bag.to_set (emptybag String)) = (emptyset String) + // - (bag.to_set (mkBag "x" 4)) = (singleton "x") + // - (bag.to_set (disjoint_union (mkBag "x" 3) (mkBag "y" 5)) = + // (union (singleton "x") (singleton "y"))) + + NodeManager* nm = NodeManager::currentNM(); + std::map bagElements = getBagElements(n[0]); + std::set setElements; + std::map::const_reverse_iterator it; + for (it = bagElements.rbegin(); it != bagElements.rend(); it++) + { + setElements.insert(it->first); + } + TypeNode setType = nm->mkSetType(n[0].getType().getBagElementType()); + Node set = sets::NormalForm::elementsToSet(setElements, setType); + return set; +} + } // namespace bags } // namespace theory } // namespace CVC4 \ No newline at end of file diff --git a/src/theory/bags/normal_form.h b/src/theory/bags/normal_form.h index 8c719fe81..ef0edefff 100644 --- a/src/theory/bags/normal_form.h +++ b/src/theory/bags/normal_form.h @@ -29,22 +29,149 @@ class NormalForm /** * Returns true if n is considered a to be a (canonical) constant bag value. * A canonical bag value is one whose AST is: - * (disjoint-union (mk-bag e1 n1) ... - * (disjoint-union (mk-bag e_{n-1} n_{n-1}) (mk-bag e_n n_n)))) - * where c1 ... cn are constants and the node identifier of these constants - * are such that: - * c1 > ... > cn. - * Also handles the corner cases of empty bag and singleton bag. + * (union_disjoint (mkBag e1 c1) ... + * (union_disjoint (mkBag e_{n-1} c_{n-1}) (mkBag e_n c_n)))) + * where c1 ... cn are positive integers, e1 ... en are constants, and the + * node identifier of these constants are such that: e1 < ... < en. + * Also handles the corner cases of empty bag and bag constructed by mkBag */ - static bool checkNormalConstant(TNode n); + static bool isConstant(TNode n); /** - * check whether all children of the given node are in normal form + * check whether all children of the given node are constants */ - static bool AreChildrenConstants(TNode n); + static bool areChildrenConstants(TNode n); /** - * evaluate the node n to a constant value + * evaluate the node n to a constant value. + * As a precondition, children of n should be constants. */ static Node evaluate(TNode n); + + /** + * get the elements along with their multiplicities in a given bag + * @param n a constant node whose type is a bag + * @return a map whose keys are constant elements and values are + * multiplicities + */ + static std::map getBagElements(TNode n); + + /** + * construct a constant bag from constant elements + * @param t the type of the returned bag + * @param elements a map whose keys are constant elements and values are + * multiplicities + * @return a constant bag that contains + */ + static Node constructBagFromElements( + TypeNode t, const std::map& elements); + + private: + /** + * a high order helper function that return a constant bag that is the result + * of (op A B) where op is a binary operator and A, B are constant bags. + * The result is computed from the elements of A (elementsA with iterator itA) + * and elements of B (elementsB with iterator itB). + * The arguments below specify how these iterators are used to generate the + * elements of the result (elements). + * @param n a node whose kind is a binary operator (union_disjoint, union_max, + * intersection_min, difference_subtract, difference_remove) and whose + * children are constant bags. + * @param equal a lambda expression that receives (elements, itA, itB) and + * specify the action that needs to be taken when the elements of itA, itB are + * equal. + * @param less a lambda expression that receives (elements, itA, itB) and + * specify the action that needs to be taken when the element itA is less than + * the element of itB. + * @param greaterOrEqual less a lambda expression that receives (elements, + * itA, itB) and specify the action that needs to be taken when the element + * itA is greater than or equal than the element of itB. + * @param remainderOfA a lambda expression that receives (elements, elementsA, + * itA) and specify the action that needs to be taken to the remaining + * elements of A when all elements of B are visited. + * @param remainderOfB a lambda expression that receives (elements, elementsB, + * itB) and specify the action that needs to be taken to the remaining + * elements of B when all elements of A are visited. + * @return a constant bag that the result of (op n[0] n[1]) + */ + template + static Node evaluateBinaryOperation(const TNode& n, + T1&& equal, + T2&& less, + T3&& greaterOrEqual, + T4&& remainderOfA, + T5&& remainderOfB); + /** + * evaluate n as follows: + * - (mkBag a 0) = (emptybag T) where T is the type of the original bag + * - (mkBag a (-c)) = (emptybag T) where T is the type the original bag, + * and c > 0 is a constant + */ + static Node evaluateMakeBag(TNode n); + + /** + * returns the multiplicity in a constant bag + * @param n has the form (bag.count x A) where x, A are constants + * @return the multiplicity of element x in bag A. + */ + static Node evaluateBagCount(TNode n); + + /** + * evaluates union disjoint node such that the returned node is a canonical + * bag that has the form + * (union_disjoint (mkBag e1 c1) ... + * (union_disjoint * (mkBag e_{n-1} c_{n-1}) (mkBag e_n c_n)))) where + * c1... cn are positive integers, e1 ... en are constants, and the node + * identifier of these constants are such that: e1 < ... < en. + * @param n has the form (union_disjoint A B) where A, B are constant bags + * @return the union disjoint of A and B + */ + static Node evaluateUnionDisjoint(TNode n); + /** + * @param n has the form (union_max A B) where A, B are constant bags + * @return the union max of A and B + */ + static Node evaluateUnionMax(TNode n); + /** + * @param n has the form (intersection_min A B) where A, B are constant bags + * @return the intersection min of A and B + */ + static Node evaluateIntersectionMin(TNode n); + /** + * @param n has the form (difference_subtract A B) where A, B are constant + * bags + * @return the difference subtract of A and B + */ + static Node evaluateDifferenceSubtract(TNode n); + /** + * @param n has the form (difference_remove A B) where A, B are constant bags + * @return the difference remove of A and B + */ + static Node evaluateDifferenceRemove(TNode n); + /** + * @param n has the form (bag.choose A) where A is a constant bag + * @return the first element of A if A is not empty. Otherwise, it returns the + * first element returned by the type enumerator for the elements + */ + static Node evaluateChoose(TNode n); + /** + * @param n has the form (bag.card A) where A is a constant bag + * @return the number of elements in bag A + */ + static Node evaluateCard(TNode n); + /** + * @param n has the form (bag.is_singleton A) where A is a constant bag + * @return whether the bag A has cardinality one. + */ + static Node evaluateIsSingleton(TNode n); + /** + * @param n has the form (bag.from_set A) where A is a constant set + * @return a constant bag that contains exactly the elements in A. + */ + static Node evaluateFromSet(TNode n); + /** + * @param n has the form (bag.to_set A) where A is a constant bag + * @return a constant set constructed from the elements in A. + */ + static Node evaluateToSet(TNode n); }; } // namespace bags } // namespace theory diff --git a/src/theory/bags/theory_bags_type_rules.h b/src/theory/bags/theory_bags_type_rules.h index 75f57ec88..7767938ed 100644 --- a/src/theory/bags/theory_bags_type_rules.h +++ b/src/theory/bags/theory_bags_type_rules.h @@ -57,7 +57,7 @@ struct BinaryOperatorTypeRule // only UNION_DISJOINT has a const rule in kinds. // Other binary operators do not have const rules in kinds Assert(n.getKind() == kind::UNION_DISJOINT); - return NormalForm::checkNormalConstant(n); + return NormalForm::isConstant(n); } }; /* struct BinaryOperatorTypeRule */ diff --git a/test/unit/theory/CMakeLists.txt b/test/unit/theory/CMakeLists.txt index 481c80f26..8cfd43989 100644 --- a/test/unit/theory/CMakeLists.txt +++ b/test/unit/theory/CMakeLists.txt @@ -14,6 +14,7 @@ cvc4_add_unit_test_white(evaluator_white theory) cvc4_add_unit_test_white(logic_info_white theory) cvc4_add_unit_test_white(sequences_rewriter_white theory) cvc4_add_unit_test_white(theory_arith_white theory) +cvc4_add_unit_test_white(theory_bags_normal_form_white theory) cvc4_add_unit_test_white(theory_bags_rewriter_white theory) cvc4_add_unit_test_white(theory_bags_type_rules_white theory) cvc4_add_unit_test_white(theory_bv_rewriter_white theory) diff --git a/test/unit/theory/theory_bags_normal_form_white.h b/test/unit/theory/theory_bags_normal_form_white.h new file mode 100644 index 000000000..6f7d5bd8d --- /dev/null +++ b/test/unit/theory/theory_bags_normal_form_white.h @@ -0,0 +1,512 @@ +/********************* */ +/*! \file theory_bags_normal_form_white.h + ** \verbatim + ** Top contributors (to current version): + ** Mudathir Mohamed + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief White box testing of bags normal form + **/ + +#include + +#include "expr/dtype.h" +#include "smt/smt_engine.h" +#include "theory/bags/bags_rewriter.h" +#include "theory/bags/normal_form.h" +#include "theory/strings/type_enumerator.h" + +using namespace CVC4; +using namespace CVC4::smt; +using namespace CVC4::theory; +using namespace CVC4::kind; +using namespace CVC4::theory::bags; +using namespace std; + +typedef expr::Attribute attribute; + +class BagsNormalFormWhite : public CxxTest::TestSuite +{ + public: + void setUp() override + { + d_em.reset(new ExprManager()); + d_smt.reset(new SmtEngine(d_em.get())); + d_nm.reset(NodeManager::fromExprManager(d_em.get())); + d_smt->finishInit(); + d_rewriter.reset(new BagsRewriter(nullptr)); + } + + void tearDown() override + { + d_rewriter.reset(); + d_smt.reset(); + d_nm.release(); + d_em.reset(); + } + + std::vector getNStrings(size_t n) + { + std::vector elements(n); + CVC4::theory::strings::StringEnumerator enumerator(d_nm->stringType()); + + for (size_t i = 0; i < n; i++) + { + ++enumerator; + elements[i] = *enumerator; + } + + return elements; + } + + void testEmptyBagNormalForm() + { + Node emptybag = d_nm->mkConst(EmptyBag(d_nm->stringType())); + // empty bags are in normal form + TS_ASSERT(emptybag.isConst()); + Node n = NormalForm::evaluate(emptybag); + TS_ASSERT(emptybag == n); + } + + void testBagEquality() {} + + void testMkBagConstantElement() + { + vector elements = getNStrings(1); + Node negative = d_nm->mkBag( + d_nm->stringType(), elements[0], d_nm->mkConst(Rational(-1))); + Node zero = d_nm->mkBag( + d_nm->stringType(), elements[0], d_nm->mkConst(Rational(0))); + Node positive = d_nm->mkBag( + d_nm->stringType(), elements[0], d_nm->mkConst(Rational(1))); + Node emptybag = + d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); + + TS_ASSERT(!negative.isConst()); + TS_ASSERT(!zero.isConst()); + TS_ASSERT(emptybag == NormalForm::evaluate(negative)); + TS_ASSERT(emptybag == NormalForm::evaluate(zero)); + TS_ASSERT(positive == NormalForm::evaluate(positive)); + } + + void testBagCount() + { + // Examples + // ------- + // (bag.count "x" (emptybag String)) = 0 + // (bag.count "x" (mkBag "y" 5)) = 0 + // (bag.count "x" (mkBag "x" 4)) = 4 + // (bag.count "x" (union_disjoint (mkBag "x" 4) (mkBag "y" 5)) = 4 + // (bag.count "x" (union_disjoint (mkBag "y" 5) (mkBag "z" 5)) = 0 + + Node zero = d_nm->mkConst(Rational(0)); + Node four = d_nm->mkConst(Rational(4)); + Node empty = d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); + Node x = d_nm->mkConst(String("x")); + Node y = d_nm->mkConst(String("y")); + Node z = d_nm->mkConst(String("z")); + Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4))); + Node y_5 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(5))); + Node z_5 = d_nm->mkBag(d_nm->stringType(), z, d_nm->mkConst(Rational(5))); + + Node input1 = d_nm->mkNode(BAG_COUNT, x, empty); + Node output1 = zero; + TS_ASSERT(output1 == NormalForm::evaluate(input1)); + + Node input2 = d_nm->mkNode(BAG_COUNT, x, y_5); + Node output2 = zero; + TS_ASSERT(output2 == NormalForm::evaluate(input2)); + + Node input3 = d_nm->mkNode(BAG_COUNT, x, x_4); + Node output3 = four; + TS_ASSERT(output2 == NormalForm::evaluate(input2)); + + Node unionDisjointXY = d_nm->mkNode(UNION_DISJOINT, x_4, y_5); + Node input4 = d_nm->mkNode(BAG_COUNT, x, unionDisjointXY); + Node output4 = four; + TS_ASSERT(output3 == NormalForm::evaluate(input3)); + + Node unionDisjointYZ = d_nm->mkNode(UNION_DISJOINT, y_5, z_5); + Node input5 = d_nm->mkNode(BAG_COUNT, x, unionDisjointYZ); + Node output5 = zero; + TS_ASSERT(output4 == NormalForm::evaluate(input4)); + } + + void testUnionMax() + { + // Example + // ------- + // input: (union_max A B) + // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) + // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) + // output: + // (union_disjoint A B) + // where A = (MK_BAG "x" 4) + // B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2))) + + Node x = d_nm->mkConst(String("x")); + Node y = d_nm->mkConst(String("y")); + Node z = d_nm->mkConst(String("z")); + Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4))); + Node x_3 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(3))); + Node x_7 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(7))); + Node z_2 = d_nm->mkBag(d_nm->stringType(), z, d_nm->mkConst(Rational(2))); + Node y_1 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(1))); + + Node A = d_nm->mkNode(UNION_DISJOINT, x_4, z_2); + Node B = d_nm->mkNode(UNION_DISJOINT, x_3, y_1); + Node input = d_nm->mkNode(UNION_MAX, A, B); + + // output + Node output = d_nm->mkNode( + UNION_DISJOINT, x_4, d_nm->mkNode(UNION_DISJOINT, y_1, z_2)); + + TS_ASSERT(output.isConst()); + TS_ASSERT(output == NormalForm::evaluate(input)); + } + + void testUnionDisjoint1() + { + vector elements = getNStrings(3); + Node emptybag = + d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); + Node A = d_nm->mkBag( + d_nm->stringType(), elements[0], d_nm->mkConst(Rational(2))); + Node B = d_nm->mkBag( + d_nm->stringType(), elements[1], d_nm->mkConst(Rational(3))); + Node C = d_nm->mkBag( + d_nm->stringType(), elements[2], d_nm->mkConst(Rational(4))); + + Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B); + // unionDisjointAB is already in a normal form + TS_ASSERT(unionDisjointAB.isConst()); + TS_ASSERT(unionDisjointAB == NormalForm::evaluate(unionDisjointAB)); + + Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A); + // unionDisjointAB is is the normal form of unionDisjointBA + TS_ASSERT(!unionDisjointBA.isConst()); + TS_ASSERT(unionDisjointAB == NormalForm::evaluate(unionDisjointBA)); + + Node unionDisjointAB_C = d_nm->mkNode(UNION_DISJOINT, unionDisjointAB, C); + Node unionDisjointBC = d_nm->mkNode(UNION_DISJOINT, B, C); + Node unionDisjointA_BC = d_nm->mkNode(UNION_DISJOINT, A, unionDisjointBC); + // unionDisjointA_BC is the normal form of unionDisjointAB_C + TS_ASSERT(!unionDisjointAB_C.isConst()); + TS_ASSERT(unionDisjointA_BC.isConst()); + TS_ASSERT(unionDisjointA_BC == NormalForm::evaluate(unionDisjointAB_C)); + + Node unionDisjointAA = d_nm->mkNode(UNION_DISJOINT, A, A); + Node AA = d_nm->mkBag( + d_nm->stringType(), elements[0], d_nm->mkConst(Rational(4))); + TS_ASSERT(!unionDisjointAA.isConst()); + TS_ASSERT(AA.isConst()); + TS_ASSERT(AA == NormalForm::evaluate(unionDisjointAA)); + } + + void testUnionDisjoint2() + { + // Example + // ------- + // input: (union_disjoint A B) + // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) + // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) + // output: + // (union_disjoint A B) + // where A = (MK_BAG "x" 7) + // B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2))) + + Node x = d_nm->mkConst(String("x")); + Node y = d_nm->mkConst(String("y")); + Node z = d_nm->mkConst(String("z")); + Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4))); + Node x_3 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(3))); + Node x_7 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(7))); + Node z_2 = d_nm->mkBag(d_nm->stringType(), z, d_nm->mkConst(Rational(2))); + Node y_1 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(1))); + + Node A = d_nm->mkNode(UNION_DISJOINT, x_4, z_2); + Node B = d_nm->mkNode(UNION_DISJOINT, x_3, y_1); + Node input = d_nm->mkNode(UNION_DISJOINT, A, B); + + // output + Node output = d_nm->mkNode( + UNION_DISJOINT, x_7, d_nm->mkNode(UNION_DISJOINT, y_1, z_2)); + + TS_ASSERT(output.isConst()); + TS_ASSERT(output == NormalForm::evaluate(input)); + } + + void testIntersectionMin() + { + // Example + // ------- + // input: (intersection_min A B) + // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) + // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) + // output: + // (MK_BAG "x" 3) + + Node x = d_nm->mkConst(String("x")); + Node y = d_nm->mkConst(String("y")); + Node z = d_nm->mkConst(String("z")); + Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4))); + Node x_3 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(3))); + Node x_7 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(7))); + Node z_2 = d_nm->mkBag(d_nm->stringType(), z, d_nm->mkConst(Rational(2))); + Node y_1 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(1))); + + Node A = d_nm->mkNode(UNION_DISJOINT, x_4, z_2); + Node B = d_nm->mkNode(UNION_DISJOINT, x_3, y_1); + Node input = d_nm->mkNode(INTERSECTION_MIN, A, B); + + // output + Node output = x_3; + + TS_ASSERT(output.isConst()); + TS_ASSERT(output == NormalForm::evaluate(input)); + } + + void testDifferenceSubtract() + { + // Example + // ------- + // input: (difference_subtract A B) + // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) + // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) + // output: + // (union_disjoint (MK_BAG "x" 1) (MK_BAG "z" 2)) + + Node x = d_nm->mkConst(String("x")); + Node y = d_nm->mkConst(String("y")); + Node z = d_nm->mkConst(String("z")); + Node x_1 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(1))); + Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4))); + Node x_3 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(3))); + Node x_7 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(7))); + Node z_2 = d_nm->mkBag(d_nm->stringType(), z, d_nm->mkConst(Rational(2))); + Node y_1 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(1))); + + Node A = d_nm->mkNode(UNION_DISJOINT, x_4, z_2); + Node B = d_nm->mkNode(UNION_DISJOINT, x_3, y_1); + Node input = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, B); + + // output + Node output = d_nm->mkNode(UNION_DISJOINT, x_1, z_2); + + TS_ASSERT(output.isConst()); + TS_ASSERT(output == NormalForm::evaluate(input)); + } + + void testDifferenceRemove() + { + // Example + // ------- + // input: (difference_remove A B) + // where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2))) + // B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1))) + // output: + // (MK_BAG "z" 2) + + Node x = d_nm->mkConst(String("x")); + Node y = d_nm->mkConst(String("y")); + Node z = d_nm->mkConst(String("z")); + Node x_1 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(1))); + Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4))); + Node x_3 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(3))); + Node x_7 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(7))); + Node z_2 = d_nm->mkBag(d_nm->stringType(), z, d_nm->mkConst(Rational(2))); + Node y_1 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(1))); + + Node A = d_nm->mkNode(UNION_DISJOINT, x_4, z_2); + Node B = d_nm->mkNode(UNION_DISJOINT, x_3, y_1); + Node input = d_nm->mkNode(DIFFERENCE_REMOVE, A, B); + + // output + Node output = z_2; + + TS_ASSERT(output.isConst()); + TS_ASSERT(output == NormalForm::evaluate(input)); + } + + void testChoose() + { + // Example + // ------- + // input: (choose (emptybag String)) + // output: "A"; the first element returned by the type enumerator + // input: (choose (MK_BAG "x" 4)) + // output: "x" + // input: (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) + // output: "x"; deterministically return the first element + Node empty = d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); + Node x = d_nm->mkConst(String("x")); + Node y = d_nm->mkConst(String("y")); + Node z = d_nm->mkConst(String("z")); + Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4))); + Node y_1 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(1))); + + Node input1 = d_nm->mkNode(BAG_CHOOSE, empty); + Node output1 = d_nm->mkConst(String("")); + + TS_ASSERT(output1 == NormalForm::evaluate(input1)); + + Node input2 = d_nm->mkNode(BAG_CHOOSE, x_4); + Node output2 = x; + TS_ASSERT(output2 == NormalForm::evaluate(input2)); + + Node union_disjoint = d_nm->mkNode(UNION_DISJOINT, x_4, y_1); + Node input3 = d_nm->mkNode(BAG_CHOOSE, union_disjoint); + Node output3 = x; + TS_ASSERT(output3 == NormalForm::evaluate(input3)); + } + + void testBagCard() + { + // Examples + // -------- + // - (card (emptybag String)) = 0 + // - (choose (MK_BAG "x" 4)) = 4 + // - (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) = 5 + Node empty = d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); + Node x = d_nm->mkConst(String("x")); + Node y = d_nm->mkConst(String("y")); + Node z = d_nm->mkConst(String("z")); + Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4))); + Node y_1 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(1))); + + Node input1 = d_nm->mkNode(BAG_CARD, empty); + Node output1 = d_nm->mkConst(Rational(0)); + + TS_ASSERT(output1 == NormalForm::evaluate(input1)); + + Node input2 = d_nm->mkNode(BAG_CARD, x_4); + Node output2 = d_nm->mkConst(Rational(4)); + TS_ASSERT(output2 == NormalForm::evaluate(input2)); + + Node union_disjoint = d_nm->mkNode(UNION_DISJOINT, x_4, y_1); + Node input3 = d_nm->mkNode(BAG_CARD, union_disjoint); + Node output3 = d_nm->mkConst(Rational(5)); + TS_ASSERT(output3 == NormalForm::evaluate(input3)); + } + + void testIsSingleton() + { + // Examples + // -------- + // - (bag.is_singleton (emptybag String)) = false + // - (bag.is_singleton (MK_BAG "x" 1)) = true + // - (bag.is_singleton (MK_BAG "x" 4)) = false + // - (bag.is_singleton (union_disjoint (MK_BAG "x" 1) (MK_BAG "y" 1))) = + // false + Node falseNode = d_nm->mkConst(false); + Node trueNode = d_nm->mkConst(true); + Node empty = d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); + Node x = d_nm->mkConst(String("x")); + Node y = d_nm->mkConst(String("y")); + Node z = d_nm->mkConst(String("z")); + Node x_1 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(1))); + Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4))); + Node y_1 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(1))); + + Node input1 = d_nm->mkNode(BAG_IS_SINGLETON, empty); + Node output1 = falseNode; + TS_ASSERT(output1 == NormalForm::evaluate(input1)); + + Node input2 = d_nm->mkNode(BAG_IS_SINGLETON, x_1); + Node output2 = trueNode; + TS_ASSERT(output2 == NormalForm::evaluate(input2)); + + Node input3 = d_nm->mkNode(BAG_IS_SINGLETON, x_4); + Node output3 = falseNode; + TS_ASSERT(output2 == NormalForm::evaluate(input2)); + + Node union_disjoint = d_nm->mkNode(UNION_DISJOINT, x_1, y_1); + Node input4 = d_nm->mkNode(BAG_IS_SINGLETON, union_disjoint); + Node output4 = falseNode; + TS_ASSERT(output3 == NormalForm::evaluate(input3)); + } + + void testFromSet() + { + // Examples + // -------- + // - (bag.from_set (emptyset String)) = (emptybag String) + // - (bag.from_set (singleton "x")) = (mkBag "x" 1) + // - (bag.to_set (union (singleton "x") (singleton "y"))) = + // (disjoint_union (mkBag "x" 1) (mkBag "y" 1)) + + Node emptyset = + d_nm->mkConst(EmptySet(d_nm->mkSetType(d_nm->stringType()))); + Node emptybag = + d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); + Node input1 = d_nm->mkNode(BAG_FROM_SET, emptyset); + Node output1 = emptybag; + TS_ASSERT(output1 == NormalForm::evaluate(input1)); + + Node x = d_nm->mkConst(String("x")); + Node y = d_nm->mkConst(String("y")); + + Node xSingleton = d_nm->mkSingleton(d_nm->stringType(), x); + Node ySingleton = d_nm->mkSingleton(d_nm->stringType(), y); + + Node x_1 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(1))); + Node y_1 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(1))); + + Node input2 = d_nm->mkNode(BAG_FROM_SET, xSingleton); + Node output2 = x_1; + TS_ASSERT(output2 == NormalForm::evaluate(input2)); + + // for normal sets, the first node is the largest, not smallest + Node normalSet = d_nm->mkNode(UNION, ySingleton, xSingleton); + Node input3 = d_nm->mkNode(BAG_FROM_SET, normalSet); + Node output3 = d_nm->mkNode(UNION_DISJOINT, x_1, y_1); + TS_ASSERT(output3 == NormalForm::evaluate(input3)); + } + + void testToSet() + { + // Examples + // -------- + // - (bag.to_set (emptybag String)) = (emptyset String) + // - (bag.to_set (mkBag "x" 4)) = (singleton "x") + // - (bag.to_set (disjoint_union (mkBag "x" 3) (mkBag "y" 5)) = + // (union (singleton "x") (singleton "y"))) + + Node emptyset = + d_nm->mkConst(EmptySet(d_nm->mkSetType(d_nm->stringType()))); + Node emptybag = + d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType()))); + Node input1 = d_nm->mkNode(BAG_TO_SET, emptybag); + Node output1 = emptyset; + TS_ASSERT(output1 == NormalForm::evaluate(input1)); + + Node x = d_nm->mkConst(String("x")); + Node y = d_nm->mkConst(String("y")); + + Node xSingleton = d_nm->mkSingleton(d_nm->stringType(), x); + Node ySingleton = d_nm->mkSingleton(d_nm->stringType(), y); + + Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4))); + Node y_5 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(5))); + + Node input2 = d_nm->mkNode(BAG_TO_SET, x_4); + Node output2 = xSingleton; + TS_ASSERT(output2 == NormalForm::evaluate(input2)); + + // for normal sets, the first node is the largest, not smallest + Node normalBag = d_nm->mkNode(UNION_DISJOINT, x_4, y_5); + Node input3 = d_nm->mkNode(BAG_TO_SET, normalBag); + Node output3 = d_nm->mkNode(UNION, ySingleton, xSingleton); + TS_ASSERT(output3 == NormalForm::evaluate(input3)); + } + + private: + std::unique_ptr d_em; + std::unique_ptr d_smt; + std::unique_ptr d_nm; + std::unique_ptr d_rewriter; +}; /* class BagsTypeRuleBlack */ diff --git a/test/unit/theory/theory_bags_type_rules_white.h b/test/unit/theory/theory_bags_type_rules_white.h index dfe2d4cac..5622a3000 100644 --- a/test/unit/theory/theory_bags_type_rules_white.h +++ b/test/unit/theory/theory_bags_type_rules_white.h @@ -104,6 +104,13 @@ class BagsTypeRuleWhite : public CxxTest::TestSuite Node bag = d_nm->mkBag(d_nm->stringType(), elements[0], d_nm->mkConst(Rational(10))); TS_ASSERT_THROWS_NOTHING(d_nm->mkNode(BAG_TO_SET, bag)); TS_ASSERT(d_nm->mkNode(BAG_TO_SET, bag).getType().isSet()); + std::cout<<"Rational(4, 4).isIntegral() " << d_nm->mkConst(Rational(4,4)).getType()<< std::endl; + std::cout<<"Rational(8, 4).isIntegral() " << d_nm->mkConst(Rational(8,4)).getType()<< std::endl; + std::cout<<"Rational(1, 4).isIntegral() " << d_nm->mkConst(Rational(1,4)).getType()<< std::endl; + + std::cout<<"Rational(4, 4).isIntegral() " << d_nm->mkNode(TO_REAL, d_nm->mkConst(Rational(4,4))).getType()<< std::endl; + std::cout<<"Rational(8, 4).isIntegral() " << d_nm->mkNode(TO_REAL, d_nm->mkConst(Rational(8,4))).getType()<< std::endl; + std::cout<<"Rational(1, 4).isIntegral() " << d_nm->mkNode(TO_REAL, d_nm->mkConst(Rational(1,4))).getType()<< std::endl; } private: -- 2.30.2