Implement bags evaluator (#5322)
authormudathirmahgoub <mudathirmahgoub@gmail.com>
Wed, 21 Oct 2020 22:33:57 +0000 (17:33 -0500)
committerGitHub <noreply@github.com>
Wed, 21 Oct 2020 22:33:57 +0000 (17:33 -0500)
This PR implements NormalForm::evaluate for bags

src/theory/bags/bags_rewriter.cpp
src/theory/bags/normal_form.cpp
src/theory/bags/normal_form.h
src/theory/bags/theory_bags_type_rules.h
test/unit/theory/CMakeLists.txt
test/unit/theory/theory_bags_normal_form_white.h [new file with mode: 0644]
test/unit/theory/theory_bags_type_rules_white.h

index c413a5e7e1eabfadefc565f644522a2eb17de3ad..26c54d4ecf6c430328556cecbb1f66d08dd553f8 100644 (file)
@@ -51,7 +51,7 @@ RewriteResponse BagsRewriter::postRewrite(TNode n)
     // no need to rewrite n if it is already in a normal form
     response = BagsRewriteResponse(n, Rewrite::NONE);
   }
-  else if (NormalForm::AreChildrenConstants(n))
+  else if (NormalForm::areChildrenConstants(n))
   {
     Node value = NormalForm::evaluate(n);
     response = BagsRewriteResponse(value, Rewrite::CONSTANT_EVALUATION);
index facad3c927551e9b067e68020b55798e53bc8c0f..f2dea62d384538d4c83421091878b36f230603c7 100644 (file)
 
 #include "normal_form.h"
 
+#include "theory/sets/normal_form.h"
+#include "theory/type_enumerator.h"
+
+using namespace CVC4::kind;
+
 namespace CVC4 {
 namespace theory {
 namespace bags {
 
-bool NormalForm::checkNormalConstant(TNode n)
+bool NormalForm::isConstant(TNode n)
 {
-  // TODO(projects#223): complete this function
+  if (n.getKind() == EMPTYBAG)
+  {
+    // empty bags are already normalized
+    return true;
+  }
+  if (n.getKind() == MK_BAG)
+  {
+    // see the implementation in MkBagTypeRule::computeIsConst
+    return n.isConst();
+  }
+  if (n.getKind() == UNION_DISJOINT)
+  {
+    if (!(n[0].getKind() == kind::MK_BAG && n[0].isConst()))
+    {
+      // the first child is not a constant
+      return false;
+    }
+    // store the previous element to check the ordering of elements
+    Node previousElement = n[0][0];
+    Node current = n[1];
+    while (current.getKind() == UNION_DISJOINT)
+    {
+      if (!(current[0].getKind() == kind::MK_BAG && current[0].isConst()))
+      {
+        // the current element is not a constant
+        return false;
+      }
+      if (previousElement >= current[0][0])
+      {
+        // the ordering is violated
+        return false;
+      }
+      previousElement = current[0][0];
+      current = current[1];
+    }
+    // check last element
+    if (!(current.getKind() == kind::MK_BAG && current.isConst()))
+    {
+      // the last element is not a constant
+      return false;
+    }
+    if (previousElement >= current[0])
+    {
+      // the ordering is violated
+      return false;
+    }
+    return true;
+  }
+
+  // only nodes with kinds EMPTY_BAG, MK_BAG, and UNION_DISJOINT can be
+  // constants
   return false;
 }
 
-bool NormalForm::AreChildrenConstants(TNode n)
+bool NormalForm::areChildrenConstants(TNode n)
 {
   return std::all_of(n.begin(), n.end(), [](Node c) { return c.isConst(); });
 }
 
 Node NormalForm::evaluate(TNode n)
 {
-  // TODO(projects#223): complete this function
-  return CVC4::Node();
+  Assert(areChildrenConstants(n));
+  if (n.isConst())
+  {
+    // a constant node is already in a normal form
+    return n;
+  }
+  switch (n.getKind())
+  {
+    case MK_BAG: return evaluateMakeBag(n);
+    case BAG_COUNT: return evaluateBagCount(n);
+    case UNION_DISJOINT: return evaluateUnionDisjoint(n);
+    case UNION_MAX: return evaluateUnionMax(n);
+    case INTERSECTION_MIN: return evaluateIntersectionMin(n);
+    case DIFFERENCE_SUBTRACT: return evaluateDifferenceSubtract(n);
+    case DIFFERENCE_REMOVE: return evaluateDifferenceRemove(n);
+    case BAG_CHOOSE: return evaluateChoose(n);
+    case BAG_CARD: return evaluateCard(n);
+    case BAG_IS_SINGLETON: return evaluateIsSingleton(n);
+    case BAG_FROM_SET: return evaluateFromSet(n);
+    case BAG_TO_SET: return evaluateToSet(n);
+    default: break;
+  }
+  Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n
+              << std::endl;
+}
+
+template <typename T1, typename T2, typename T3, typename T4, typename T5>
+Node NormalForm::evaluateBinaryOperation(const TNode& n,
+                                         T1&& equal,
+                                         T2&& less,
+                                         T3&& greaterOrEqual,
+                                         T4&& remainderOfA,
+                                         T5&& remainderOfB)
+{
+  std::map<Node, Rational> elementsA = getBagElements(n[0]);
+  std::map<Node, Rational> elementsB = getBagElements(n[1]);
+  std::map<Node, Rational> elements;
+
+  std::map<Node, Rational>::const_iterator itA = elementsA.begin();
+  std::map<Node, Rational>::const_iterator itB = elementsB.begin();
+
+  Trace("bags-evaluate") << "[NormalForm::evaluateBinaryOperation "
+                         << n.getKind() << "] " << std::endl
+                         << "elements A: " << elementsA << std::endl
+                         << "elements B: " << elementsB << std::endl;
+
+  while (itA != elementsA.end() && itB != elementsB.end())
+  {
+    if (itA->first == itB->first)
+    {
+      equal(elements, itA, itB);
+      itA++;
+      itB++;
+    }
+    else if (itA->first < itB->first)
+    {
+      less(elements, itA, itB);
+      itA++;
+    }
+    else
+    {
+      greaterOrEqual(elements, itA, itB);
+      itB++;
+    }
+  }
+
+  // handle the remaining elements from A
+  remainderOfA(elements, elementsA, itA);
+  // handle the remaining elements from B
+  remainderOfA(elements, elementsB, itB);
+
+  Trace("bags-evaluate") << "elements: " << elements << std::endl;
+  Node bag = constructBagFromElements(n.getType(), elements);
+  Trace("bags-evaluate") << "bag: " << bag << std::endl;
+  return bag;
+}
+
+std::map<Node, Rational> NormalForm::getBagElements(TNode n)
+{
+  Assert(n.isConst()) << "node " << n << " is not in a normal form"
+                      << std::endl;
+  std::map<Node, Rational> elements;
+  if (n.getKind() == EMPTYBAG)
+  {
+    return elements;
+  }
+  while (n.getKind() == kind::UNION_DISJOINT)
+  {
+    Assert(n[0].getKind() == kind::MK_BAG);
+    Node element = n[0][0];
+    Rational count = n[0][1].getConst<Rational>();
+    elements[element] = count;
+    n = n[1];
+  }
+  Assert(n.getKind() == kind::MK_BAG);
+  Node lastElement = n[0];
+  Rational lastCount = n[1].getConst<Rational>();
+  elements[lastElement] = lastCount;
+  return elements;
+}
+
+Node NormalForm::constructBagFromElements(
+    TypeNode t, const std::map<Node, Rational>& elements)
+{
+  Assert(t.isBag());
+  NodeManager* nm = NodeManager::currentNM();
+  if (elements.empty())
+  {
+    return nm->mkConst(EmptyBag(t));
+  }
+  TypeNode elementType = t.getBagElementType();
+  std::map<Node, Rational>::const_reverse_iterator it = elements.rbegin();
+  Node bag =
+      nm->mkBag(elementType, it->first, nm->mkConst<Rational>(it->second));
+  while (++it != elements.rend())
+  {
+    Node n =
+        nm->mkBag(elementType, it->first, nm->mkConst<Rational>(it->second));
+    bag = nm->mkNode(UNION_DISJOINT, n, bag);
+  }
+  return bag;
+}
+
+Node NormalForm::evaluateMakeBag(TNode n)
+{
+  // the case where n is const should be handled earlier.
+  // here we handle the case where the multiplicity is zero or negative
+  Assert(n.getKind() == MK_BAG && !n.isConst()
+         && n[1].getConst<Rational>().sgn() < 1);
+  Node emptybag = NodeManager::currentNM()->mkConst(EmptyBag(n.getType()));
+  return emptybag;
+}
+
+Node NormalForm::evaluateBagCount(TNode n)
+{
+  Assert(n.getKind() == BAG_COUNT);
+  // Examples
+  // --------
+  // - (bag.count "x" (emptybag String)) = 0
+  // - (bag.count "x" (mkBag "y" 5)) = 0
+  // - (bag.count "x" (mkBag "x" 4)) = 4
+  // - (bag.count "x" (union_disjoint (mkBag "x" 4) (mkBag "y" 5)) = 4
+  // - (bag.count "x" (union_disjoint (mkBag "y" 5) (mkBag "z" 5)) = 0
+
+  std::map<Node, Rational> elements = getBagElements(n[1]);
+  std::map<Node, Rational>::iterator it = elements.find(n[0]);
+
+  NodeManager* nm = NodeManager::currentNM();
+  if (it != elements.end())
+  {
+    Node count = nm->mkConst(it->second);
+    return count;
+  }
+  return nm->mkConst(Rational(0));
+}
+
+Node NormalForm::evaluateUnionDisjoint(TNode n)
+{
+  Assert(n.getKind() == UNION_DISJOINT);
+  // Example
+  // -------
+  // input: (union_disjoint A B)
+  //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
+  //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
+  // output:
+  //    (union_disjoint A B)
+  //        where A = (MK_BAG "x" 7)
+  //              B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2)))
+
+  auto equal = [](std::map<Node, Rational>& elements,
+                  std::map<Node, Rational>::const_iterator& itA,
+                  std::map<Node, Rational>::const_iterator& itB) {
+    // compute the sum of the multiplicities
+    elements[itA->first] = itA->second + itB->second;
+  };
+
+  auto less = [](std::map<Node, Rational>& elements,
+                 std::map<Node, Rational>::const_iterator& itA,
+                 std::map<Node, Rational>::const_iterator& itB) {
+    // add the element to the result
+    elements[itA->first] = itA->second;
+  };
+
+  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
+                           std::map<Node, Rational>::const_iterator& itA,
+                           std::map<Node, Rational>::const_iterator& itB) {
+    // add the element to the result
+    elements[itB->first] = itB->second;
+  };
+
+  auto remainderOfA = [](std::map<Node, Rational>& elements,
+                         std::map<Node, Rational>& elementsA,
+                         std::map<Node, Rational>::const_iterator& itA) {
+    // append the remainder of A
+    while (itA != elementsA.end())
+    {
+      elements[itA->first] = itA->second;
+      itA++;
+    }
+  };
+
+  auto remainderOfB = [](std::map<Node, Rational>& elements,
+                         std::map<Node, Rational>& elementsB,
+                         std::map<Node, Rational>::const_iterator& itB) {
+    // append the remainder of B
+    while (itB != elementsB.end())
+    {
+      elements[itB->first] = itB->second;
+      itB++;
+    }
+  };
+
+  return evaluateBinaryOperation(
+      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
+}
+
+Node NormalForm::evaluateUnionMax(TNode n)
+{
+  Assert(n.getKind() == UNION_MAX);
+  // Example
+  // -------
+  // input: (union_max A B)
+  //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
+  //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
+  // output:
+  //    (union_disjoint A B)
+  //        where A = (MK_BAG "x" 4)
+  //              B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2)))
+
+  auto equal = [](std::map<Node, Rational>& elements,
+                  std::map<Node, Rational>::const_iterator& itA,
+                  std::map<Node, Rational>::const_iterator& itB) {
+    // compute the maximum multiplicity
+    elements[itA->first] = std::max(itA->second, itB->second);
+  };
+
+  auto less = [](std::map<Node, Rational>& elements,
+                 std::map<Node, Rational>::const_iterator& itA,
+                 std::map<Node, Rational>::const_iterator& itB) {
+    // add to the result
+    elements[itA->first] = itA->second;
+  };
+
+  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
+                           std::map<Node, Rational>::const_iterator& itA,
+                           std::map<Node, Rational>::const_iterator& itB) {
+    // add to the result
+    elements[itB->first] = itB->second;
+  };
+
+  auto remainderOfA = [](std::map<Node, Rational>& elements,
+                         std::map<Node, Rational>& elementsA,
+                         std::map<Node, Rational>::const_iterator& itA) {
+    // append the remainder of A
+    while (itA != elementsA.end())
+    {
+      elements[itA->first] = itA->second;
+      itA++;
+    }
+  };
+
+  auto remainderOfB = [](std::map<Node, Rational>& elements,
+                         std::map<Node, Rational>& elementsB,
+                         std::map<Node, Rational>::const_iterator& itB) {
+    // append the remainder of B
+    while (itB != elementsB.end())
+    {
+      elements[itB->first] = itB->second;
+      itB++;
+    }
+  };
+
+  return evaluateBinaryOperation(
+      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
 }
+
+Node NormalForm::evaluateIntersectionMin(TNode n)
+{
+  Assert(n.getKind() == INTERSECTION_MIN);
+  // Example
+  // -------
+  // input: (intersectionMin A B)
+  //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
+  //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
+  // output:
+  //        (MK_BAG "x" 3)
+
+  auto equal = [](std::map<Node, Rational>& elements,
+                  std::map<Node, Rational>::const_iterator& itA,
+                  std::map<Node, Rational>::const_iterator& itB) {
+    // compute the minimum multiplicity
+    elements[itA->first] = std::min(itA->second, itB->second);
+  };
+
+  auto less = [](std::map<Node, Rational>& elements,
+                 std::map<Node, Rational>::const_iterator& itA,
+                 std::map<Node, Rational>::const_iterator& itB) {
+    // do nothing
+  };
+
+  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
+                           std::map<Node, Rational>::const_iterator& itA,
+                           std::map<Node, Rational>::const_iterator& itB) {
+    // do nothing
+  };
+
+  auto remainderOfA = [](std::map<Node, Rational>& elements,
+                         std::map<Node, Rational>& elementsA,
+                         std::map<Node, Rational>::const_iterator& itA) {
+    // do nothing
+  };
+
+  auto remainderOfB = [](std::map<Node, Rational>& elements,
+                         std::map<Node, Rational>& elementsB,
+                         std::map<Node, Rational>::const_iterator& itB) {
+    // do nothing
+  };
+
+  return evaluateBinaryOperation(
+      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
+}
+
+Node NormalForm::evaluateDifferenceSubtract(TNode n)
+{
+  Assert(n.getKind() == DIFFERENCE_SUBTRACT);
+  // Example
+  // -------
+  // input: (difference_subtract A B)
+  //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
+  //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
+  // output:
+  //    (union_disjoint (MK_BAG "x" 1) (MK_BAG "z" 2))
+
+  auto equal = [](std::map<Node, Rational>& elements,
+                  std::map<Node, Rational>::const_iterator& itA,
+                  std::map<Node, Rational>::const_iterator& itB) {
+    // subtract the multiplicities
+    elements[itA->first] = itA->second - itB->second;
+  };
+
+  auto less = [](std::map<Node, Rational>& elements,
+                 std::map<Node, Rational>::const_iterator& itA,
+                 std::map<Node, Rational>::const_iterator& itB) {
+    // itA->first is not in B, so we add it to the difference subtract
+    elements[itA->first] = itA->second;
+  };
+
+  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
+                           std::map<Node, Rational>::const_iterator& itA,
+                           std::map<Node, Rational>::const_iterator& itB) {
+    // itB->first is not in A, so we just skip it
+  };
+
+  auto remainderOfA = [](std::map<Node, Rational>& elements,
+                         std::map<Node, Rational>& elementsA,
+                         std::map<Node, Rational>::const_iterator& itA) {
+    // append the remainder of A
+    while (itA != elementsA.end())
+    {
+      elements[itA->first] = itA->second;
+      itA++;
+    }
+  };
+
+  auto remainderOfB = [](std::map<Node, Rational>& elements,
+                         std::map<Node, Rational>& elementsB,
+                         std::map<Node, Rational>::const_iterator& itB) {
+    // do nothing
+  };
+
+  return evaluateBinaryOperation(
+      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
+}
+
+Node NormalForm::evaluateDifferenceRemove(TNode n)
+{
+  Assert(n.getKind() == DIFFERENCE_REMOVE);
+  // Example
+  // -------
+  // input: (difference_subtract A B)
+  //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
+  //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
+  // output:
+  //    (MK_BAG "z" 2)
+
+  auto equal = [](std::map<Node, Rational>& elements,
+                  std::map<Node, Rational>::const_iterator& itA,
+                  std::map<Node, Rational>::const_iterator& itB) {
+    // skip the shared element by doing nothing
+  };
+
+  auto less = [](std::map<Node, Rational>& elements,
+                 std::map<Node, Rational>::const_iterator& itA,
+                 std::map<Node, Rational>::const_iterator& itB) {
+    // itA->first is not in B, so we add it to the difference remove
+    elements[itA->first] = itA->second;
+  };
+
+  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
+                           std::map<Node, Rational>::const_iterator& itA,
+                           std::map<Node, Rational>::const_iterator& itB) {
+    // itB->first is not in A, so we just skip it
+  };
+
+  auto remainderOfA = [](std::map<Node, Rational>& elements,
+                         std::map<Node, Rational>& elementsA,
+                         std::map<Node, Rational>::const_iterator& itA) {
+    // append the remainder of A
+    while (itA != elementsA.end())
+    {
+      elements[itA->first] = itA->second;
+      itA++;
+    }
+  };
+
+  auto remainderOfB = [](std::map<Node, Rational>& elements,
+                         std::map<Node, Rational>& elementsB,
+                         std::map<Node, Rational>::const_iterator& itB) {
+    // do nothing
+  };
+
+  return evaluateBinaryOperation(
+      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
+}
+
+Node NormalForm::evaluateChoose(TNode n)
+{
+  Assert(n.getKind() == BAG_CHOOSE);
+  // Examples
+  // --------
+  // - (choose (emptyBag String)) = "" // the empty string which is the first
+  //   element returned by the type enumerator
+  // - (choose (MK_BAG "x" 4)) = "x"
+  // - (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) = "x"
+  //     deterministically return the first element
+
+  if (n[0].getKind() == EMPTYBAG)
+  {
+    TypeNode elementType = n[0].getType().getBagElementType();
+    TypeEnumerator typeEnumerator(elementType);
+    // get the first value from the typeEnumerator
+    Node element = *typeEnumerator;
+    return element;
+  }
+
+  if (n[0].getKind() == MK_BAG)
+  {
+    return n[0][0];
+  }
+  Assert(n[0].getKind() == UNION_DISJOINT);
+  // return the first element
+  // e.g. (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1)))
+  return n[0][0][0];
+}
+
+Node NormalForm::evaluateCard(TNode n)
+{
+  Assert(n.getKind() == BAG_CARD);
+  // Examples
+  // --------
+  //  - (card (emptyBag String)) = 0
+  //  - (choose (MK_BAG "x" 4)) = 4
+  //  - (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) = 5
+
+  std::map<Node, Rational> elements = getBagElements(n[0]);
+  Rational sum(0);
+  for (std::pair<Node, Rational> element : elements)
+  {
+    sum += element.second;
+  }
+
+  NodeManager* nm = NodeManager::currentNM();
+  Node sumNode = nm->mkConst(sum);
+  return sumNode;
+}
+
+Node NormalForm::evaluateIsSingleton(TNode n)
+{
+  Assert(n.getKind() == BAG_IS_SINGLETON);
+  // Examples
+  // --------
+  // - (bag.is_singleton (emptyBag String)) = false
+  // - (bag.is_singleton (MK_BAG "x" 1)) = true
+  // - (bag.is_singleton (MK_BAG "x" 4)) = false
+  // - (bag.is_singleton (union_disjoint (MK_BAG "x" 1) (MK_BAG "y" 1))) = false
+
+  if (n[0].getKind() == MK_BAG && n[0][1].getConst<Rational>().isOne())
+  {
+    return NodeManager::currentNM()->mkConst(true);
+  }
+  return NodeManager::currentNM()->mkConst(false);
+}
+
+Node NormalForm::evaluateFromSet(TNode n)
+{
+  Assert(n.getKind() == BAG_FROM_SET);
+
+  // Examples
+  // --------
+  //  - (bag.from_set (emptyset String)) = (emptybag String)
+  //  - (bag.from_set (singleton "x")) = (mkBag "x" 1)
+  //  - (bag.from_set (union (singleton "x") (singleton "y"))) =
+  //     (disjoint_union (mkBag "x" 1) (mkBag "y" 1))
+
+  NodeManager* nm = NodeManager::currentNM();
+  std::set<Node> setElements =
+      sets::NormalForm::getElementsFromNormalConstant(n[0]);
+  Rational one = Rational(1);
+  std::map<Node, Rational> bagElements;
+  for (const Node& element : setElements)
+  {
+    bagElements[element] = one;
+  }
+  TypeNode bagType = nm->mkBagType(n[0].getType().getSetElementType());
+  Node bag = constructBagFromElements(bagType, bagElements);
+  return bag;
+}
+
+Node NormalForm::evaluateToSet(TNode n)
+{
+  Assert(n.getKind() == BAG_TO_SET);
+
+  // Examples
+  // --------
+  //  - (bag.to_set (emptybag String)) = (emptyset String)
+  //  - (bag.to_set (mkBag "x" 4)) = (singleton "x")
+  //  - (bag.to_set (disjoint_union (mkBag "x" 3) (mkBag "y" 5)) =
+  //     (union (singleton "x") (singleton "y")))
+
+  NodeManager* nm = NodeManager::currentNM();
+  std::map<Node, Rational> bagElements = getBagElements(n[0]);
+  std::set<Node> setElements;
+  std::map<Node, Rational>::const_reverse_iterator it;
+  for (it = bagElements.rbegin(); it != bagElements.rend(); it++)
+  {
+    setElements.insert(it->first);
+  }
+  TypeNode setType = nm->mkSetType(n[0].getType().getBagElementType());
+  Node set = sets::NormalForm::elementsToSet(setElements, setType);
+  return set;
+}
+
 }  // namespace bags
 }  // namespace theory
 }  // namespace CVC4
\ No newline at end of file
index 8c719fe81c1838d7175326d329ce3525da563931..ef0edefff5e75674593b2d04f034a5ed8f78ce06 100644 (file)
@@ -29,22 +29,149 @@ class NormalForm
   /**
    * Returns true if n is considered a to be a (canonical) constant bag value.
    * A canonical bag value is one whose AST is:
-   *   (disjoint-union (mk-bag e1 n1) ...
-   *        (disjoint-union (mk-bag e_{n-1} n_{n-1}) (mk-bag e_n n_n))))
-   * where c1 ... cn are constants and the node identifier of these constants
-   * are such that:
-   *   c1 > ... > cn.
-   * Also handles the corner cases of empty bag and singleton bag.
+   *   (union_disjoint (mkBag e1 c1) ...
+   *      (union_disjoint (mkBag e_{n-1} c_{n-1}) (mkBag e_n c_n))))
+   * where c1 ... cn are positive integers, e1 ... en are constants, and the
+   * node identifier of these constants are such that: e1 < ... < en.
+   * Also handles the corner cases of empty bag and bag constructed by mkBag
    */
-  static bool checkNormalConstant(TNode n);
+  static bool isConstant(TNode n);
   /**
-   * check whether all children of the given node are in normal form
+   * check whether all children of the given node are constants
    */
-  static bool AreChildrenConstants(TNode n);
+  static bool areChildrenConstants(TNode n);
   /**
-   * evaluate the node n to a constant value
+   * evaluate the node n to a constant value.
+   * As a precondition, children of n should be constants.
    */
   static Node evaluate(TNode n);
+
+  /**
+   * get the elements along with their multiplicities in a given bag
+   * @param n a constant node whose type is a bag
+   * @return a map whose keys are constant elements and values are
+   * multiplicities
+   */
+  static std::map<Node, Rational> getBagElements(TNode n);
+
+  /**
+   * construct a constant bag from constant elements
+   * @param t the type of the returned bag
+   * @param elements a map whose keys are constant elements and values are
+   *        multiplicities
+   * @return a constant bag that contains
+   */
+  static Node constructBagFromElements(
+      TypeNode t, const std::map<Node, Rational>& elements);
+
+ private:
+  /**
+   * a high order helper function that return a constant bag that is the result
+   * of (op A B) where op is a binary operator and A, B are constant bags.
+   * The result is computed from the elements of A (elementsA with iterator itA)
+   * and elements of B (elementsB with iterator itB).
+   * The arguments below specify how these iterators are used to generate the
+   * elements of the result (elements).
+   * @param n a node whose kind is a binary operator (union_disjoint, union_max,
+   * intersection_min, difference_subtract, difference_remove) and whose
+   * children are constant bags.
+   * @param equal a lambda expression that receives (elements, itA, itB) and
+   * specify the action that needs to be taken when the elements of itA, itB are
+   * equal.
+   * @param less a lambda expression that receives (elements, itA, itB) and
+   * specify the action that needs to be taken when the element itA is less than
+   * the element of itB.
+   * @param greaterOrEqual less a lambda expression that receives (elements,
+   * itA, itB) and specify the action that needs to be taken when the element
+   * itA is greater than or equal than the element of itB.
+   * @param remainderOfA a lambda expression that receives (elements, elementsA,
+   * itA) and specify the action that needs to be taken to the remaining
+   * elements of A when all elements of B are visited.
+   * @param remainderOfB a lambda expression that receives (elements, elementsB,
+   * itB) and specify the action that needs to be taken to the remaining
+   * elements of B when all elements of A are visited.
+   * @return a constant bag that the result of (op n[0] n[1])
+   */
+  template <typename T1, typename T2, typename T3, typename T4, typename T5>
+  static Node evaluateBinaryOperation(const TNode& n,
+                                      T1&& equal,
+                                      T2&& less,
+                                      T3&& greaterOrEqual,
+                                      T4&& remainderOfA,
+                                      T5&& remainderOfB);
+  /**
+   * evaluate n as follows:
+   * - (mkBag a 0) = (emptybag T) where T is the type of the original bag
+   * - (mkBag a (-c)) = (emptybag T) where T is the type the original bag,
+   *                                and c > 0 is a constant
+   */
+  static Node evaluateMakeBag(TNode n);
+
+  /**
+   * returns the multiplicity in a constant bag
+   * @param n has the form (bag.count x A) where x, A are constants
+   * @return the multiplicity of element x in bag A.
+   */
+  static Node evaluateBagCount(TNode n);
+
+  /**
+   * evaluates union disjoint node such that the returned node is a canonical
+   * bag that has the form
+   * (union_disjoint (mkBag e1 c1) ...
+   *   (union_disjoint  * (mkBag e_{n-1} c_{n-1}) (mkBag e_n c_n)))) where
+   *   c1... cn are positive integers, e1 ... en are constants, and the node
+   * identifier of these constants are such that: e1 < ... < en.
+   * @param n has the form (union_disjoint A B) where A, B are constant bags
+   * @return the union disjoint of A and B
+   */
+  static Node evaluateUnionDisjoint(TNode n);
+  /**
+   * @param n has the form (union_max A B) where A, B are constant bags
+   * @return the union max of A and B
+   */
+  static Node evaluateUnionMax(TNode n);
+  /**
+   * @param n has the form (intersection_min A B) where A, B are constant bags
+   * @return the intersection min of A and B
+   */
+  static Node evaluateIntersectionMin(TNode n);
+  /**
+   * @param n has the form (difference_subtract A B) where A, B are constant
+   * bags
+   * @return the difference subtract of A and B
+   */
+  static Node evaluateDifferenceSubtract(TNode n);
+  /**
+   * @param n has the form (difference_remove A B) where A, B are constant bags
+   * @return the difference remove of A and B
+   */
+  static Node evaluateDifferenceRemove(TNode n);
+  /**
+   * @param n has the form (bag.choose A) where A is a constant bag
+   * @return the first element of A if A is not empty. Otherwise, it returns the
+   * first element returned by the type enumerator for the elements
+   */
+  static Node evaluateChoose(TNode n);
+  /**
+   * @param n has the form (bag.card A) where A is a constant bag
+   * @return the number of elements in bag A
+   */
+  static Node evaluateCard(TNode n);
+  /**
+   * @param n has the form (bag.is_singleton A) where A is a constant bag
+   * @return whether the bag A has cardinality one.
+   */
+  static Node evaluateIsSingleton(TNode n);
+  /**
+   * @param n has the form (bag.from_set A) where A is a constant set
+   * @return a constant bag that contains exactly the elements in A.
+   */
+  static Node evaluateFromSet(TNode n);
+  /**
+   * @param n has the form (bag.to_set A) where A is a constant bag
+   * @return a constant set constructed from the elements in A.
+   */
+  static Node evaluateToSet(TNode n);
 };
 }  // namespace bags
 }  // namespace theory
index 75f57ec885ddc007716950e9114a0e51d0b7576a..7767938edee5af85bf30f88f9ac88315e784782a 100644 (file)
@@ -57,7 +57,7 @@ struct BinaryOperatorTypeRule
     // only UNION_DISJOINT has a const rule in kinds.
     // Other binary operators do not have const rules in kinds
     Assert(n.getKind() == kind::UNION_DISJOINT);
-    return NormalForm::checkNormalConstant(n);
+    return NormalForm::isConstant(n);
   }
 }; /* struct BinaryOperatorTypeRule */
 
index 481c80f264812ed4b06a12c60f8a45f32e55a3e6..8cfd439897fcfaddf2310daa377d7442a6daabb6 100644 (file)
@@ -14,6 +14,7 @@ cvc4_add_unit_test_white(evaluator_white theory)
 cvc4_add_unit_test_white(logic_info_white theory)
 cvc4_add_unit_test_white(sequences_rewriter_white theory)
 cvc4_add_unit_test_white(theory_arith_white theory)
+cvc4_add_unit_test_white(theory_bags_normal_form_white theory)
 cvc4_add_unit_test_white(theory_bags_rewriter_white theory)
 cvc4_add_unit_test_white(theory_bags_type_rules_white theory)
 cvc4_add_unit_test_white(theory_bv_rewriter_white theory)
diff --git a/test/unit/theory/theory_bags_normal_form_white.h b/test/unit/theory/theory_bags_normal_form_white.h
new file mode 100644 (file)
index 0000000..6f7d5bd
--- /dev/null
@@ -0,0 +1,512 @@
+/*********************                                                        */
+/*! \file theory_bags_normal_form_white.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief White box testing of bags normal form
+ **/
+
+#include <cxxtest/TestSuite.h>
+
+#include "expr/dtype.h"
+#include "smt/smt_engine.h"
+#include "theory/bags/bags_rewriter.h"
+#include "theory/bags/normal_form.h"
+#include "theory/strings/type_enumerator.h"
+
+using namespace CVC4;
+using namespace CVC4::smt;
+using namespace CVC4::theory;
+using namespace CVC4::kind;
+using namespace CVC4::theory::bags;
+using namespace std;
+
+typedef expr::Attribute<Node, Node> attribute;
+
+class BagsNormalFormWhite : public CxxTest::TestSuite
+{
+ public:
+  void setUp() override
+  {
+    d_em.reset(new ExprManager());
+    d_smt.reset(new SmtEngine(d_em.get()));
+    d_nm.reset(NodeManager::fromExprManager(d_em.get()));
+    d_smt->finishInit();
+    d_rewriter.reset(new BagsRewriter(nullptr));
+  }
+
+  void tearDown() override
+  {
+    d_rewriter.reset();
+    d_smt.reset();
+    d_nm.release();
+    d_em.reset();
+  }
+
+  std::vector<Node> getNStrings(size_t n)
+  {
+    std::vector<Node> elements(n);
+    CVC4::theory::strings::StringEnumerator enumerator(d_nm->stringType());
+
+    for (size_t i = 0; i < n; i++)
+    {
+      ++enumerator;
+      elements[i] = *enumerator;
+    }
+
+    return elements;
+  }
+
+  void testEmptyBagNormalForm()
+  {
+    Node emptybag = d_nm->mkConst(EmptyBag(d_nm->stringType()));
+    // empty bags are in normal form
+    TS_ASSERT(emptybag.isConst());
+    Node n = NormalForm::evaluate(emptybag);
+    TS_ASSERT(emptybag == n);
+  }
+
+  void testBagEquality() {}
+
+  void testMkBagConstantElement()
+  {
+    vector<Node> elements = getNStrings(1);
+    Node negative = d_nm->mkBag(
+        d_nm->stringType(), elements[0], d_nm->mkConst(Rational(-1)));
+    Node zero = d_nm->mkBag(
+        d_nm->stringType(), elements[0], d_nm->mkConst(Rational(0)));
+    Node positive = d_nm->mkBag(
+        d_nm->stringType(), elements[0], d_nm->mkConst(Rational(1)));
+    Node emptybag =
+        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+
+    TS_ASSERT(!negative.isConst());
+    TS_ASSERT(!zero.isConst());
+    TS_ASSERT(emptybag == NormalForm::evaluate(negative));
+    TS_ASSERT(emptybag == NormalForm::evaluate(zero));
+    TS_ASSERT(positive == NormalForm::evaluate(positive));
+  }
+
+  void testBagCount()
+  {
+    // Examples
+    // -------
+    // (bag.count "x" (emptybag String)) = 0
+    // (bag.count "x" (mkBag "y" 5)) = 0
+    // (bag.count "x" (mkBag "x" 4)) = 4
+    // (bag.count "x" (union_disjoint (mkBag "x" 4) (mkBag "y" 5)) = 4
+    // (bag.count "x" (union_disjoint (mkBag "y" 5) (mkBag "z" 5)) = 0
+
+    Node zero = d_nm->mkConst(Rational(0));
+    Node four = d_nm->mkConst(Rational(4));
+    Node empty = d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+    Node x = d_nm->mkConst(String("x"));
+    Node y = d_nm->mkConst(String("y"));
+    Node z = d_nm->mkConst(String("z"));
+    Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4)));
+    Node y_5 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(5)));
+    Node z_5 = d_nm->mkBag(d_nm->stringType(), z, d_nm->mkConst(Rational(5)));
+
+    Node input1 = d_nm->mkNode(BAG_COUNT, x, empty);
+    Node output1 = zero;
+    TS_ASSERT(output1 == NormalForm::evaluate(input1));
+
+    Node input2 = d_nm->mkNode(BAG_COUNT, x, y_5);
+    Node output2 = zero;
+    TS_ASSERT(output2 == NormalForm::evaluate(input2));
+
+    Node input3 = d_nm->mkNode(BAG_COUNT, x, x_4);
+    Node output3 = four;
+    TS_ASSERT(output2 == NormalForm::evaluate(input2));
+
+    Node unionDisjointXY = d_nm->mkNode(UNION_DISJOINT, x_4, y_5);
+    Node input4 = d_nm->mkNode(BAG_COUNT, x, unionDisjointXY);
+    Node output4 = four;
+    TS_ASSERT(output3 == NormalForm::evaluate(input3));
+
+    Node unionDisjointYZ = d_nm->mkNode(UNION_DISJOINT, y_5, z_5);
+    Node input5 = d_nm->mkNode(BAG_COUNT, x, unionDisjointYZ);
+    Node output5 = zero;
+    TS_ASSERT(output4 == NormalForm::evaluate(input4));
+  }
+
+  void testUnionMax()
+  {
+    // Example
+    // -------
+    // input: (union_max A B)
+    //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
+    //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
+    // output:
+    //    (union_disjoint A B)
+    //        where A = (MK_BAG "x" 4)
+    //              B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2)))
+
+    Node x = d_nm->mkConst(String("x"));
+    Node y = d_nm->mkConst(String("y"));
+    Node z = d_nm->mkConst(String("z"));
+    Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4)));
+    Node x_3 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(3)));
+    Node x_7 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(7)));
+    Node z_2 = d_nm->mkBag(d_nm->stringType(), z, d_nm->mkConst(Rational(2)));
+    Node y_1 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(1)));
+
+    Node A = d_nm->mkNode(UNION_DISJOINT, x_4, z_2);
+    Node B = d_nm->mkNode(UNION_DISJOINT, x_3, y_1);
+    Node input = d_nm->mkNode(UNION_MAX, A, B);
+
+    // output
+    Node output = d_nm->mkNode(
+        UNION_DISJOINT, x_4, d_nm->mkNode(UNION_DISJOINT, y_1, z_2));
+
+    TS_ASSERT(output.isConst());
+    TS_ASSERT(output == NormalForm::evaluate(input));
+  }
+
+  void testUnionDisjoint1()
+  {
+    vector<Node> elements = getNStrings(3);
+    Node emptybag =
+        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+    Node A = d_nm->mkBag(
+        d_nm->stringType(), elements[0], d_nm->mkConst(Rational(2)));
+    Node B = d_nm->mkBag(
+        d_nm->stringType(), elements[1], d_nm->mkConst(Rational(3)));
+    Node C = d_nm->mkBag(
+        d_nm->stringType(), elements[2], d_nm->mkConst(Rational(4)));
+
+    Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
+    // unionDisjointAB is already in a normal form
+    TS_ASSERT(unionDisjointAB.isConst());
+    TS_ASSERT(unionDisjointAB == NormalForm::evaluate(unionDisjointAB));
+
+    Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
+    // unionDisjointAB is is the normal form of unionDisjointBA
+    TS_ASSERT(!unionDisjointBA.isConst());
+    TS_ASSERT(unionDisjointAB == NormalForm::evaluate(unionDisjointBA));
+
+    Node unionDisjointAB_C = d_nm->mkNode(UNION_DISJOINT, unionDisjointAB, C);
+    Node unionDisjointBC = d_nm->mkNode(UNION_DISJOINT, B, C);
+    Node unionDisjointA_BC = d_nm->mkNode(UNION_DISJOINT, A, unionDisjointBC);
+    // unionDisjointA_BC is the normal form of unionDisjointAB_C
+    TS_ASSERT(!unionDisjointAB_C.isConst());
+    TS_ASSERT(unionDisjointA_BC.isConst());
+    TS_ASSERT(unionDisjointA_BC == NormalForm::evaluate(unionDisjointAB_C));
+
+    Node unionDisjointAA = d_nm->mkNode(UNION_DISJOINT, A, A);
+    Node AA = d_nm->mkBag(
+        d_nm->stringType(), elements[0], d_nm->mkConst(Rational(4)));
+    TS_ASSERT(!unionDisjointAA.isConst());
+    TS_ASSERT(AA.isConst());
+    TS_ASSERT(AA == NormalForm::evaluate(unionDisjointAA));
+  }
+
+  void testUnionDisjoint2()
+  {
+    // Example
+    // -------
+    // input: (union_disjoint A B)
+    //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
+    //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
+    // output:
+    //    (union_disjoint A B)
+    //        where A = (MK_BAG "x" 7)
+    //              B = (union_disjoint (MK_BAG "y" 1) (MK_BAG "z" 2)))
+
+    Node x = d_nm->mkConst(String("x"));
+    Node y = d_nm->mkConst(String("y"));
+    Node z = d_nm->mkConst(String("z"));
+    Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4)));
+    Node x_3 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(3)));
+    Node x_7 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(7)));
+    Node z_2 = d_nm->mkBag(d_nm->stringType(), z, d_nm->mkConst(Rational(2)));
+    Node y_1 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(1)));
+
+    Node A = d_nm->mkNode(UNION_DISJOINT, x_4, z_2);
+    Node B = d_nm->mkNode(UNION_DISJOINT, x_3, y_1);
+    Node input = d_nm->mkNode(UNION_DISJOINT, A, B);
+
+    // output
+    Node output = d_nm->mkNode(
+        UNION_DISJOINT, x_7, d_nm->mkNode(UNION_DISJOINT, y_1, z_2));
+
+    TS_ASSERT(output.isConst());
+    TS_ASSERT(output == NormalForm::evaluate(input));
+  }
+
+  void testIntersectionMin()
+  {
+    // Example
+    // -------
+    // input: (intersection_min A B)
+    //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
+    //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
+    // output:
+    //    (MK_BAG "x" 3)
+
+    Node x = d_nm->mkConst(String("x"));
+    Node y = d_nm->mkConst(String("y"));
+    Node z = d_nm->mkConst(String("z"));
+    Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4)));
+    Node x_3 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(3)));
+    Node x_7 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(7)));
+    Node z_2 = d_nm->mkBag(d_nm->stringType(), z, d_nm->mkConst(Rational(2)));
+    Node y_1 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(1)));
+
+    Node A = d_nm->mkNode(UNION_DISJOINT, x_4, z_2);
+    Node B = d_nm->mkNode(UNION_DISJOINT, x_3, y_1);
+    Node input = d_nm->mkNode(INTERSECTION_MIN, A, B);
+
+    // output
+    Node output = x_3;
+
+    TS_ASSERT(output.isConst());
+    TS_ASSERT(output == NormalForm::evaluate(input));
+  }
+
+  void testDifferenceSubtract()
+  {
+    // Example
+    // -------
+    // input: (difference_subtract A B)
+    //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
+    //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
+    // output:
+    //    (union_disjoint (MK_BAG "x" 1) (MK_BAG "z" 2))
+
+    Node x = d_nm->mkConst(String("x"));
+    Node y = d_nm->mkConst(String("y"));
+    Node z = d_nm->mkConst(String("z"));
+    Node x_1 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(1)));
+    Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4)));
+    Node x_3 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(3)));
+    Node x_7 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(7)));
+    Node z_2 = d_nm->mkBag(d_nm->stringType(), z, d_nm->mkConst(Rational(2)));
+    Node y_1 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(1)));
+
+    Node A = d_nm->mkNode(UNION_DISJOINT, x_4, z_2);
+    Node B = d_nm->mkNode(UNION_DISJOINT, x_3, y_1);
+    Node input = d_nm->mkNode(DIFFERENCE_SUBTRACT, A, B);
+
+    // output
+    Node output = d_nm->mkNode(UNION_DISJOINT, x_1, z_2);
+
+    TS_ASSERT(output.isConst());
+    TS_ASSERT(output == NormalForm::evaluate(input));
+  }
+
+  void testDifferenceRemove()
+  {
+    // Example
+    // -------
+    // input: (difference_remove A B)
+    //    where A = (union_disjoint (MK_BAG "x" 4) (MK_BAG "z" 2)))
+    //          B = (union_disjoint (MK_BAG "x" 3) (MK_BAG "y" 1)))
+    // output:
+    //    (MK_BAG "z" 2)
+
+    Node x = d_nm->mkConst(String("x"));
+    Node y = d_nm->mkConst(String("y"));
+    Node z = d_nm->mkConst(String("z"));
+    Node x_1 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(1)));
+    Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4)));
+    Node x_3 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(3)));
+    Node x_7 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(7)));
+    Node z_2 = d_nm->mkBag(d_nm->stringType(), z, d_nm->mkConst(Rational(2)));
+    Node y_1 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(1)));
+
+    Node A = d_nm->mkNode(UNION_DISJOINT, x_4, z_2);
+    Node B = d_nm->mkNode(UNION_DISJOINT, x_3, y_1);
+    Node input = d_nm->mkNode(DIFFERENCE_REMOVE, A, B);
+
+    // output
+    Node output = z_2;
+
+    TS_ASSERT(output.isConst());
+    TS_ASSERT(output == NormalForm::evaluate(input));
+  }
+
+  void testChoose()
+  {
+    // Example
+    // -------
+    // input:  (choose (emptybag String))
+    // output: "A"; the first element returned by the type enumerator
+    // input:  (choose (MK_BAG "x" 4))
+    // output: "x"
+    // input:  (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1)))
+    // output: "x"; deterministically return the first element
+    Node empty = d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+    Node x = d_nm->mkConst(String("x"));
+    Node y = d_nm->mkConst(String("y"));
+    Node z = d_nm->mkConst(String("z"));
+    Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4)));
+    Node y_1 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(1)));
+
+    Node input1 = d_nm->mkNode(BAG_CHOOSE, empty);
+    Node output1 = d_nm->mkConst(String(""));
+
+    TS_ASSERT(output1 == NormalForm::evaluate(input1));
+
+    Node input2 = d_nm->mkNode(BAG_CHOOSE, x_4);
+    Node output2 = x;
+    TS_ASSERT(output2 == NormalForm::evaluate(input2));
+
+    Node union_disjoint = d_nm->mkNode(UNION_DISJOINT, x_4, y_1);
+    Node input3 = d_nm->mkNode(BAG_CHOOSE, union_disjoint);
+    Node output3 = x;
+    TS_ASSERT(output3 == NormalForm::evaluate(input3));
+  }
+
+  void testBagCard()
+  {
+    // Examples
+    // --------
+    //  - (card (emptybag String)) = 0
+    //  - (choose (MK_BAG "x" 4)) = 4
+    //  - (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) = 5
+    Node empty = d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+    Node x = d_nm->mkConst(String("x"));
+    Node y = d_nm->mkConst(String("y"));
+    Node z = d_nm->mkConst(String("z"));
+    Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4)));
+    Node y_1 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(1)));
+
+    Node input1 = d_nm->mkNode(BAG_CARD, empty);
+    Node output1 = d_nm->mkConst(Rational(0));
+
+    TS_ASSERT(output1 == NormalForm::evaluate(input1));
+
+    Node input2 = d_nm->mkNode(BAG_CARD, x_4);
+    Node output2 = d_nm->mkConst(Rational(4));
+    TS_ASSERT(output2 == NormalForm::evaluate(input2));
+
+    Node union_disjoint = d_nm->mkNode(UNION_DISJOINT, x_4, y_1);
+    Node input3 = d_nm->mkNode(BAG_CARD, union_disjoint);
+    Node output3 = d_nm->mkConst(Rational(5));
+    TS_ASSERT(output3 == NormalForm::evaluate(input3));
+  }
+
+  void testIsSingleton()
+  {
+    // Examples
+    // --------
+    //  - (bag.is_singleton (emptybag String)) = false
+    //  - (bag.is_singleton (MK_BAG "x" 1)) = true
+    //  - (bag.is_singleton (MK_BAG "x" 4)) = false
+    //  - (bag.is_singleton (union_disjoint (MK_BAG "x" 1) (MK_BAG "y" 1))) =
+    //     false
+    Node falseNode = d_nm->mkConst(false);
+    Node trueNode = d_nm->mkConst(true);
+    Node empty = d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+    Node x = d_nm->mkConst(String("x"));
+    Node y = d_nm->mkConst(String("y"));
+    Node z = d_nm->mkConst(String("z"));
+    Node x_1 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(1)));
+    Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4)));
+    Node y_1 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(1)));
+
+    Node input1 = d_nm->mkNode(BAG_IS_SINGLETON, empty);
+    Node output1 = falseNode;
+    TS_ASSERT(output1 == NormalForm::evaluate(input1));
+
+    Node input2 = d_nm->mkNode(BAG_IS_SINGLETON, x_1);
+    Node output2 = trueNode;
+    TS_ASSERT(output2 == NormalForm::evaluate(input2));
+
+    Node input3 = d_nm->mkNode(BAG_IS_SINGLETON, x_4);
+    Node output3 = falseNode;
+    TS_ASSERT(output2 == NormalForm::evaluate(input2));
+
+    Node union_disjoint = d_nm->mkNode(UNION_DISJOINT, x_1, y_1);
+    Node input4 = d_nm->mkNode(BAG_IS_SINGLETON, union_disjoint);
+    Node output4 = falseNode;
+    TS_ASSERT(output3 == NormalForm::evaluate(input3));
+  }
+
+  void testFromSet()
+  {
+    // Examples
+    // --------
+    //  - (bag.from_set (emptyset String)) = (emptybag String)
+    //  - (bag.from_set (singleton "x")) = (mkBag "x" 1)
+    //  - (bag.to_set (union (singleton "x") (singleton "y"))) =
+    //     (disjoint_union (mkBag "x" 1) (mkBag "y" 1))
+
+    Node emptyset =
+        d_nm->mkConst(EmptySet(d_nm->mkSetType(d_nm->stringType())));
+    Node emptybag =
+        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+    Node input1 = d_nm->mkNode(BAG_FROM_SET, emptyset);
+    Node output1 = emptybag;
+    TS_ASSERT(output1 == NormalForm::evaluate(input1));
+
+    Node x = d_nm->mkConst(String("x"));
+    Node y = d_nm->mkConst(String("y"));
+
+    Node xSingleton = d_nm->mkSingleton(d_nm->stringType(), x);
+    Node ySingleton = d_nm->mkSingleton(d_nm->stringType(), y);
+
+    Node x_1 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(1)));
+    Node y_1 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(1)));
+
+    Node input2 = d_nm->mkNode(BAG_FROM_SET, xSingleton);
+    Node output2 = x_1;
+    TS_ASSERT(output2 == NormalForm::evaluate(input2));
+
+    // for normal sets, the first node is the largest, not smallest
+    Node normalSet = d_nm->mkNode(UNION, ySingleton, xSingleton);
+    Node input3 = d_nm->mkNode(BAG_FROM_SET, normalSet);
+    Node output3 = d_nm->mkNode(UNION_DISJOINT, x_1, y_1);
+    TS_ASSERT(output3 == NormalForm::evaluate(input3));
+  }
+
+  void testToSet()
+  {
+    // Examples
+    // --------
+    //  - (bag.to_set (emptybag String)) = (emptyset String)
+    //  - (bag.to_set (mkBag "x" 4)) = (singleton "x")
+    //  - (bag.to_set (disjoint_union (mkBag "x" 3) (mkBag "y" 5)) =
+    //     (union (singleton "x") (singleton "y")))
+
+    Node emptyset =
+        d_nm->mkConst(EmptySet(d_nm->mkSetType(d_nm->stringType())));
+    Node emptybag =
+        d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+    Node input1 = d_nm->mkNode(BAG_TO_SET, emptybag);
+    Node output1 = emptyset;
+    TS_ASSERT(output1 == NormalForm::evaluate(input1));
+
+    Node x = d_nm->mkConst(String("x"));
+    Node y = d_nm->mkConst(String("y"));
+
+    Node xSingleton = d_nm->mkSingleton(d_nm->stringType(), x);
+    Node ySingleton = d_nm->mkSingleton(d_nm->stringType(), y);
+
+    Node x_4 = d_nm->mkBag(d_nm->stringType(), x, d_nm->mkConst(Rational(4)));
+    Node y_5 = d_nm->mkBag(d_nm->stringType(), y, d_nm->mkConst(Rational(5)));
+
+    Node input2 = d_nm->mkNode(BAG_TO_SET, x_4);
+    Node output2 = xSingleton;
+    TS_ASSERT(output2 == NormalForm::evaluate(input2));
+
+    // for normal sets, the first node is the largest, not smallest
+    Node normalBag = d_nm->mkNode(UNION_DISJOINT, x_4, y_5);
+    Node input3 = d_nm->mkNode(BAG_TO_SET, normalBag);
+    Node output3 = d_nm->mkNode(UNION, ySingleton, xSingleton);
+    TS_ASSERT(output3 == NormalForm::evaluate(input3));
+  }
+
+ private:
+  std::unique_ptr<ExprManager> d_em;
+  std::unique_ptr<SmtEngine> d_smt;
+  std::unique_ptr<NodeManager> d_nm;
+  std::unique_ptr<BagsRewriter> d_rewriter;
+}; /* class BagsTypeRuleBlack */
index dfe2d4cacae9c2a7f40c05c3632c2ac84d4c2424..5622a30007b2c7c20178e544b20019d5eb5eeb23 100644 (file)
@@ -104,6 +104,13 @@ class BagsTypeRuleWhite : public CxxTest::TestSuite
     Node bag = d_nm->mkBag(d_nm->stringType(), elements[0], d_nm->mkConst(Rational(10)));
     TS_ASSERT_THROWS_NOTHING(d_nm->mkNode(BAG_TO_SET, bag));
     TS_ASSERT(d_nm->mkNode(BAG_TO_SET, bag).getType().isSet());
+    std::cout<<"Rational(4, 4).isIntegral() " << d_nm->mkConst(Rational(4,4)).getType()<<  std::endl;
+    std::cout<<"Rational(8, 4).isIntegral() " << d_nm->mkConst(Rational(8,4)).getType()<<  std::endl;
+    std::cout<<"Rational(1, 4).isIntegral() " << d_nm->mkConst(Rational(1,4)).getType()<<  std::endl;
+
+    std::cout<<"Rational(4, 4).isIntegral() " << d_nm->mkNode(TO_REAL, d_nm->mkConst(Rational(4,4))).getType()<<  std::endl;
+    std::cout<<"Rational(8, 4).isIntegral() " << d_nm->mkNode(TO_REAL, d_nm->mkConst(Rational(8,4))).getType()<<  std::endl;
+    std::cout<<"Rational(1, 4).isIntegral() " << d_nm->mkNode(TO_REAL, d_nm->mkConst(Rational(1,4))).getType()<<  std::endl;
   }
 
  private: