Add bag.filter operator (#8006)
authormudathirmahgoub <mudathirmahgoub@gmail.com>
Tue, 1 Feb 2022 14:58:04 +0000 (08:58 -0600)
committerGitHub <noreply@github.com>
Tue, 1 Feb 2022 14:58:04 +0000 (14:58 +0000)
33 files changed:
src/CMakeLists.txt
src/api/cpp/cvc5.cpp
src/api/cpp/cvc5_kind.h
src/parser/smt2/smt2.cpp
src/printer/smt2/smt2_printer.cpp
src/theory/bags/bag_solver.cpp
src/theory/bags/bag_solver.h
src/theory/bags/bags_rewriter.cpp
src/theory/bags/bags_rewriter.h
src/theory/bags/bags_utils.cpp [new file with mode: 0644]
src/theory/bags/bags_utils.h [new file with mode: 0644]
src/theory/bags/card_solver.cpp
src/theory/bags/inference_generator.cpp
src/theory/bags/inference_generator.h
src/theory/bags/kinds
src/theory/bags/normal_form.cpp [deleted file]
src/theory/bags/normal_form.h [deleted file]
src/theory/bags/rewrites.cpp
src/theory/bags/rewrites.h
src/theory/bags/theory_bags.cpp
src/theory/bags/theory_bags_type_enumerator.cpp
src/theory/bags/theory_bags_type_rules.cpp
src/theory/bags/theory_bags_type_rules.h
src/theory/inference_id.cpp
src/theory/inference_id.h
test/regress/CMakeLists.txt
test/regress/regress1/bags/filter1.smt2 [new file with mode: 0644]
test/regress/regress1/bags/filter2.smt2 [new file with mode: 0644]
test/regress/regress1/bags/filter3.smt2 [new file with mode: 0644]
test/regress/regress1/bags/filter4.smt2 [new file with mode: 0644]
test/regress/regress1/bags/filter5.smt2 [new file with mode: 0644]
test/regress/regress1/bags/map1.smt2
test/unit/theory/theory_bags_normal_form_white.cpp

index a1ad056f1eced3878641972f5c7807a36c28dca0..1715257f7ca05abdd48b962663823cb90fd8866a 100644 (file)
@@ -549,6 +549,8 @@ libcvc5_add_sources(
   theory/bags/bag_reduction.h
   theory/bags/bags_statistics.cpp
   theory/bags/bags_statistics.h
+  theory/bags/bags_utils.cpp
+  theory/bags/bags_utils.h
   theory/bags/card_solver.cpp
   theory/bags/card_solver.h
   theory/bags/infer_info.cpp
@@ -557,8 +559,6 @@ libcvc5_add_sources(
   theory/bags/inference_generator.h
   theory/bags/inference_manager.cpp
   theory/bags/inference_manager.h
-  theory/bags/normal_form.cpp
-  theory/bags/normal_form.h
   theory/bags/rewrites.cpp
   theory/bags/rewrites.h
   theory/bags/solver_state.cpp
index 458dd359aba6aed8dbd343f30b042ff8402b8dba..df9a8b8ae5f92ded22c09cabd641314bfc779da9 100644 (file)
@@ -312,6 +312,7 @@ const static std::unordered_map<Kind, cvc5::Kind> s_kinds{
     {BAG_FROM_SET, cvc5::Kind::BAG_FROM_SET},
     {BAG_TO_SET, cvc5::Kind::BAG_TO_SET},
     {BAG_MAP, cvc5::Kind::BAG_MAP},
+    {BAG_FILTER, cvc5::Kind::BAG_FILTER},
     {BAG_FOLD, cvc5::Kind::BAG_FOLD},
     /* Strings ------------------------------------------------------------- */
     {STRING_CONCAT, cvc5::Kind::STRING_CONCAT},
@@ -624,6 +625,7 @@ const static std::unordered_map<cvc5::Kind, Kind, cvc5::kind::KindHashFunction>
         {cvc5::Kind::BAG_FROM_SET, BAG_FROM_SET},
         {cvc5::Kind::BAG_TO_SET, BAG_TO_SET},
         {cvc5::Kind::BAG_MAP, BAG_MAP},
+        {cvc5::Kind::BAG_FILTER, BAG_FILTER},
         {cvc5::Kind::BAG_FOLD, BAG_FOLD},
         /* Strings --------------------------------------------------------- */
         {cvc5::Kind::STRING_CONCAT, STRING_CONCAT},
index 3bd89681406a57f34608bf2856233dab41ae8445..dba4df07fa96c0d015fe6667479caab5ef90373b 100644 (file)
@@ -2539,6 +2539,23 @@ enum Kind : int32_t
    *   - `Solver::mkTerm(Kind kind, const std::vector<Term>& children) const`
    */
   BAG_MAP,
+  /**
+    * bag.filter operator filters the elements of a bag.
+    * (bag.filter p B) takes a predicate p of type (-> T Bool) as a first
+    * argument, and a bag B of type (Bag T) as a second argument, and returns a
+    * subbag of type (Bag T) that includes all elements of B that satisfy p
+    * with the same multiplicity.
+    *
+    * Parameters:
+    *   - 1: a function of type (-> T Bool)
+    *   - 2: a bag of type (Bag T)
+    *
+    * Create with:
+    *   - `Solver::mkTerm(Kind kind, const Term& child1, const Term& child2)
+    * const`
+    *   - `Solver::mkTerm(Kind kind, const std::vector<Term>& children) const`
+    */
+   BAG_FILTER,
   /**
    * bag.fold operator combines elements of a bag into a single value.
    * (bag.fold f t B) folds the elements of bag B starting with term t and using
index 8eed51baac2dc2cd388b0acb4896dfc2e2a1ce45..a93596633b73536c77580e5a99f0e4880979fe52 100644 (file)
@@ -625,6 +625,7 @@ Command* Smt2::setLogic(std::string name, bool fromCommand)
     addOperator(api::BAG_FROM_SET, "bag.from_set");
     addOperator(api::BAG_TO_SET, "bag.to_set");
     addOperator(api::BAG_MAP, "bag.map");
+    addOperator(api::BAG_FILTER, "bag.filter");
     addOperator(api::BAG_FOLD, "bag.fold");
   }
   if(d_logic.isTheoryEnabled(theory::THEORY_STRINGS)) {
index 067cf27fefb365824368c25a48179ad04974a0ac..cf85e6b0ef372a691d820761da33655cfd765285 100644 (file)
@@ -1123,6 +1123,7 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v)
   case kind::BAG_FROM_SET: return "bag.from_set";
   case kind::BAG_TO_SET: return "bag.to_set";
   case kind::BAG_MAP: return "bag.map";
+  case kind::BAG_FILTER: return "bag.filter";
   case kind::BAG_FOLD: return "bag.fold";
 
     // fp theory
index 55367bb8902158e278017867f15e2aaa187b7c28..ed4b501f3564c270028acffadf286a8797d802b5 100644 (file)
@@ -16,9 +16,9 @@
 #include "theory/bags/bag_solver.h"
 
 #include "expr/emptybag.h"
+#include "theory/bags/bags_utils.h"
 #include "theory/bags/inference_generator.h"
 #include "theory/bags/inference_manager.h"
-#include "theory/bags/normal_form.h"
 #include "theory/bags/solver_state.h"
 #include "theory/bags/term_registry.h"
 #include "theory/uf/equality_engine_iterator.h"
@@ -76,6 +76,7 @@ void BagSolver::checkBasicOperations()
         case kind::BAG_DIFFERENCE_SUBTRACT: checkDifferenceSubtract(n); break;
         case kind::BAG_DIFFERENCE_REMOVE: checkDifferenceRemove(n); break;
         case kind::BAG_DUPLICATE_REMOVAL: checkDuplicateRemoval(n); break;
+        case kind::BAG_FILTER: checkFilter(n); break;
         case kind::BAG_MAP: checkMap(n); break;
         default: break;
       }
@@ -280,6 +281,28 @@ void BagSolver::checkMap(Node n)
   }
 }
 
+void BagSolver::checkFilter(Node n)
+{
+  Assert(n.getKind() == BAG_FILTER);
+
+  set<Node> elements;
+  const set<Node>& downwards = d_state.getElements(n);
+  const set<Node>& upwards = d_state.getElements(n[0]);
+  elements.insert(downwards.begin(), downwards.end());
+  elements.insert(upwards.begin(), upwards.end());
+
+  for (const Node& e : elements)
+  {
+    InferInfo i = d_ig.filterDownwards(n, d_state.getRepresentative(e));
+    d_im.lemmaTheoryInference(&i);
+  }
+  for (const Node& e : elements)
+  {
+    InferInfo i = d_ig.filterUpwards(n, d_state.getRepresentative(e));
+    d_im.lemmaTheoryInference(&i);
+  }
+}
+
 }  // namespace bags
 }  // namespace theory
 }  // namespace cvc5
index 499b7998db510af2f241811266267b35ce74d446..fca72b22e8046e9265ad237293e0d7afd66d11cb 100644 (file)
@@ -96,6 +96,8 @@ class BagSolver : protected EnvObj
   void checkDisequalBagTerms();
   /** apply inference rules for map operator */
   void checkMap(Node n);
+  /** apply inference rules for filter operator */
+  void checkFilter(Node n);
 
   /** The solver state object */
   SolverState& d_state;
index 40f8d6c95e9adfa4ba03b25b088c8a8769c2f6c7..24f313ad68cc630d31f9306bcb816c2f18b35880 100644 (file)
@@ -16,7 +16,7 @@
 #include "theory/bags/bags_rewriter.h"
 
 #include "expr/emptybag.h"
-#include "theory/bags/normal_form.h"
+#include "theory/bags/bags_utils.h"
 #include "util/rational.h"
 #include "util/statistics_registry.h"
 
@@ -65,9 +65,9 @@ RewriteResponse BagsRewriter::postRewrite(TNode n)
   {
     response = rewriteChoose(n);
   }
-  else if (NormalForm::areChildrenConstants(n))
+  else if (BagsUtils::areChildrenConstants(n))
   {
-    Node value = NormalForm::evaluate(n);
+    Node value = BagsUtils::evaluate(n);
     response = BagsRewriteResponse(value, Rewrite::CONSTANT_EVALUATION);
   }
   else
@@ -90,6 +90,7 @@ RewriteResponse BagsRewriter::postRewrite(TNode n)
       case BAG_FROM_SET: response = rewriteFromSet(n); break;
       case BAG_TO_SET: response = rewriteToSet(n); break;
       case BAG_MAP: response = postRewriteMap(n); break;
+      case BAG_FILTER: response = postRewriteFilter(n); break;
       case BAG_FOLD: response = postRewriteFold(n); break;
       default: response = BagsRewriteResponse(n, Rewrite::NONE); break;
     }
@@ -533,7 +534,7 @@ BagsRewriteResponse BagsRewriter::postRewriteMap(const TNode& n) const
   {
     // (bag.map f (as bag.empty (Bag T1)) = (as bag.empty (Bag T2))
     // (bag.map f (bag "a" 3)) = (bag (f "a") 3)
-    std::map<Node, Rational> elements = NormalForm::getBagElements(n[1]);
+    std::map<Node, Rational> elements = BagsUtils::getBagElements(n[1]);
     std::map<Node, Rational> mappedElements;
     std::map<Node, Rational>::iterator it = elements.begin();
     while (it != elements.end())
@@ -543,7 +544,7 @@ BagsRewriteResponse BagsRewriter::postRewriteMap(const TNode& n) const
       ++it;
     }
     TypeNode t = d_nm->mkBagType(n[0].getType().getRangeType());
-    Node ret = NormalForm::constructConstantBagFromElements(t, mappedElements);
+    Node ret = BagsUtils::constructConstantBagFromElements(t, mappedElements);
     return BagsRewriteResponse(ret, Rewrite::MAP_CONST);
   }
   Kind k = n[1].getKind();
@@ -572,6 +573,49 @@ BagsRewriteResponse BagsRewriter::postRewriteMap(const TNode& n) const
   }
 }
 
+BagsRewriteResponse BagsRewriter::postRewriteFilter(const TNode& n) const
+{
+  Assert(n.getKind() == kind::BAG_FILTER);
+  Node P = n[0];
+  Node A = n[1];
+  TypeNode t = A.getType();
+  if (A.isConst())
+  {
+    // (bag.filter p (as bag.empty (Bag T)) = (as bag.empty (Bag T))
+    // (bag.filter p (bag "a" 3) ((bag "b" 2))) =
+    //   (bag.union_disjoint
+    //     (ite (p "a") (bag "a" 3) (as bag.empty (Bag T)))
+    //     (ite (p "b") (bag "b" 2) (as bag.empty (Bag T)))
+
+    Node ret = BagsUtils::evaluateBagFilter(n);
+    return BagsRewriteResponse(ret, Rewrite::FILTER_CONST);
+  }
+  Kind k = A.getKind();
+  switch (k)
+  {
+    case BAG_MAKE:
+    {
+      // (bag.filter p (bag x y)) = (ite (p x) (bag x y) (as bag.empty (Bag T)))
+      Node empty = d_nm->mkConst(EmptyBag(t));
+      Node pOfe = d_nm->mkNode(APPLY_UF, P, A[0]);
+      Node ret = d_nm->mkNode(ITE, pOfe, A, empty);
+      return BagsRewriteResponse(ret, Rewrite::FILTER_BAG_MAKE);
+    }
+
+    case BAG_UNION_DISJOINT:
+    {
+      // (bag.filter p (bag.union_disjoint A B)) =
+      //    (bag.union_disjoint (bag.filter p A) (bag.filter p B))
+      Node a = d_nm->mkNode(BAG_FILTER, n[0], n[1][0]);
+      Node b = d_nm->mkNode(BAG_FILTER, n[0], n[1][1]);
+      Node ret = d_nm->mkNode(BAG_UNION_DISJOINT, a, b);
+      return BagsRewriteResponse(ret, Rewrite::FILTER_UNION_DISJOINT);
+    }
+
+    default: return BagsRewriteResponse(n, Rewrite::NONE);
+  }
+}
+
 BagsRewriteResponse BagsRewriter::postRewriteFold(const TNode& n) const
 {
   Assert(n.getKind() == kind::BAG_FOLD);
@@ -580,7 +624,7 @@ BagsRewriteResponse BagsRewriter::postRewriteFold(const TNode& n) const
   Node bag = n[2];
   if (bag.isConst())
   {
-    Node value = NormalForm::evaluateBagFold(n);
+    Node value = BagsUtils::evaluateBagFold(n);
     return BagsRewriteResponse(value, Rewrite::FOLD_CONST);
   }
   Kind k = bag.getKind();
@@ -591,7 +635,7 @@ BagsRewriteResponse BagsRewriter::postRewriteFold(const TNode& n) const
       if (bag[1].isConst() && bag[1].getConst<Rational>() > Rational(0))
       {
         // (bag.fold f t (bag x n)) = (f t ... (f t (f t x))) n times, n > 0
-        Node value = NormalForm::evaluateBagFold(n);
+        Node value = BagsUtils::evaluateBagFold(n);
         return BagsRewriteResponse(value, Rewrite::FOLD_BAG);
       }
       break;
index b4b1e90435db9cd7e0337594a530c385abef0b2c..3e5b69a1c62d527598c98d2fda8d3563711ce96d 100644 (file)
@@ -228,6 +228,16 @@ class BagsRewriter : public TheoryRewriter
    */
   BagsRewriteResponse postRewriteMap(const TNode& n) const;
 
+  /**
+   *  rewrites for n include:
+   *  - (bag.filter p (as bag.empty (Bag T)) = (as bag.empty (Bag T))
+   *  - (bag.filter p (bag x y)) = (ite (p x) (bag x y) (as bag.empty (Bag T)))
+   *  - (bag.filter p (bag.union_disjoint A B)) =
+   *       (bag.union_disjoint (bag.filter p A) (bag.filter p B))
+   *  where p: T -> Bool
+   */
+  BagsRewriteResponse postRewriteFilter(const TNode& n) const;
+
   /**
    *  rewrites for n include:
    *  - (bag.fold f t (as bag.empty (Bag T1))) = t
diff --git a/src/theory/bags/bags_utils.cpp b/src/theory/bags/bags_utils.cpp
new file mode 100644 (file)
index 0000000..39987ce
--- /dev/null
@@ -0,0 +1,783 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Mudathir Mohamed, Aina Niemetz
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * Utility functions for bags.
+ */
+#include "bags_utils.h"
+
+#include "expr/emptybag.h"
+#include "smt/logic_exception.h"
+#include "theory/sets/normal_form.h"
+#include "theory/type_enumerator.h"
+#include "util/rational.h"
+
+using namespace cvc5::kind;
+
+namespace cvc5 {
+namespace theory {
+namespace bags {
+
+Node BagsUtils::computeDisjointUnion(TypeNode bagType,
+                                     const std::vector<Node>& bags)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  if (bags.empty())
+  {
+    return nm->mkConst(EmptyBag(bagType));
+  }
+  if (bags.size() == 1)
+  {
+    return bags[0];
+  }
+  Node unionDisjoint = bags[0];
+  for (size_t i = 1; i < bags.size(); i++)
+  {
+    if (bags[i].getKind() == BAG_EMPTY)
+    {
+      continue;
+    }
+    unionDisjoint = nm->mkNode(BAG_UNION_DISJOINT, unionDisjoint, bags[i]);
+  }
+  return unionDisjoint;
+}
+
+bool BagsUtils::isConstant(TNode n)
+{
+  if (n.getKind() == BAG_EMPTY)
+  {
+    // empty bags are already normalized
+    return true;
+  }
+  if (n.getKind() == BAG_MAKE)
+  {
+    // see the implementation in MkBagTypeRule::computeIsConst
+    return n.isConst();
+  }
+  if (n.getKind() == BAG_UNION_DISJOINT)
+  {
+    if (!(n[0].getKind() == kind::BAG_MAKE && n[0].isConst()))
+    {
+      // the first child is not a constant
+      return false;
+    }
+    // store the previous element to check the ordering of elements
+    Node previousElement = n[0][0];
+    Node current = n[1];
+    while (current.getKind() == BAG_UNION_DISJOINT)
+    {
+      if (!(current[0].getKind() == kind::BAG_MAKE && current[0].isConst()))
+      {
+        // the current element is not a constant
+        return false;
+      }
+      if (previousElement >= current[0][0])
+      {
+        // the ordering is violated
+        return false;
+      }
+      previousElement = current[0][0];
+      current = current[1];
+    }
+    // check last element
+    if (!(current.getKind() == kind::BAG_MAKE && current.isConst()))
+    {
+      // the last element is not a constant
+      return false;
+    }
+    if (previousElement >= current[0])
+    {
+      // the ordering is violated
+      return false;
+    }
+    return true;
+  }
+
+  // only nodes with kinds EMPTY_BAG, BAG_MAKE, and BAG_UNION_DISJOINT can be
+  // constants
+  return false;
+}
+
+bool BagsUtils::areChildrenConstants(TNode n)
+{
+  return std::all_of(n.begin(), n.end(), [](Node c) { return c.isConst(); });
+}
+
+Node BagsUtils::evaluate(TNode n)
+{
+  Assert(areChildrenConstants(n));
+  if (n.isConst())
+  {
+    // a constant node is already in a normal form
+    return n;
+  }
+  switch (n.getKind())
+  {
+    case BAG_MAKE: return evaluateMakeBag(n);
+    case BAG_COUNT: return evaluateBagCount(n);
+    case BAG_DUPLICATE_REMOVAL: return evaluateDuplicateRemoval(n);
+    case BAG_UNION_DISJOINT: return evaluateUnionDisjoint(n);
+    case BAG_UNION_MAX: return evaluateUnionMax(n);
+    case BAG_INTER_MIN: return evaluateIntersectionMin(n);
+    case BAG_DIFFERENCE_SUBTRACT: return evaluateDifferenceSubtract(n);
+    case BAG_DIFFERENCE_REMOVE: return evaluateDifferenceRemove(n);
+    case BAG_CARD: return evaluateCard(n);
+    case BAG_IS_SINGLETON: return evaluateIsSingleton(n);
+    case BAG_FROM_SET: return evaluateFromSet(n);
+    case BAG_TO_SET: return evaluateToSet(n);
+    case BAG_MAP: return evaluateBagMap(n);
+    case BAG_FILTER: return evaluateBagFilter(n);
+    case BAG_FOLD: return evaluateBagFold(n);
+    default: break;
+  }
+  Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n
+              << std::endl;
+}
+
+template <typename T1, typename T2, typename T3, typename T4, typename T5>
+Node BagsUtils::evaluateBinaryOperation(const TNode& n,
+                                        T1&& equal,
+                                        T2&& less,
+                                        T3&& greaterOrEqual,
+                                        T4&& remainderOfA,
+                                        T5&& remainderOfB)
+{
+  std::map<Node, Rational> elementsA = getBagElements(n[0]);
+  std::map<Node, Rational> elementsB = getBagElements(n[1]);
+  std::map<Node, Rational> elements;
+
+  std::map<Node, Rational>::const_iterator itA = elementsA.begin();
+  std::map<Node, Rational>::const_iterator itB = elementsB.begin();
+
+  Trace("bags-evaluate") << "[NormalForm::evaluateBinaryOperation "
+                         << n.getKind() << "] " << std::endl
+                         << "elements A: " << elementsA << std::endl
+                         << "elements B: " << elementsB << std::endl;
+
+  while (itA != elementsA.end() && itB != elementsB.end())
+  {
+    if (itA->first == itB->first)
+    {
+      equal(elements, itA, itB);
+      itA++;
+      itB++;
+    }
+    else if (itA->first < itB->first)
+    {
+      less(elements, itA, itB);
+      itA++;
+    }
+    else
+    {
+      greaterOrEqual(elements, itA, itB);
+      itB++;
+    }
+  }
+
+  // handle the remaining elements from A
+  remainderOfA(elements, elementsA, itA);
+  // handle the remaining elements from B
+  remainderOfB(elements, elementsB, itB);
+
+  Trace("bags-evaluate") << "elements: " << elements << std::endl;
+  Node bag = constructConstantBagFromElements(n.getType(), elements);
+  Trace("bags-evaluate") << "bag: " << bag << std::endl;
+  return bag;
+}
+
+std::map<Node, Rational> BagsUtils::getBagElements(TNode n)
+{
+  std::map<Node, Rational> elements;
+  if (n.getKind() == BAG_EMPTY)
+  {
+    return elements;
+  }
+  while (n.getKind() == kind::BAG_UNION_DISJOINT)
+  {
+    Assert(n[0].getKind() == kind::BAG_MAKE);
+    Node element = n[0][0];
+    Rational count = n[0][1].getConst<Rational>();
+    elements[element] = count;
+    n = n[1];
+  }
+  Assert(n.getKind() == kind::BAG_MAKE);
+  Node lastElement = n[0];
+  Rational lastCount = n[1].getConst<Rational>();
+  elements[lastElement] = lastCount;
+  return elements;
+}
+
+Node BagsUtils::constructConstantBagFromElements(
+    TypeNode t, const std::map<Node, Rational>& elements)
+{
+  Assert(t.isBag());
+  NodeManager* nm = NodeManager::currentNM();
+  if (elements.empty())
+  {
+    return nm->mkConst(EmptyBag(t));
+  }
+  TypeNode elementType = t.getBagElementType();
+  std::map<Node, Rational>::const_reverse_iterator it = elements.rbegin();
+  Node bag = nm->mkBag(elementType, it->first, nm->mkConstInt(it->second));
+  while (++it != elements.rend())
+  {
+    Node n = nm->mkBag(elementType, it->first, nm->mkConstInt(it->second));
+    bag = nm->mkNode(BAG_UNION_DISJOINT, n, bag);
+  }
+  return bag;
+}
+
+Node BagsUtils::constructBagFromElements(TypeNode t,
+                                         const std::map<Node, Node>& elements)
+{
+  Assert(t.isBag());
+  NodeManager* nm = NodeManager::currentNM();
+  if (elements.empty())
+  {
+    return nm->mkConst(EmptyBag(t));
+  }
+  TypeNode elementType = t.getBagElementType();
+  std::map<Node, Node>::const_reverse_iterator it = elements.rbegin();
+  Node bag = nm->mkBag(elementType, it->first, it->second);
+  while (++it != elements.rend())
+  {
+    Node n = nm->mkBag(elementType, it->first, it->second);
+    bag = nm->mkNode(BAG_UNION_DISJOINT, n, bag);
+  }
+  return bag;
+}
+
+Node BagsUtils::evaluateMakeBag(TNode n)
+{
+  // the case where n is const should be handled earlier.
+  // here we handle the case where the multiplicity is zero or negative
+  Assert(n.getKind() == BAG_MAKE && !n.isConst()
+         && n[1].getConst<Rational>().sgn() < 1);
+  Node emptybag = NodeManager::currentNM()->mkConst(EmptyBag(n.getType()));
+  return emptybag;
+}
+
+Node BagsUtils::evaluateBagCount(TNode n)
+{
+  Assert(n.getKind() == BAG_COUNT);
+  // Examples
+  // --------
+  // - (bag.count "x" (as bag.empty (Bag String))) = 0
+  // - (bag.count "x" (bag "y" 5)) = 0
+  // - (bag.count "x" (bag "x" 4)) = 4
+  // - (bag.count "x" (bag.union_disjoint (bag "x" 4) (bag "y" 5)) = 4
+  // - (bag.count "x" (bag.union_disjoint (bag "y" 5) (bag "z" 5)) = 0
+
+  std::map<Node, Rational> elements = getBagElements(n[1]);
+  std::map<Node, Rational>::iterator it = elements.find(n[0]);
+
+  NodeManager* nm = NodeManager::currentNM();
+  if (it != elements.end())
+  {
+    Node count = nm->mkConstInt(it->second);
+    return count;
+  }
+  return nm->mkConstInt(Rational(0));
+}
+
+Node BagsUtils::evaluateDuplicateRemoval(TNode n)
+{
+  Assert(n.getKind() == BAG_DUPLICATE_REMOVAL);
+
+  // Examples
+  // --------
+  //  - (bag.duplicate_removal (as bag.empty (Bag String))) = (as bag.empty (Bag
+  //  String))
+  //  - (bag.duplicate_removal (bag "x" 4)) = (bag "x" 1)
+  //  - (bag.duplicate_removal (bag.disjoint_union (bag "x" 3) (bag "y" 5)) =
+  //     (bag.disjoint_union (bag "x" 1) (bag "y" 1)
+
+  std::map<Node, Rational> oldElements = getBagElements(n[0]);
+  // copy elements from the old bag
+  std::map<Node, Rational> newElements(oldElements);
+  Rational one = Rational(1);
+  std::map<Node, Rational>::iterator it;
+  for (it = newElements.begin(); it != newElements.end(); it++)
+  {
+    it->second = one;
+  }
+  Node bag = constructConstantBagFromElements(n[0].getType(), newElements);
+  return bag;
+}
+
+Node BagsUtils::evaluateUnionDisjoint(TNode n)
+{
+  Assert(n.getKind() == BAG_UNION_DISJOINT);
+  // Example
+  // -------
+  // input: (bag.union_disjoint A B)
+  //    where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
+  //          B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
+  // output:
+  //    (bag.union_disjoint A B)
+  //        where A = (bag "x" 7)
+  //              B = (bag.union_disjoint (bag "y" 1) (bag "z" 2)))
+
+  auto equal = [](std::map<Node, Rational>& elements,
+                  std::map<Node, Rational>::const_iterator& itA,
+                  std::map<Node, Rational>::const_iterator& itB) {
+    // compute the sum of the multiplicities
+    elements[itA->first] = itA->second + itB->second;
+  };
+
+  auto less = [](std::map<Node, Rational>& elements,
+                 std::map<Node, Rational>::const_iterator& itA,
+                 std::map<Node, Rational>::const_iterator& itB) {
+    // add the element to the result
+    elements[itA->first] = itA->second;
+  };
+
+  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
+                           std::map<Node, Rational>::const_iterator& itA,
+                           std::map<Node, Rational>::const_iterator& itB) {
+    // add the element to the result
+    elements[itB->first] = itB->second;
+  };
+
+  auto remainderOfA = [](std::map<Node, Rational>& elements,
+                         std::map<Node, Rational>& elementsA,
+                         std::map<Node, Rational>::const_iterator& itA) {
+    // append the remainder of A
+    while (itA != elementsA.end())
+    {
+      elements[itA->first] = itA->second;
+      itA++;
+    }
+  };
+
+  auto remainderOfB = [](std::map<Node, Rational>& elements,
+                         std::map<Node, Rational>& elementsB,
+                         std::map<Node, Rational>::const_iterator& itB) {
+    // append the remainder of B
+    while (itB != elementsB.end())
+    {
+      elements[itB->first] = itB->second;
+      itB++;
+    }
+  };
+
+  return evaluateBinaryOperation(
+      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
+}
+
+Node BagsUtils::evaluateUnionMax(TNode n)
+{
+  Assert(n.getKind() == BAG_UNION_MAX);
+  // Example
+  // -------
+  // input: (bag.union_max A B)
+  //    where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
+  //          B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
+  // output:
+  //    (bag.union_disjoint A B)
+  //        where A = (bag "x" 4)
+  //              B = (bag.union_disjoint (bag "y" 1) (bag "z" 2)))
+
+  auto equal = [](std::map<Node, Rational>& elements,
+                  std::map<Node, Rational>::const_iterator& itA,
+                  std::map<Node, Rational>::const_iterator& itB) {
+    // compute the maximum multiplicity
+    elements[itA->first] = std::max(itA->second, itB->second);
+  };
+
+  auto less = [](std::map<Node, Rational>& elements,
+                 std::map<Node, Rational>::const_iterator& itA,
+                 std::map<Node, Rational>::const_iterator& itB) {
+    // add to the result
+    elements[itA->first] = itA->second;
+  };
+
+  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
+                           std::map<Node, Rational>::const_iterator& itA,
+                           std::map<Node, Rational>::const_iterator& itB) {
+    // add to the result
+    elements[itB->first] = itB->second;
+  };
+
+  auto remainderOfA = [](std::map<Node, Rational>& elements,
+                         std::map<Node, Rational>& elementsA,
+                         std::map<Node, Rational>::const_iterator& itA) {
+    // append the remainder of A
+    while (itA != elementsA.end())
+    {
+      elements[itA->first] = itA->second;
+      itA++;
+    }
+  };
+
+  auto remainderOfB = [](std::map<Node, Rational>& elements,
+                         std::map<Node, Rational>& elementsB,
+                         std::map<Node, Rational>::const_iterator& itB) {
+    // append the remainder of B
+    while (itB != elementsB.end())
+    {
+      elements[itB->first] = itB->second;
+      itB++;
+    }
+  };
+
+  return evaluateBinaryOperation(
+      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
+}
+
+Node BagsUtils::evaluateIntersectionMin(TNode n)
+{
+  Assert(n.getKind() == BAG_INTER_MIN);
+  // Example
+  // -------
+  // input: (bag.inter_min A B)
+  //    where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
+  //          B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
+  // output:
+  //        (bag "x" 3)
+
+  auto equal = [](std::map<Node, Rational>& elements,
+                  std::map<Node, Rational>::const_iterator& itA,
+                  std::map<Node, Rational>::const_iterator& itB) {
+    // compute the minimum multiplicity
+    elements[itA->first] = std::min(itA->second, itB->second);
+  };
+
+  auto less = [](std::map<Node, Rational>& elements,
+                 std::map<Node, Rational>::const_iterator& itA,
+                 std::map<Node, Rational>::const_iterator& itB) {
+    // do nothing
+  };
+
+  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
+                           std::map<Node, Rational>::const_iterator& itA,
+                           std::map<Node, Rational>::const_iterator& itB) {
+    // do nothing
+  };
+
+  auto remainderOfA = [](std::map<Node, Rational>& elements,
+                         std::map<Node, Rational>& elementsA,
+                         std::map<Node, Rational>::const_iterator& itA) {
+    // do nothing
+  };
+
+  auto remainderOfB = [](std::map<Node, Rational>& elements,
+                         std::map<Node, Rational>& elementsB,
+                         std::map<Node, Rational>::const_iterator& itB) {
+    // do nothing
+  };
+
+  return evaluateBinaryOperation(
+      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
+}
+
+Node BagsUtils::evaluateDifferenceSubtract(TNode n)
+{
+  Assert(n.getKind() == BAG_DIFFERENCE_SUBTRACT);
+  // Example
+  // -------
+  // input: (bag.difference_subtract A B)
+  //    where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
+  //          B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
+  // output:
+  //    (bag.union_disjoint (bag "x" 1) (bag "z" 2))
+
+  auto equal = [](std::map<Node, Rational>& elements,
+                  std::map<Node, Rational>::const_iterator& itA,
+                  std::map<Node, Rational>::const_iterator& itB) {
+    // subtract the multiplicities
+    elements[itA->first] = itA->second - itB->second;
+  };
+
+  auto less = [](std::map<Node, Rational>& elements,
+                 std::map<Node, Rational>::const_iterator& itA,
+                 std::map<Node, Rational>::const_iterator& itB) {
+    // itA->first is not in B, so we add it to the difference subtract
+    elements[itA->first] = itA->second;
+  };
+
+  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
+                           std::map<Node, Rational>::const_iterator& itA,
+                           std::map<Node, Rational>::const_iterator& itB) {
+    // itB->first is not in A, so we just skip it
+  };
+
+  auto remainderOfA = [](std::map<Node, Rational>& elements,
+                         std::map<Node, Rational>& elementsA,
+                         std::map<Node, Rational>::const_iterator& itA) {
+    // append the remainder of A
+    while (itA != elementsA.end())
+    {
+      elements[itA->first] = itA->second;
+      itA++;
+    }
+  };
+
+  auto remainderOfB = [](std::map<Node, Rational>& elements,
+                         std::map<Node, Rational>& elementsB,
+                         std::map<Node, Rational>::const_iterator& itB) {
+    // do nothing
+  };
+
+  return evaluateBinaryOperation(
+      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
+}
+
+Node BagsUtils::evaluateDifferenceRemove(TNode n)
+{
+  Assert(n.getKind() == BAG_DIFFERENCE_REMOVE);
+  // Example
+  // -------
+  // input: (bag.difference_remove A B)
+  //    where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
+  //          B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
+  // output:
+  //    (bag "z" 2)
+
+  auto equal = [](std::map<Node, Rational>& elements,
+                  std::map<Node, Rational>::const_iterator& itA,
+                  std::map<Node, Rational>::const_iterator& itB) {
+    // skip the shared element by doing nothing
+  };
+
+  auto less = [](std::map<Node, Rational>& elements,
+                 std::map<Node, Rational>::const_iterator& itA,
+                 std::map<Node, Rational>::const_iterator& itB) {
+    // itA->first is not in B, so we add it to the difference remove
+    elements[itA->first] = itA->second;
+  };
+
+  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
+                           std::map<Node, Rational>::const_iterator& itA,
+                           std::map<Node, Rational>::const_iterator& itB) {
+    // itB->first is not in A, so we just skip it
+  };
+
+  auto remainderOfA = [](std::map<Node, Rational>& elements,
+                         std::map<Node, Rational>& elementsA,
+                         std::map<Node, Rational>::const_iterator& itA) {
+    // append the remainder of A
+    while (itA != elementsA.end())
+    {
+      elements[itA->first] = itA->second;
+      itA++;
+    }
+  };
+
+  auto remainderOfB = [](std::map<Node, Rational>& elements,
+                         std::map<Node, Rational>& elementsB,
+                         std::map<Node, Rational>::const_iterator& itB) {
+    // do nothing
+  };
+
+  return evaluateBinaryOperation(
+      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
+}
+
+Node BagsUtils::evaluateChoose(TNode n)
+{
+  Assert(n.getKind() == BAG_CHOOSE);
+  // Examples
+  // --------
+  // - (bag.choose (bag "x" 4)) = "x"
+
+  if (n[0].getKind() == BAG_MAKE)
+  {
+    return n[0][0];
+  }
+  throw LogicException("BAG_CHOOSE_TOTAL is not supported yet");
+}
+
+Node BagsUtils::evaluateCard(TNode n)
+{
+  Assert(n.getKind() == BAG_CARD);
+  // Examples
+  // --------
+  //  - (card (as bag.empty (Bag String))) = 0
+  //  - (bag.choose (bag "x" 4)) = 4
+  //  - (bag.choose (bag.union_disjoint (bag "x" 4) (bag "y" 1))) = 5
+
+  std::map<Node, Rational> elements = getBagElements(n[0]);
+  Rational sum(0);
+  for (std::pair<Node, Rational> element : elements)
+  {
+    sum += element.second;
+  }
+
+  NodeManager* nm = NodeManager::currentNM();
+  Node sumNode = nm->mkConstInt(sum);
+  return sumNode;
+}
+
+Node BagsUtils::evaluateIsSingleton(TNode n)
+{
+  Assert(n.getKind() == BAG_IS_SINGLETON);
+  // Examples
+  // --------
+  // - (bag.is_singleton (as bag.empty (Bag String))) = false
+  // - (bag.is_singleton (bag "x" 1)) = true
+  // - (bag.is_singleton (bag "x" 4)) = false
+  // - (bag.is_singleton (bag.union_disjoint (bag "x" 1) (bag "y" 1)))
+  // = false
+
+  if (n[0].getKind() == BAG_MAKE && n[0][1].getConst<Rational>().isOne())
+  {
+    return NodeManager::currentNM()->mkConst(true);
+  }
+  return NodeManager::currentNM()->mkConst(false);
+}
+
+Node BagsUtils::evaluateFromSet(TNode n)
+{
+  Assert(n.getKind() == BAG_FROM_SET);
+
+  // Examples
+  // --------
+  //  - (bag.from_set (as set.empty (Set String))) = (as bag.empty (Bag String))
+  //  - (bag.from_set (set.singleton "x")) = (bag "x" 1)
+  //  - (bag.from_set (set.union (set.singleton "x") (set.singleton "y"))) =
+  //     (bag.disjoint_union (bag "x" 1) (bag "y" 1))
+
+  NodeManager* nm = NodeManager::currentNM();
+  std::set<Node> setElements =
+      sets::NormalForm::getElementsFromNormalConstant(n[0]);
+  Rational one = Rational(1);
+  std::map<Node, Rational> bagElements;
+  for (const Node& element : setElements)
+  {
+    bagElements[element] = one;
+  }
+  TypeNode bagType = nm->mkBagType(n[0].getType().getSetElementType());
+  Node bag = constructConstantBagFromElements(bagType, bagElements);
+  return bag;
+}
+
+Node BagsUtils::evaluateToSet(TNode n)
+{
+  Assert(n.getKind() == BAG_TO_SET);
+
+  // Examples
+  // --------
+  //  - (bag.to_set (as bag.empty (Bag String))) = (as set.empty (Set String))
+  //  - (bag.to_set (bag "x" 4)) = (set.singleton "x")
+  //  - (bag.to_set (bag.disjoint_union (bag "x" 3) (bag "y" 5)) =
+  //     (set.union (set.singleton "x") (set.singleton "y")))
+
+  NodeManager* nm = NodeManager::currentNM();
+  std::map<Node, Rational> bagElements = getBagElements(n[0]);
+  std::set<Node> setElements;
+  std::map<Node, Rational>::const_reverse_iterator it;
+  for (it = bagElements.rbegin(); it != bagElements.rend(); it++)
+  {
+    setElements.insert(it->first);
+  }
+  TypeNode setType = nm->mkSetType(n[0].getType().getBagElementType());
+  Node set = sets::NormalForm::elementsToSet(setElements, setType);
+  return set;
+}
+
+Node BagsUtils::evaluateBagMap(TNode n)
+{
+  Assert(n.getKind() == BAG_MAP);
+
+  // Examples
+  // --------
+  // - (bag.map ((lambda ((x String)) "z")
+  //            (bag.union_disjoint (bag "a" 2) (bag "b" 3)) =
+  //     (bag.union_disjoint
+  //       (bag ((lambda ((x String)) "z") "a") 2)
+  //       (bag ((lambda ((x String)) "z") "b") 3)) =
+  //     (bag "z" 5)
+
+  std::map<Node, Rational> elements = BagsUtils::getBagElements(n[1]);
+  std::map<Node, Rational> mappedElements;
+  std::map<Node, Rational>::iterator it = elements.begin();
+  NodeManager* nm = NodeManager::currentNM();
+  while (it != elements.end())
+  {
+    Node mappedElement = nm->mkNode(APPLY_UF, n[0], it->first);
+    mappedElements[mappedElement] = it->second;
+    ++it;
+  }
+  TypeNode t = nm->mkBagType(n[0].getType().getRangeType());
+  Node ret = BagsUtils::constructConstantBagFromElements(t, mappedElements);
+  return ret;
+}
+
+Node BagsUtils::evaluateBagFilter(TNode n)
+{
+  Assert(n.getKind() == BAG_FILTER);
+
+  // - (bag.filter p (as bag.empty (Bag T)) = (as bag.empty (Bag T))
+  // - (bag.filter p (bag.union_disjoint (bag "a" 3) (bag "b" 2))) =
+  //   (bag.union_disjoint
+  //     (ite (p "a") (bag "a" 3) (as bag.empty (Bag T)))
+  //     (ite (p "b") (bag "b" 2) (as bag.empty (Bag T)))
+
+  Node P = n[0];
+  Node A = n[1];
+  TypeNode bagType = A.getType();
+  NodeManager* nm = NodeManager::currentNM();
+  Node empty = nm->mkConst(EmptyBag(bagType));
+
+  std::map<Node, Rational> elements = getBagElements(n[1]);
+  std::vector<Node> bags;
+
+  for (const auto& [e, count] : elements)
+  {
+    Node multiplicity = nm->mkConst(CONST_RATIONAL, count);
+    Node bag = nm->mkBag(bagType.getBagElementType(), e, multiplicity);
+    Node pOfe = nm->mkNode(APPLY_UF, P, e);
+    Node ite = nm->mkNode(ITE, pOfe, bag, empty);
+    bags.push_back(ite);
+  }
+  Node ret = computeDisjointUnion(bagType, bags);
+  return ret;
+}
+
+Node BagsUtils::evaluateBagFold(TNode n)
+{
+  Assert(n.getKind() == BAG_FOLD);
+
+  // Examples
+  // --------
+  // minimum string
+  // - (bag.fold
+  //     ((lambda ((x String) (y String)) (ite (str.< x y) x y))
+  //     ""
+  //     (bag.union_disjoint (bag "a" 2) (bag "b" 3))
+  //   = "a"
+
+  Node f = n[0];    // combining function
+  Node ret = n[1];  // initial value
+  Node A = n[2];    // bag
+  std::map<Node, Rational> elements = BagsUtils::getBagElements(A);
+
+  std::map<Node, Rational>::iterator it = elements.begin();
+  NodeManager* nm = NodeManager::currentNM();
+  while (it != elements.end())
+  {
+    // apply the combination function n times, where n is the multiplicity
+    Rational count = it->second;
+    Assert(count.sgn() >= 0) << "negative multiplicity" << std::endl;
+    while (!count.isZero())
+    {
+      ret = nm->mkNode(APPLY_UF, f, it->first, ret);
+      count = count - 1;
+    }
+    ++it;
+  }
+  return ret;
+}
+
+}  // namespace bags
+}  // namespace theory
+}  // namespace cvc5
diff --git a/src/theory/bags/bags_utils.h b/src/theory/bags/bags_utils.h
new file mode 100644 (file)
index 0000000..61473a0
--- /dev/null
@@ -0,0 +1,223 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Mudathir Mohamed
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * Utility functions for bags.
+ */
+
+#include <expr/node.h>
+
+#include "cvc5_private.h"
+
+#ifndef CVC5__THEORY__BAGS__NORMAL_FORM_H
+#define CVC5__THEORY__BAGS__NORMAL_FORM_H
+
+namespace cvc5 {
+namespace theory {
+namespace bags {
+
+class BagsUtils
+{
+ public:
+  /**
+   * @param bagType type of bags
+   * @param bags a vector of bag nodes
+   * @return disjoint union of these bags
+   */
+  static Node computeDisjointUnion(TypeNode bagType,
+                                   const std::vector<Node>& bags);
+  /**
+   * Returns true if n is considered a to be a (canonical) constant bag value.
+   * A canonical bag value is one whose AST is:
+   *   (bag.union_disjoint (bag e1 c1) ...
+   *      (bag.union_disjoint (bag e_{n-1} c_{n-1}) (bag e_n c_n))))
+   * where c1 ... cn are positive integers, e1 ... en are constants, and the
+   * node identifier of these constants are such that: e1 < ... < en.
+   * Also handles the corner cases of empty bag and bag constructed by bag
+   */
+  static bool isConstant(TNode n);
+  /**
+   * check whether all children of the given node are constants
+   */
+  static bool areChildrenConstants(TNode n);
+  /**
+   * evaluate the node n to a constant value.
+   * As a precondition, children of n should be constants.
+   */
+  static Node evaluate(TNode n);
+
+  /**
+   * get the elements along with their multiplicities in a given bag
+   * @param n a constant node whose type is a bag
+   * @return a map whose keys are constant elements and values are
+   * multiplicities
+   */
+  static std::map<Node, Rational> getBagElements(TNode n);
+
+  /**
+   * construct a constant bag from constant elements
+   * @param t the type of the returned bag
+   * @param elements a map whose keys are constant elements and values are
+   *        multiplicities
+   * @return a constant bag that contains
+   */
+  static Node constructConstantBagFromElements(
+      TypeNode t, const std::map<Node, Rational>& elements);
+
+  /**
+   * construct a constant bag from node elements
+   * @param t the type of the returned bag
+   * @param elements a map whose keys are constant elements and values are
+   *        multiplicities
+   * @return a constant bag that contains
+   */
+  static Node constructBagFromElements(TypeNode t,
+                                       const std::map<Node, Node>& elements);
+
+  /**
+   * @param n has the form (bag.fold f t A) where A is a constant bag
+   * @return a single value which is the result of the fold
+   */
+  static Node evaluateBagFold(TNode n);
+
+  /**
+   * @param n has the form (bag.filter p A) where A is a constant bag
+   * @return A filtered with predicate p
+   */
+  static Node evaluateBagFilter(TNode n);
+
+ private:
+  /**
+   * a high order helper function that return a constant bag that is the result
+   * of (op A B) where op is a binary operator and A, B are constant bags.
+   * The result is computed from the elements of A (elementsA with iterator itA)
+   * and elements of B (elementsB with iterator itB).
+   * The arguments below specify how these iterators are used to generate the
+   * elements of the result (elements).
+   * @param n a node whose kind is a binary operator (bag.union_disjoint,
+   * union_max, intersection_min, difference_subtract, difference_remove) and
+   * whose children are constant bags.
+   * @param equal a lambda expression that receives (elements, itA, itB) and
+   * specify the action that needs to be taken when the elements of itA, itB are
+   * equal.
+   * @param less a lambda expression that receives (elements, itA, itB) and
+   * specify the action that needs to be taken when the element itA is less than
+   * the element of itB.
+   * @param greaterOrEqual less a lambda expression that receives (elements,
+   * itA, itB) and specify the action that needs to be taken when the element
+   * itA is greater than or equal than the element of itB.
+   * @param remainderOfA a lambda expression that receives (elements, elementsA,
+   * itA) and specify the action that needs to be taken to the remaining
+   * elements of A when all elements of B are visited.
+   * @param remainderOfB a lambda expression that receives (elements, elementsB,
+   * itB) and specify the action that needs to be taken to the remaining
+   * elements of B when all elements of A are visited.
+   * @return a constant bag that the result of (op n[0] n[1])
+   */
+  template <typename T1, typename T2, typename T3, typename T4, typename T5>
+  static Node evaluateBinaryOperation(const TNode& n,
+                                      T1&& equal,
+                                      T2&& less,
+                                      T3&& greaterOrEqual,
+                                      T4&& remainderOfA,
+                                      T5&& remainderOfB);
+  /**
+   * evaluate n as follows:
+   * - (bag a 0) = (as bag.empty T) where T is the type of the original bag
+   * - (bag a (-c)) = (as bag.empty T) where T is the type the original bag,
+   *                                and c > 0 is a constant
+   */
+  static Node evaluateMakeBag(TNode n);
+
+  /**
+   * returns the multiplicity in a constant bag
+   * @param n has the form (bag.count x A) where x, A are constants
+   * @return the multiplicity of element x in bag A.
+   */
+  static Node evaluateBagCount(TNode n);
+
+  /**
+   * @param n has the form (bag.duplicate_removal A) where A is a constant bag
+   * @return a constant bag constructed from the elements in A where each
+   * element has multiplicity one
+   */
+  static Node evaluateDuplicateRemoval(TNode n);
+
+  /**
+   * evaluates union disjoint node such that the returned node is a canonical
+   * bag that has the form
+   * (bag.union_disjoint (bag e1 c1) ...
+   *   (bag.union_disjoint  * (bag e_{n-1} c_{n-1}) (bag e_n c_n)))) where
+   *   c1... cn are positive integers, e1 ... en are constants, and the node
+   * identifier of these constants are such that: e1 < ... < en.
+   * @param n has the form (bag.union_disjoint A B) where A, B are constant bags
+   * @return the union disjoint of A and B
+   */
+  static Node evaluateUnionDisjoint(TNode n);
+  /**
+   * @param n has the form (bag.union_max A B) where A, B are constant bags
+   * @return the union max of A and B
+   */
+  static Node evaluateUnionMax(TNode n);
+  /**
+   * @param n has the form (bag.inter_min A B) where A, B are constant bags
+   * @return the intersection min of A and B
+   */
+  static Node evaluateIntersectionMin(TNode n);
+  /**
+   * @param n has the form (bag.difference_subtract A B) where A, B are constant
+   * bags
+   * @return the difference subtract of A and B
+   */
+  static Node evaluateDifferenceSubtract(TNode n);
+  /**
+   * @param n has the form (bag.difference_remove A B) where A, B are constant
+   * bags
+   * @return the difference remove of A and B
+   */
+  static Node evaluateDifferenceRemove(TNode n);
+  /**
+   * @param n has the form (bag.choose A) where A is a constant bag
+   * @return x if n has the form (bag.choose (bag x c)). Otherwise an error is
+   * thrown.
+   */
+  static Node evaluateChoose(TNode n);
+  /**
+   * @param n has the form (bag.card A) where A is a constant bag
+   * @return the number of elements in bag A
+   */
+  static Node evaluateCard(TNode n);
+  /**
+   * @param n has the form (bag.is_singleton A) where A is a constant bag
+   * @return whether the bag A has cardinality one.
+   */
+  static Node evaluateIsSingleton(TNode n);
+  /**
+   * @param n has the form (bag.from_set A) where A is a constant set
+   * @return a constant bag that contains exactly the elements in A.
+   */
+  static Node evaluateFromSet(TNode n);
+  /**
+   * @param n has the form (bag.to_set A) where A is a constant bag
+   * @return a constant set constructed from the elements in A.
+   */
+  static Node evaluateToSet(TNode n);
+  /**
+   * @param n has the form (bag.map f A) where A is a constant bag
+   * @return a constant bag constructed from the images of elements in A.
+   */
+  static Node evaluateBagMap(TNode n);
+};
+}  // namespace bags
+}  // namespace theory
+}  // namespace cvc5
+
+#endif /* CVC5__THEORY__BAGS__NORMAL_FORM_H */
index 4ec009c7d66e9bdd72a0b62a515752ae0e38eb22..2a35fb2bd5c833e2e28adfcbf6fce085bac53652 100644 (file)
@@ -17,9 +17,9 @@
 
 #include "expr/emptybag.h"
 #include "smt/logic_exception.h"
+#include "theory/bags/bags_utils.h"
 #include "theory/bags/inference_generator.h"
 #include "theory/bags/inference_manager.h"
-#include "theory/bags/normal_form.h"
 #include "theory/bags/solver_state.h"
 #include "theory/bags/term_registry.h"
 #include "theory/uf/equality_engine_iterator.h"
index 92aa1a0ea7f801da89b8cd8e10ab85c8b90f7c6a..3247548c5635b54defb8041ebcb1cb68181f7b41 100644 (file)
@@ -517,6 +517,52 @@ InferInfo InferenceGenerator::mapUpwards(
   return inferInfo;
 }
 
+InferInfo InferenceGenerator::filterDownwards(Node n, Node e)
+{
+  Assert(n.getKind() == BAG_FILTER && n[1].getType().isBag());
+  Assert(e.getType().isSubtypeOf(n[1].getType().getBagElementType()));
+
+  Node P = n[0];
+  Node A = n[1];
+  InferInfo inferInfo(d_im, InferenceId::BAGS_FILTER_DOWN);
+
+  Node countA = getMultiplicityTerm(e, A);
+  Node skolem = getSkolem(n, inferInfo);
+  Node count = getMultiplicityTerm(e, skolem);
+
+  Node member = d_nm->mkNode(GEQ, count, d_one);
+  Node pOfe = d_nm->mkNode(APPLY_UF, P, e);
+  Node equal = count.eqNode(countA);
+
+  inferInfo.d_conclusion = pOfe.andNode(equal);
+  inferInfo.d_premises.push_back(member);
+  return inferInfo;
+}
+
+InferInfo InferenceGenerator::filterUpwards(Node n, Node e)
+{
+  Assert(n.getKind() == BAG_FILTER && n[1].getType().isBag());
+  Assert(e.getType().isSubtypeOf(n[1].getType().getBagElementType()));
+
+  Node P = n[0];
+  Node A = n[1];
+  InferInfo inferInfo(d_im, InferenceId::BAGS_FILTER_UP);
+
+  Node countA = getMultiplicityTerm(e, A);
+  Node skolem = getSkolem(n, inferInfo);
+  Node count = getMultiplicityTerm(e, skolem);
+
+  Node member = d_nm->mkNode(GEQ, countA, d_one);
+  Node pOfe = d_nm->mkNode(APPLY_UF, P, e);
+  Node equal = count.eqNode(countA);
+  Node included = pOfe.andNode(equal);
+  Node equalZero = count.eqNode(d_zero);
+  Node excluded = pOfe.notNode().andNode(equalZero);
+  inferInfo.d_conclusion = included.orNode(excluded);
+  inferInfo.d_premises.push_back(member);
+  return inferInfo;
+}
+
 }  // namespace bags
 }  // namespace theory
 }  // namespace cvc5
index 2815058b263ea68eab9a9a4df65487482f994321..3d74dbaa23148b97c72cde5816d55697ef721c15 100644 (file)
@@ -262,6 +262,34 @@ class InferenceGenerator
    */
   InferInfo mapUpwards(Node n, Node uf, Node preImageSize, Node y, Node x);
 
+  /**
+   * @param n is (bag.filter p A) where p is a function (-> E Bool),
+   * A a bag of type (Bag E)
+   * @param e is an element of type E
+   * @return an inference that represents the following implication
+   * (=>
+   *   (bag.member e skolem)
+   *   (and
+   *     (p e)
+   *     (= (bag.count e skolem) (bag.count A)))
+   * where skolem is a variable equals (bag.filter p A)
+   */
+  InferInfo filterDownwards(Node n, Node e);
+
+  /**
+   * @param n is (bag.filter p A) where p is a function (-> E Bool),
+   * A a bag of type (Bag E)
+   * @param e is an element of type E
+   * @return an inference that represents the following implication
+   * (=>
+   *   (bag.member e A)
+   *   (or
+   *     (and (p e) (= (bag.count e skolem) (bag.count A)))
+   *     (and (not (p e)) (= (bag.count e skolem) 0)))
+   * where skolem is a variable equals (bag.filter p A)
+   */
+  InferInfo filterUpwards(Node n, Node e);
+
   /**
    * @param element of type T
    * @param bag of type (bag T)
index d83be5e211b913ea28d544b1dc4b79b77bf37654..7d995dd7bd0d64c40c84ae3ff0a87deaf323f241 100644 (file)
@@ -77,6 +77,10 @@ operator BAG_CHOOSE        1  "return an element in the bag given as a parameter
 # of the second argument, a bag of type (Bag T1), and returns a bag of type (Bag T2).
 operator BAG_MAP           2  "bag map function"
 
+# The bag.filter operator takes a predicate of type (-> T Bool) and a bag of type (Bag T)
+# and return the same bag excluding those elements that do not satisfy the predicate
+operator BAG_FILTER        2  "bag filter operator"
+
 # bag.fold operator combines elements of a bag into a single value.
 # (bag.fold f t B) folds the elements of bag B starting with term t and using
 # the combining function f.
@@ -103,6 +107,7 @@ typerule BAG_IS_SINGLETON        ::cvc5::theory::bags::IsSingletonTypeRule
 typerule BAG_FROM_SET            ::cvc5::theory::bags::FromSetTypeRule
 typerule BAG_TO_SET              ::cvc5::theory::bags::ToSetTypeRule
 typerule BAG_MAP                 ::cvc5::theory::bags::BagMapTypeRule
+typerule BAG_FILTER              ::cvc5::theory::bags::BagFilterTypeRule
 typerule BAG_FOLD                ::cvc5::theory::bags::BagFoldTypeRule
 
 construle BAG_UNION_DISJOINT     ::cvc5::theory::bags::BinaryOperatorTypeRule
diff --git a/src/theory/bags/normal_form.cpp b/src/theory/bags/normal_form.cpp
deleted file mode 100644 (file)
index 6cf26d3..0000000
+++ /dev/null
@@ -1,727 +0,0 @@
-/******************************************************************************
- * Top contributors (to current version):
- *   Mudathir Mohamed, Aina Niemetz
- *
- * This file is part of the cvc5 project.
- *
- * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
- * in the top-level source directory and their institutional affiliations.
- * All rights reserved.  See the file COPYING in the top-level source
- * directory for licensing information.
- * ****************************************************************************
- *
- * Normal form for bag constants.
- */
-#include "normal_form.h"
-
-#include "expr/emptybag.h"
-#include "smt/logic_exception.h"
-#include "theory/sets/normal_form.h"
-#include "theory/type_enumerator.h"
-#include "util/rational.h"
-
-using namespace cvc5::kind;
-
-namespace cvc5 {
-namespace theory {
-namespace bags {
-
-bool NormalForm::isConstant(TNode n)
-{
-  if (n.getKind() == BAG_EMPTY)
-  {
-    // empty bags are already normalized
-    return true;
-  }
-  if (n.getKind() == BAG_MAKE)
-  {
-    // see the implementation in MkBagTypeRule::computeIsConst
-    return n.isConst();
-  }
-  if (n.getKind() == BAG_UNION_DISJOINT)
-  {
-    if (!(n[0].getKind() == kind::BAG_MAKE && n[0].isConst()))
-    {
-      // the first child is not a constant
-      return false;
-    }
-    // store the previous element to check the ordering of elements
-    Node previousElement = n[0][0];
-    Node current = n[1];
-    while (current.getKind() == BAG_UNION_DISJOINT)
-    {
-      if (!(current[0].getKind() == kind::BAG_MAKE && current[0].isConst()))
-      {
-        // the current element is not a constant
-        return false;
-      }
-      if (previousElement >= current[0][0])
-      {
-        // the ordering is violated
-        return false;
-      }
-      previousElement = current[0][0];
-      current = current[1];
-    }
-    // check last element
-    if (!(current.getKind() == kind::BAG_MAKE && current.isConst()))
-    {
-      // the last element is not a constant
-      return false;
-    }
-    if (previousElement >= current[0])
-    {
-      // the ordering is violated
-      return false;
-    }
-    return true;
-  }
-
-  // only nodes with kinds EMPTY_BAG, BAG_MAKE, and BAG_UNION_DISJOINT can be
-  // constants
-  return false;
-}
-
-bool NormalForm::areChildrenConstants(TNode n)
-{
-  return std::all_of(n.begin(), n.end(), [](Node c) { return c.isConst(); });
-}
-
-Node NormalForm::evaluate(TNode n)
-{
-  Assert(areChildrenConstants(n));
-  if (n.isConst())
-  {
-    // a constant node is already in a normal form
-    return n;
-  }
-  switch (n.getKind())
-  {
-    case BAG_MAKE: return evaluateMakeBag(n);
-    case BAG_COUNT: return evaluateBagCount(n);
-    case BAG_DUPLICATE_REMOVAL: return evaluateDuplicateRemoval(n);
-    case BAG_UNION_DISJOINT: return evaluateUnionDisjoint(n);
-    case BAG_UNION_MAX: return evaluateUnionMax(n);
-    case BAG_INTER_MIN: return evaluateIntersectionMin(n);
-    case BAG_DIFFERENCE_SUBTRACT: return evaluateDifferenceSubtract(n);
-    case BAG_DIFFERENCE_REMOVE: return evaluateDifferenceRemove(n);
-    case BAG_CARD: return evaluateCard(n);
-    case BAG_IS_SINGLETON: return evaluateIsSingleton(n);
-    case BAG_FROM_SET: return evaluateFromSet(n);
-    case BAG_TO_SET: return evaluateToSet(n);
-    case BAG_MAP: return evaluateBagMap(n);
-    case BAG_FOLD: return evaluateBagFold(n);
-    default: break;
-  }
-  Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n
-              << std::endl;
-}
-
-template <typename T1, typename T2, typename T3, typename T4, typename T5>
-Node NormalForm::evaluateBinaryOperation(const TNode& n,
-                                         T1&& equal,
-                                         T2&& less,
-                                         T3&& greaterOrEqual,
-                                         T4&& remainderOfA,
-                                         T5&& remainderOfB)
-{
-  std::map<Node, Rational> elementsA = getBagElements(n[0]);
-  std::map<Node, Rational> elementsB = getBagElements(n[1]);
-  std::map<Node, Rational> elements;
-
-  std::map<Node, Rational>::const_iterator itA = elementsA.begin();
-  std::map<Node, Rational>::const_iterator itB = elementsB.begin();
-
-  Trace("bags-evaluate") << "[NormalForm::evaluateBinaryOperation "
-                         << n.getKind() << "] " << std::endl
-                         << "elements A: " << elementsA << std::endl
-                         << "elements B: " << elementsB << std::endl;
-
-  while (itA != elementsA.end() && itB != elementsB.end())
-  {
-    if (itA->first == itB->first)
-    {
-      equal(elements, itA, itB);
-      itA++;
-      itB++;
-    }
-    else if (itA->first < itB->first)
-    {
-      less(elements, itA, itB);
-      itA++;
-    }
-    else
-    {
-      greaterOrEqual(elements, itA, itB);
-      itB++;
-    }
-  }
-
-  // handle the remaining elements from A
-  remainderOfA(elements, elementsA, itA);
-  // handle the remaining elements from B
-  remainderOfB(elements, elementsB, itB);
-
-  Trace("bags-evaluate") << "elements: " << elements << std::endl;
-  Node bag = constructConstantBagFromElements(n.getType(), elements);
-  Trace("bags-evaluate") << "bag: " << bag << std::endl;
-  return bag;
-}
-
-std::map<Node, Rational> NormalForm::getBagElements(TNode n)
-{
-  std::map<Node, Rational> elements;
-  if (n.getKind() == BAG_EMPTY)
-  {
-    return elements;
-  }
-  while (n.getKind() == kind::BAG_UNION_DISJOINT)
-  {
-    Assert(n[0].getKind() == kind::BAG_MAKE);
-    Node element = n[0][0];
-    Rational count = n[0][1].getConst<Rational>();
-    elements[element] = count;
-    n = n[1];
-  }
-  Assert(n.getKind() == kind::BAG_MAKE);
-  Node lastElement = n[0];
-  Rational lastCount = n[1].getConst<Rational>();
-  elements[lastElement] = lastCount;
-  return elements;
-}
-
-Node NormalForm::constructConstantBagFromElements(
-    TypeNode t, const std::map<Node, Rational>& elements)
-{
-  Assert(t.isBag());
-  NodeManager* nm = NodeManager::currentNM();
-  if (elements.empty())
-  {
-    return nm->mkConst(EmptyBag(t));
-  }
-  TypeNode elementType = t.getBagElementType();
-  std::map<Node, Rational>::const_reverse_iterator it = elements.rbegin();
-  Node bag = nm->mkBag(elementType, it->first, nm->mkConstInt(it->second));
-  while (++it != elements.rend())
-  {
-    Node n = nm->mkBag(elementType, it->first, nm->mkConstInt(it->second));
-    bag = nm->mkNode(BAG_UNION_DISJOINT, n, bag);
-  }
-  return bag;
-}
-
-Node NormalForm::constructBagFromElements(TypeNode t,
-                                          const std::map<Node, Node>& elements)
-{
-  Assert(t.isBag());
-  NodeManager* nm = NodeManager::currentNM();
-  if (elements.empty())
-  {
-    return nm->mkConst(EmptyBag(t));
-  }
-  TypeNode elementType = t.getBagElementType();
-  std::map<Node, Node>::const_reverse_iterator it = elements.rbegin();
-  Node bag = nm->mkBag(elementType, it->first, it->second);
-  while (++it != elements.rend())
-  {
-    Node n = nm->mkBag(elementType, it->first, it->second);
-    bag = nm->mkNode(BAG_UNION_DISJOINT, n, bag);
-  }
-  return bag;
-}
-
-Node NormalForm::evaluateMakeBag(TNode n)
-{
-  // the case where n is const should be handled earlier.
-  // here we handle the case where the multiplicity is zero or negative
-  Assert(n.getKind() == BAG_MAKE && !n.isConst()
-         && n[1].getConst<Rational>().sgn() < 1);
-  Node emptybag = NodeManager::currentNM()->mkConst(EmptyBag(n.getType()));
-  return emptybag;
-}
-
-Node NormalForm::evaluateBagCount(TNode n)
-{
-  Assert(n.getKind() == BAG_COUNT);
-  // Examples
-  // --------
-  // - (bag.count "x" (as bag.empty (Bag String))) = 0
-  // - (bag.count "x" (bag "y" 5)) = 0
-  // - (bag.count "x" (bag "x" 4)) = 4
-  // - (bag.count "x" (bag.union_disjoint (bag "x" 4) (bag "y" 5)) = 4
-  // - (bag.count "x" (bag.union_disjoint (bag "y" 5) (bag "z" 5)) = 0
-
-  std::map<Node, Rational> elements = getBagElements(n[1]);
-  std::map<Node, Rational>::iterator it = elements.find(n[0]);
-
-  NodeManager* nm = NodeManager::currentNM();
-  if (it != elements.end())
-  {
-    Node count = nm->mkConstInt(it->second);
-    return count;
-  }
-  return nm->mkConstInt(Rational(0));
-}
-
-Node NormalForm::evaluateDuplicateRemoval(TNode n)
-{
-  Assert(n.getKind() == BAG_DUPLICATE_REMOVAL);
-
-  // Examples
-  // --------
-  //  - (bag.duplicate_removal (as bag.empty (Bag String))) = (as bag.empty (Bag
-  //  String))
-  //  - (bag.duplicate_removal (bag "x" 4)) = (bag "x" 1)
-  //  - (bag.duplicate_removal (bag.disjoint_union (bag "x" 3) (bag "y" 5)) =
-  //     (bag.disjoint_union (bag "x" 1) (bag "y" 1)
-
-  std::map<Node, Rational> oldElements = getBagElements(n[0]);
-  // copy elements from the old bag
-  std::map<Node, Rational> newElements(oldElements);
-  Rational one = Rational(1);
-  std::map<Node, Rational>::iterator it;
-  for (it = newElements.begin(); it != newElements.end(); it++)
-  {
-    it->second = one;
-  }
-  Node bag = constructConstantBagFromElements(n[0].getType(), newElements);
-  return bag;
-}
-
-Node NormalForm::evaluateUnionDisjoint(TNode n)
-{
-  Assert(n.getKind() == BAG_UNION_DISJOINT);
-  // Example
-  // -------
-  // input: (bag.union_disjoint A B)
-  //    where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
-  //          B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
-  // output:
-  //    (bag.union_disjoint A B)
-  //        where A = (bag "x" 7)
-  //              B = (bag.union_disjoint (bag "y" 1) (bag "z" 2)))
-
-  auto equal = [](std::map<Node, Rational>& elements,
-                  std::map<Node, Rational>::const_iterator& itA,
-                  std::map<Node, Rational>::const_iterator& itB) {
-    // compute the sum of the multiplicities
-    elements[itA->first] = itA->second + itB->second;
-  };
-
-  auto less = [](std::map<Node, Rational>& elements,
-                 std::map<Node, Rational>::const_iterator& itA,
-                 std::map<Node, Rational>::const_iterator& itB) {
-    // add the element to the result
-    elements[itA->first] = itA->second;
-  };
-
-  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
-                           std::map<Node, Rational>::const_iterator& itA,
-                           std::map<Node, Rational>::const_iterator& itB) {
-    // add the element to the result
-    elements[itB->first] = itB->second;
-  };
-
-  auto remainderOfA = [](std::map<Node, Rational>& elements,
-                         std::map<Node, Rational>& elementsA,
-                         std::map<Node, Rational>::const_iterator& itA) {
-    // append the remainder of A
-    while (itA != elementsA.end())
-    {
-      elements[itA->first] = itA->second;
-      itA++;
-    }
-  };
-
-  auto remainderOfB = [](std::map<Node, Rational>& elements,
-                         std::map<Node, Rational>& elementsB,
-                         std::map<Node, Rational>::const_iterator& itB) {
-    // append the remainder of B
-    while (itB != elementsB.end())
-    {
-      elements[itB->first] = itB->second;
-      itB++;
-    }
-  };
-
-  return evaluateBinaryOperation(
-      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
-}
-
-Node NormalForm::evaluateUnionMax(TNode n)
-{
-  Assert(n.getKind() == BAG_UNION_MAX);
-  // Example
-  // -------
-  // input: (bag.union_max A B)
-  //    where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
-  //          B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
-  // output:
-  //    (bag.union_disjoint A B)
-  //        where A = (bag "x" 4)
-  //              B = (bag.union_disjoint (bag "y" 1) (bag "z" 2)))
-
-  auto equal = [](std::map<Node, Rational>& elements,
-                  std::map<Node, Rational>::const_iterator& itA,
-                  std::map<Node, Rational>::const_iterator& itB) {
-    // compute the maximum multiplicity
-    elements[itA->first] = std::max(itA->second, itB->second);
-  };
-
-  auto less = [](std::map<Node, Rational>& elements,
-                 std::map<Node, Rational>::const_iterator& itA,
-                 std::map<Node, Rational>::const_iterator& itB) {
-    // add to the result
-    elements[itA->first] = itA->second;
-  };
-
-  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
-                           std::map<Node, Rational>::const_iterator& itA,
-                           std::map<Node, Rational>::const_iterator& itB) {
-    // add to the result
-    elements[itB->first] = itB->second;
-  };
-
-  auto remainderOfA = [](std::map<Node, Rational>& elements,
-                         std::map<Node, Rational>& elementsA,
-                         std::map<Node, Rational>::const_iterator& itA) {
-    // append the remainder of A
-    while (itA != elementsA.end())
-    {
-      elements[itA->first] = itA->second;
-      itA++;
-    }
-  };
-
-  auto remainderOfB = [](std::map<Node, Rational>& elements,
-                         std::map<Node, Rational>& elementsB,
-                         std::map<Node, Rational>::const_iterator& itB) {
-    // append the remainder of B
-    while (itB != elementsB.end())
-    {
-      elements[itB->first] = itB->second;
-      itB++;
-    }
-  };
-
-  return evaluateBinaryOperation(
-      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
-}
-
-Node NormalForm::evaluateIntersectionMin(TNode n)
-{
-  Assert(n.getKind() == BAG_INTER_MIN);
-  // Example
-  // -------
-  // input: (bag.inter_min A B)
-  //    where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
-  //          B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
-  // output:
-  //        (bag "x" 3)
-
-  auto equal = [](std::map<Node, Rational>& elements,
-                  std::map<Node, Rational>::const_iterator& itA,
-                  std::map<Node, Rational>::const_iterator& itB) {
-    // compute the minimum multiplicity
-    elements[itA->first] = std::min(itA->second, itB->second);
-  };
-
-  auto less = [](std::map<Node, Rational>& elements,
-                 std::map<Node, Rational>::const_iterator& itA,
-                 std::map<Node, Rational>::const_iterator& itB) {
-    // do nothing
-  };
-
-  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
-                           std::map<Node, Rational>::const_iterator& itA,
-                           std::map<Node, Rational>::const_iterator& itB) {
-    // do nothing
-  };
-
-  auto remainderOfA = [](std::map<Node, Rational>& elements,
-                         std::map<Node, Rational>& elementsA,
-                         std::map<Node, Rational>::const_iterator& itA) {
-    // do nothing
-  };
-
-  auto remainderOfB = [](std::map<Node, Rational>& elements,
-                         std::map<Node, Rational>& elementsB,
-                         std::map<Node, Rational>::const_iterator& itB) {
-    // do nothing
-  };
-
-  return evaluateBinaryOperation(
-      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
-}
-
-Node NormalForm::evaluateDifferenceSubtract(TNode n)
-{
-  Assert(n.getKind() == BAG_DIFFERENCE_SUBTRACT);
-  // Example
-  // -------
-  // input: (bag.difference_subtract A B)
-  //    where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
-  //          B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
-  // output:
-  //    (bag.union_disjoint (bag "x" 1) (bag "z" 2))
-
-  auto equal = [](std::map<Node, Rational>& elements,
-                  std::map<Node, Rational>::const_iterator& itA,
-                  std::map<Node, Rational>::const_iterator& itB) {
-    // subtract the multiplicities
-    elements[itA->first] = itA->second - itB->second;
-  };
-
-  auto less = [](std::map<Node, Rational>& elements,
-                 std::map<Node, Rational>::const_iterator& itA,
-                 std::map<Node, Rational>::const_iterator& itB) {
-    // itA->first is not in B, so we add it to the difference subtract
-    elements[itA->first] = itA->second;
-  };
-
-  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
-                           std::map<Node, Rational>::const_iterator& itA,
-                           std::map<Node, Rational>::const_iterator& itB) {
-    // itB->first is not in A, so we just skip it
-  };
-
-  auto remainderOfA = [](std::map<Node, Rational>& elements,
-                         std::map<Node, Rational>& elementsA,
-                         std::map<Node, Rational>::const_iterator& itA) {
-    // append the remainder of A
-    while (itA != elementsA.end())
-    {
-      elements[itA->first] = itA->second;
-      itA++;
-    }
-  };
-
-  auto remainderOfB = [](std::map<Node, Rational>& elements,
-                         std::map<Node, Rational>& elementsB,
-                         std::map<Node, Rational>::const_iterator& itB) {
-    // do nothing
-  };
-
-  return evaluateBinaryOperation(
-      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
-}
-
-Node NormalForm::evaluateDifferenceRemove(TNode n)
-{
-  Assert(n.getKind() == BAG_DIFFERENCE_REMOVE);
-  // Example
-  // -------
-  // input: (bag.difference_remove A B)
-  //    where A = (bag.union_disjoint (bag "x" 4) (bag "z" 2)))
-  //          B = (bag.union_disjoint (bag "x" 3) (bag "y" 1)))
-  // output:
-  //    (bag "z" 2)
-
-  auto equal = [](std::map<Node, Rational>& elements,
-                  std::map<Node, Rational>::const_iterator& itA,
-                  std::map<Node, Rational>::const_iterator& itB) {
-    // skip the shared element by doing nothing
-  };
-
-  auto less = [](std::map<Node, Rational>& elements,
-                 std::map<Node, Rational>::const_iterator& itA,
-                 std::map<Node, Rational>::const_iterator& itB) {
-    // itA->first is not in B, so we add it to the difference remove
-    elements[itA->first] = itA->second;
-  };
-
-  auto greaterOrEqual = [](std::map<Node, Rational>& elements,
-                           std::map<Node, Rational>::const_iterator& itA,
-                           std::map<Node, Rational>::const_iterator& itB) {
-    // itB->first is not in A, so we just skip it
-  };
-
-  auto remainderOfA = [](std::map<Node, Rational>& elements,
-                         std::map<Node, Rational>& elementsA,
-                         std::map<Node, Rational>::const_iterator& itA) {
-    // append the remainder of A
-    while (itA != elementsA.end())
-    {
-      elements[itA->first] = itA->second;
-      itA++;
-    }
-  };
-
-  auto remainderOfB = [](std::map<Node, Rational>& elements,
-                         std::map<Node, Rational>& elementsB,
-                         std::map<Node, Rational>::const_iterator& itB) {
-    // do nothing
-  };
-
-  return evaluateBinaryOperation(
-      n, equal, less, greaterOrEqual, remainderOfA, remainderOfB);
-}
-
-Node NormalForm::evaluateChoose(TNode n)
-{
-  Assert(n.getKind() == BAG_CHOOSE);
-  // Examples
-  // --------
-  // - (bag.choose (bag "x" 4)) = "x"
-
-  if (n[0].getKind() == BAG_MAKE)
-  {
-    return n[0][0];
-  }
-  throw LogicException("BAG_CHOOSE_TOTAL is not supported yet");
-}
-
-Node NormalForm::evaluateCard(TNode n)
-{
-  Assert(n.getKind() == BAG_CARD);
-  // Examples
-  // --------
-  //  - (card (as bag.empty (Bag String))) = 0
-  //  - (bag.choose (bag "x" 4)) = 4
-  //  - (bag.choose (bag.union_disjoint (bag "x" 4) (bag "y" 1))) = 5
-
-  std::map<Node, Rational> elements = getBagElements(n[0]);
-  Rational sum(0);
-  for (std::pair<Node, Rational> element : elements)
-  {
-    sum += element.second;
-  }
-
-  NodeManager* nm = NodeManager::currentNM();
-  Node sumNode = nm->mkConstInt(sum);
-  return sumNode;
-}
-
-Node NormalForm::evaluateIsSingleton(TNode n)
-{
-  Assert(n.getKind() == BAG_IS_SINGLETON);
-  // Examples
-  // --------
-  // - (bag.is_singleton (as bag.empty (Bag String))) = false
-  // - (bag.is_singleton (bag "x" 1)) = true
-  // - (bag.is_singleton (bag "x" 4)) = false
-  // - (bag.is_singleton (bag.union_disjoint (bag "x" 1) (bag "y" 1)))
-  // = false
-
-  if (n[0].getKind() == BAG_MAKE && n[0][1].getConst<Rational>().isOne())
-  {
-    return NodeManager::currentNM()->mkConst(true);
-  }
-  return NodeManager::currentNM()->mkConst(false);
-}
-
-Node NormalForm::evaluateFromSet(TNode n)
-{
-  Assert(n.getKind() == BAG_FROM_SET);
-
-  // Examples
-  // --------
-  //  - (bag.from_set (as set.empty (Set String))) = (as bag.empty (Bag String))
-  //  - (bag.from_set (set.singleton "x")) = (bag "x" 1)
-  //  - (bag.from_set (set.union (set.singleton "x") (set.singleton "y"))) =
-  //     (bag.disjoint_union (bag "x" 1) (bag "y" 1))
-
-  NodeManager* nm = NodeManager::currentNM();
-  std::set<Node> setElements =
-      sets::NormalForm::getElementsFromNormalConstant(n[0]);
-  Rational one = Rational(1);
-  std::map<Node, Rational> bagElements;
-  for (const Node& element : setElements)
-  {
-    bagElements[element] = one;
-  }
-  TypeNode bagType = nm->mkBagType(n[0].getType().getSetElementType());
-  Node bag = constructConstantBagFromElements(bagType, bagElements);
-  return bag;
-}
-
-Node NormalForm::evaluateToSet(TNode n)
-{
-  Assert(n.getKind() == BAG_TO_SET);
-
-  // Examples
-  // --------
-  //  - (bag.to_set (as bag.empty (Bag String))) = (as set.empty (Set String))
-  //  - (bag.to_set (bag "x" 4)) = (set.singleton "x")
-  //  - (bag.to_set (bag.disjoint_union (bag "x" 3) (bag "y" 5)) =
-  //     (set.union (set.singleton "x") (set.singleton "y")))
-
-  NodeManager* nm = NodeManager::currentNM();
-  std::map<Node, Rational> bagElements = getBagElements(n[0]);
-  std::set<Node> setElements;
-  std::map<Node, Rational>::const_reverse_iterator it;
-  for (it = bagElements.rbegin(); it != bagElements.rend(); it++)
-  {
-    setElements.insert(it->first);
-  }
-  TypeNode setType = nm->mkSetType(n[0].getType().getBagElementType());
-  Node set = sets::NormalForm::elementsToSet(setElements, setType);
-  return set;
-}
-
-Node NormalForm::evaluateBagMap(TNode n)
-{
-  Assert(n.getKind() == BAG_MAP);
-
-  // Examples
-  // --------
-  // - (bag.map ((lambda ((x String)) "z")
-  //            (bag.union_disjoint (bag "a" 2) (bag "b" 3)) =
-  //     (bag.union_disjoint
-  //       (bag ((lambda ((x String)) "z") "a") 2)
-  //       (bag ((lambda ((x String)) "z") "b") 3)) =
-  //     (bag "z" 5)
-
-  std::map<Node, Rational> elements = NormalForm::getBagElements(n[1]);
-  std::map<Node, Rational> mappedElements;
-  std::map<Node, Rational>::iterator it = elements.begin();
-  NodeManager* nm = NodeManager::currentNM();
-  while (it != elements.end())
-  {
-    Node mappedElement = nm->mkNode(APPLY_UF, n[0], it->first);
-    mappedElements[mappedElement] = it->second;
-    ++it;
-  }
-  TypeNode t = nm->mkBagType(n[0].getType().getRangeType());
-  Node ret = NormalForm::constructConstantBagFromElements(t, mappedElements);
-  return ret;
-}
-
-Node NormalForm::evaluateBagFold(TNode n)
-{
-  Assert(n.getKind() == BAG_FOLD);
-
-  // Examples
-  // --------
-  // minimum string
-  // - (bag.fold
-  //     ((lambda ((x String) (y String)) (ite (str.< x y) x y))
-  //     ""
-  //     (bag.union_disjoint (bag "a" 2) (bag "b" 3))
-  //   = "a"
-
-  Node f = n[0];    // combining function
-  Node ret = n[1];  // initial value
-  Node A = n[2];    // bag
-  std::map<Node, Rational> elements = NormalForm::getBagElements(A);
-
-  std::map<Node, Rational>::iterator it = elements.begin();
-  NodeManager* nm = NodeManager::currentNM();
-  while (it != elements.end())
-  {
-    // apply the combination function n times, where n is the multiplicity
-    Rational count = it->second;
-    Assert(count.sgn() >= 0) << "negative multiplicity" << std::endl;
-    while (!count.isZero())
-    {
-      ret = nm->mkNode(APPLY_UF, f, it->first, ret);
-      count = count - 1;
-    }
-    ++it;
-  }
-  return ret;
-}
-
-}  // namespace bags
-}  // namespace theory
-}  // namespace cvc5
diff --git a/src/theory/bags/normal_form.h b/src/theory/bags/normal_form.h
deleted file mode 100644 (file)
index 5275678..0000000
+++ /dev/null
@@ -1,210 +0,0 @@
-/******************************************************************************
- * Top contributors (to current version):
- *   Mudathir Mohamed
- *
- * This file is part of the cvc5 project.
- *
- * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
- * in the top-level source directory and their institutional affiliations.
- * All rights reserved.  See the file COPYING in the top-level source
- * directory for licensing information.
- * ****************************************************************************
- *
- * Normal form for bag constants.
- */
-
-#include <expr/node.h>
-
-#include "cvc5_private.h"
-
-#ifndef CVC5__THEORY__BAGS__NORMAL_FORM_H
-#define CVC5__THEORY__BAGS__NORMAL_FORM_H
-
-namespace cvc5 {
-namespace theory {
-namespace bags {
-
-class NormalForm
-{
- public:
-  /**
-   * Returns true if n is considered a to be a (canonical) constant bag value.
-   * A canonical bag value is one whose AST is:
-   *   (bag.union_disjoint (bag e1 c1) ...
-   *      (bag.union_disjoint (bag e_{n-1} c_{n-1}) (bag e_n c_n))))
-   * where c1 ... cn are positive integers, e1 ... en are constants, and the
-   * node identifier of these constants are such that: e1 < ... < en.
-   * Also handles the corner cases of empty bag and bag constructed by bag
-   */
-  static bool isConstant(TNode n);
-  /**
-   * check whether all children of the given node are constants
-   */
-  static bool areChildrenConstants(TNode n);
-  /**
-   * evaluate the node n to a constant value.
-   * As a precondition, children of n should be constants.
-   */
-  static Node evaluate(TNode n);
-
-  /**
-   * get the elements along with their multiplicities in a given bag
-   * @param n a constant node whose type is a bag
-   * @return a map whose keys are constant elements and values are
-   * multiplicities
-   */
-  static std::map<Node, Rational> getBagElements(TNode n);
-
-  /**
-   * construct a constant bag from constant elements
-   * @param t the type of the returned bag
-   * @param elements a map whose keys are constant elements and values are
-   *        multiplicities
-   * @return a constant bag that contains
-   */
-  static Node constructConstantBagFromElements(
-      TypeNode t, const std::map<Node, Rational>& elements);
-
-  /**
-   * construct a constant bag from node elements
-   * @param t the type of the returned bag
-   * @param elements a map whose keys are constant elements and values are
-   *        multiplicities
-   * @return a constant bag that contains
-   */
-  static Node constructBagFromElements(TypeNode t,
-                                       const std::map<Node, Node>& elements);
-
-  /**
-   * @param n has the form (bag.fold f t A) where A is a constant bag
-   * @return a single value which is the result of the fold
-   */
-  static Node evaluateBagFold(TNode n);
-
- private:
-  /**
-   * a high order helper function that return a constant bag that is the result
-   * of (op A B) where op is a binary operator and A, B are constant bags.
-   * The result is computed from the elements of A (elementsA with iterator itA)
-   * and elements of B (elementsB with iterator itB).
-   * The arguments below specify how these iterators are used to generate the
-   * elements of the result (elements).
-   * @param n a node whose kind is a binary operator (bag.union_disjoint,
-   * union_max, intersection_min, difference_subtract, difference_remove) and
-   * whose children are constant bags.
-   * @param equal a lambda expression that receives (elements, itA, itB) and
-   * specify the action that needs to be taken when the elements of itA, itB are
-   * equal.
-   * @param less a lambda expression that receives (elements, itA, itB) and
-   * specify the action that needs to be taken when the element itA is less than
-   * the element of itB.
-   * @param greaterOrEqual less a lambda expression that receives (elements,
-   * itA, itB) and specify the action that needs to be taken when the element
-   * itA is greater than or equal than the element of itB.
-   * @param remainderOfA a lambda expression that receives (elements, elementsA,
-   * itA) and specify the action that needs to be taken to the remaining
-   * elements of A when all elements of B are visited.
-   * @param remainderOfB a lambda expression that receives (elements, elementsB,
-   * itB) and specify the action that needs to be taken to the remaining
-   * elements of B when all elements of A are visited.
-   * @return a constant bag that the result of (op n[0] n[1])
-   */
-  template <typename T1, typename T2, typename T3, typename T4, typename T5>
-  static Node evaluateBinaryOperation(const TNode& n,
-                                      T1&& equal,
-                                      T2&& less,
-                                      T3&& greaterOrEqual,
-                                      T4&& remainderOfA,
-                                      T5&& remainderOfB);
-  /**
-   * evaluate n as follows:
-   * - (bag a 0) = (as bag.empty T) where T is the type of the original bag
-   * - (bag a (-c)) = (as bag.empty T) where T is the type the original bag,
-   *                                and c > 0 is a constant
-   */
-  static Node evaluateMakeBag(TNode n);
-
-  /**
-   * returns the multiplicity in a constant bag
-   * @param n has the form (bag.count x A) where x, A are constants
-   * @return the multiplicity of element x in bag A.
-   */
-  static Node evaluateBagCount(TNode n);
-
-  /**
-   * @param n has the form (bag.duplicate_removal A) where A is a constant bag
-   * @return a constant bag constructed from the elements in A where each
-   * element has multiplicity one
-   */
-  static Node evaluateDuplicateRemoval(TNode n);
-
-  /**
-   * evaluates union disjoint node such that the returned node is a canonical
-   * bag that has the form
-   * (bag.union_disjoint (bag e1 c1) ...
-   *   (bag.union_disjoint  * (bag e_{n-1} c_{n-1}) (bag e_n c_n)))) where
-   *   c1... cn are positive integers, e1 ... en are constants, and the node
-   * identifier of these constants are such that: e1 < ... < en.
-   * @param n has the form (bag.union_disjoint A B) where A, B are constant bags
-   * @return the union disjoint of A and B
-   */
-  static Node evaluateUnionDisjoint(TNode n);
-  /**
-   * @param n has the form (bag.union_max A B) where A, B are constant bags
-   * @return the union max of A and B
-   */
-  static Node evaluateUnionMax(TNode n);
-  /**
-   * @param n has the form (bag.inter_min A B) where A, B are constant bags
-   * @return the intersection min of A and B
-   */
-  static Node evaluateIntersectionMin(TNode n);
-  /**
-   * @param n has the form (bag.difference_subtract A B) where A, B are constant
-   * bags
-   * @return the difference subtract of A and B
-   */
-  static Node evaluateDifferenceSubtract(TNode n);
-  /**
-   * @param n has the form (bag.difference_remove A B) where A, B are constant
-   * bags
-   * @return the difference remove of A and B
-   */
-  static Node evaluateDifferenceRemove(TNode n);
-  /**
-   * @param n has the form (bag.choose A) where A is a constant bag
-   * @return x if n has the form (bag.choose (bag x c)). Otherwise an error is
-   * thrown.
-   */
-  static Node evaluateChoose(TNode n);
-  /**
-   * @param n has the form (bag.card A) where A is a constant bag
-   * @return the number of elements in bag A
-   */
-  static Node evaluateCard(TNode n);
-  /**
-   * @param n has the form (bag.is_singleton A) where A is a constant bag
-   * @return whether the bag A has cardinality one.
-   */
-  static Node evaluateIsSingleton(TNode n);
-  /**
-   * @param n has the form (bag.from_set A) where A is a constant set
-   * @return a constant bag that contains exactly the elements in A.
-   */
-  static Node evaluateFromSet(TNode n);
-  /**
-   * @param n has the form (bag.to_set A) where A is a constant bag
-   * @return a constant set constructed from the elements in A.
-   */
-  static Node evaluateToSet(TNode n);
-  /**
-   * @param n has the form (bag.map f A) where A is a constant bag
-   * @return a constant bag constructed from the images of elements in A.
-   */
-  static Node evaluateBagMap(TNode n);
-};
-}  // namespace bags
-}  // namespace theory
-}  // namespace cvc5
-
-#endif /* CVC5__THEORY__BAGS__NORMAL_FORM_H */
index d8ed9fb959df82e514587a4f880ac7aa5627c599..9bd0c3a86ca7afe27786448d10a23b4ff2128b29 100644 (file)
@@ -38,6 +38,9 @@ const char* toString(Rewrite r)
     case Rewrite::EQ_CONST_FALSE: return "EQ_CONST_FALSE";
     case Rewrite::EQ_REFL: return "EQ_REFL";
     case Rewrite::EQ_SYM: return "EQ_SYM";
+    case Rewrite::FILTER_CONST: return "FILTER_CONST";
+    case Rewrite::FILTER_BAG_MAKE: return "FILTER_BAG_MAKE";
+    case Rewrite::FILTER_UNION_DISJOINT: return "FILTER_UNION_DISJOINT";
     case Rewrite::FROM_SINGLETON: return "FROM_SINGLETON";
     case Rewrite::FOLD_BAG: return "FOLD_BAG";
     case Rewrite::FOLD_CONST: return "FOLD_CONST";
index 57f10621114aa41ec5172014ab08abca5af8798f..e1ef38c4b14e9e695282cc38a8931460f3f1727c 100644 (file)
@@ -42,6 +42,9 @@ enum class Rewrite : uint32_t
   EQ_CONST_FALSE,
   EQ_REFL,
   EQ_SYM,
+  FILTER_CONST,
+  FILTER_BAG_MAKE,
+  FILTER_UNION_DISJOINT,
   FROM_SINGLETON,
   FOLD_BAG,
   FOLD_CONST,
index 720e97c25dfff90573a36b5444908991e288bbde..37b6415e0dca1d9c670ab971f2ef834e9f5537ac 100644 (file)
@@ -19,7 +19,7 @@
 #include "expr/skolem_manager.h"
 #include "proof/proof_checker.h"
 #include "smt/logic_exception.h"
-#include "theory/bags/normal_form.h"
+#include "theory/bags/bags_utils.h"
 #include "theory/quantifiers/fmf/bounded_integers.h"
 #include "theory/rewriter.h"
 #include "theory/theory_model.h"
@@ -321,7 +321,7 @@ bool TheoryBags::collectModelValues(TheoryModel* m,
       Node value = m->getRepresentative(countSkolem);
       elementReps[key] = value;
     }
-    Node constructedBag = NormalForm::constructBagFromElements(tn, elementReps);
+    Node constructedBag = BagsUtils::constructBagFromElements(tn, elementReps);
     constructedBag = rewrite(constructedBag);
     Trace("bags-model") << "constructed bag for " << n
                         << " is: " << constructedBag << std::endl;
@@ -352,7 +352,8 @@ bool TheoryBags::collectModelValues(TheoryModel* m,
           if (constructedRational < rCardRational
               && !d_env.isFiniteType(elementType))
           {
-            Node newElement = nm->getSkolemManager()->mkDummySkolem("slack", elementType);
+            Node newElement =
+                nm->getSkolemManager()->mkDummySkolem("slack", elementType);
             Trace("bags-model") << "newElement is " << newElement << std::endl;
             Rational difference = rCardRational - constructedRational;
             Node multiplicity = nm->mkConst(CONST_RATIONAL, difference);
index 14fca3297dc0df6bdb4ac3ac40ffbb372f205d50..a24981934bc96467715f2a0db3d3c155487852bf 100644 (file)
@@ -16,7 +16,7 @@
 #include "theory/bags/theory_bags_type_enumerator.h"
 
 #include "expr/emptybag.h"
-#include "theory/bags/normal_form.h"
+#include "theory/bags/bags_utils.h"
 #include "theory_bags_type_enumerator.h"
 #include "util/rational.h"
 
@@ -67,11 +67,10 @@ BagEnumerator& BagEnumerator::operator++()
   else
   {
     // increase the multiplicity of one of the elements in the current bag
-    std::map<Node, Rational> elements =
-        NormalForm::getBagElements(d_currentBag);
+    std::map<Node, Rational> elements = BagsUtils::getBagElements(d_currentBag);
     Node element = elements.begin()->first;
     elements[element] = elements[element] + Rational(1);
-    d_currentBag = NormalForm::constructConstantBagFromElements(
+    d_currentBag = BagsUtils::constructConstantBagFromElements(
         d_currentBag.getType(), elements);
   }
 
index 2d218f8218af1665f4aeec8c86c93b9c72e61c82..689b0e208cd2644dfdb15c4a5ccdd059c07af520 100644 (file)
@@ -20,7 +20,7 @@
 #include "base/check.h"
 #include "expr/emptybag.h"
 #include "theory/bags/bag_make_op.h"
-#include "theory/bags/normal_form.h"
+#include "theory/bags/bags_utils.h"
 #include "util/cardinality.h"
 #include "util/rational.h"
 
@@ -63,7 +63,7 @@ bool BinaryOperatorTypeRule::computeIsConst(NodeManager* nodeManager, TNode n)
   // only UNION_DISJOINT has a const rule in kinds.
   // Other binary operators do not have const rules in kinds
   Assert(n.getKind() == kind::BAG_UNION_DISJOINT);
-  return NormalForm::isConstant(n);
+  return BagsUtils::isConstant(n);
 }
 
 TypeNode SubBagTypeRule::computeType(NodeManager* nodeManager,
@@ -356,6 +356,48 @@ TypeNode BagMapTypeRule::computeType(NodeManager* nodeManager,
   return retType;
 }
 
+TypeNode BagFilterTypeRule::computeType(NodeManager* nodeManager,
+                                        TNode n,
+                                        bool check)
+{
+  Assert(n.getKind() == kind::BAG_FILTER);
+  TypeNode functionType = n[0].getType(check);
+  TypeNode bagType = n[1].getType(check);
+  if (check)
+  {
+    if (!bagType.isBag())
+    {
+      throw TypeCheckingExceptionPrivate(
+          n,
+          "bag.filter operator expects a bag in the second argument, "
+          "a non-bag is found");
+    }
+
+    TypeNode elementType = bagType.getBagElementType();
+
+    if (!(functionType.isFunction()))
+    {
+      std::stringstream ss;
+      ss << "Operator " << n.getKind() << " expects a function of type  (-> "
+         << elementType << " Bool) as a first argument. "
+         << "Found a term of type '" << functionType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+    std::vector<TypeNode> argTypes = functionType.getArgTypes();
+    NodeManager* nm = NodeManager::currentNM();
+    if (!(argTypes.size() == 1 && argTypes[0] == elementType
+          && functionType.getRangeType() == nm->booleanType()))
+    {
+      std::stringstream ss;
+      ss << "Operator " << n.getKind() << " expects a function of type  (-> "
+         << elementType << " Bool). "
+         << "Found a function of type '" << functionType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+  }
+  return bagType;
+}
+
 TypeNode BagFoldTypeRule::computeType(NodeManager* nodeManager,
                                       TNode n,
                                       bool check)
index da9ea75bfe3c221bf9e146cf9863f0583fcd0816..76c179a62fb9dd9a4062c8ee659f146efe9b82f6 100644 (file)
@@ -141,6 +141,15 @@ struct BagMapTypeRule
   static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
 }; /* struct BagMapTypeRule */
 
+/**
+ * Type rule for (bag.filter p B) to make sure p is a unary predicate of type
+ * (-> T Bool) where B is a bag of type (Bag T)
+ */
+struct BagFilterTypeRule
+{
+  static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
+}; /* struct BagFilterTypeRule */
+
 /**
  * Type rule for (bag.fold f t A) to make sure f is a binary operation of type
  * (-> T1 T2 T2), t of type T2, and B is a bag of type (Bag T1)
index d7e5fccbe3ee978177d2e3267084a3b2c4104a31..240d6e29374f3997e6752d689c646546d443f44b 100644 (file)
@@ -120,6 +120,8 @@ const char* toString(InferenceId i)
     case InferenceId::BAGS_DIFFERENCE_REMOVE: return "BAGS_DIFFERENCE_REMOVE";
     case InferenceId::BAGS_DUPLICATE_REMOVAL: return "BAGS_DUPLICATE_REMOVAL";
     case InferenceId::BAGS_MAP: return "BAGS_MAP";
+    case InferenceId::BAGS_FILTER_DOWN: return "BAGS_FILTER_DOWN";
+    case InferenceId::BAGS_FILTER_UP: return "BAGS_FILTER_UP";
     case InferenceId::BAGS_FOLD: return "BAGS_FOLD";
     case InferenceId::BAGS_CARD: return "BAGS_CARD";
 
index 4970c1ceedb884934a01e15210620e105692a436..2fb3ae003b380e207aff2a7f4e3d41ee4fd6d9fd 100644 (file)
@@ -182,6 +182,8 @@ enum class InferenceId
   BAGS_DIFFERENCE_REMOVE,
   BAGS_DUPLICATE_REMOVAL,
   BAGS_MAP,
+  BAGS_FILTER_DOWN,
+  BAGS_FILTER_UP,
   BAGS_FOLD,
   BAGS_CARD,
   // ---------------------------------- end bags theory
index c86be3e7663817cafe4826de3cea4686c5f8ec47..ec3b13caa0f9497119c7692a8d07d06fd164ec4d 100644 (file)
@@ -1653,6 +1653,11 @@ set(regress_1_tests
   regress1/bags/duplicate_removal1.smt2
   regress1/bags/duplicate_removal2.smt2
   regress1/bags/emptybag1.smt2
+  regress1/bags/filter1.smt2
+  regress1/bags/filter2.smt2
+  regress1/bags/filter3.smt2
+  regress1/bags/filter4.smt2
+  regress1/bags/filter5.smt2
   regress1/bags/fol_0000119.smt2
   regress1/bags/fold1.smt2
   regress1/bags/fuzzy1.smt2
diff --git a/test/regress/regress1/bags/filter1.smt2 b/test/regress/regress1/bags/filter1.smt2
new file mode 100644 (file)
index 0000000..65e87c1
--- /dev/null
@@ -0,0 +1,11 @@
+(set-logic ALL)
+(set-info :status sat)
+(declare-fun A () (Bag Int))
+(declare-fun B () (Bag Int))
+(declare-fun x () Int)
+(declare-fun y () Int)
+(declare-fun p (Int) Bool)
+(assert (= A (bag.union_max (bag x 1) (bag y 2))))
+(assert (= B (bag.filter p A)))
+(assert (distinct (p x) (p y)))
+(check-sat)
diff --git a/test/regress/regress1/bags/filter2.smt2 b/test/regress/regress1/bags/filter2.smt2
new file mode 100644 (file)
index 0000000..62b6403
--- /dev/null
@@ -0,0 +1,9 @@
+(set-logic HO_ALL)
+(set-info :status sat)
+(set-option :fmf-bound true)
+(declare-fun A () (Bag Int))
+(declare-fun B () (Bag Int))
+(declare-fun p (Int) Bool)
+(assert (= B (bag.filter p A)))
+(assert (= (bag.count (- 2) B) 57))
+(check-sat)
diff --git a/test/regress/regress1/bags/filter3.smt2 b/test/regress/regress1/bags/filter3.smt2
new file mode 100644 (file)
index 0000000..10f6370
--- /dev/null
@@ -0,0 +1,10 @@
+(set-logic HO_ALL)
+(set-info :status unsat)
+(set-option :fmf-bound true)
+(declare-fun A () (Bag Int))
+(declare-fun B () (Bag Int))
+(define-fun p ((x Int)) Bool (> x 1))
+(assert (= B (bag.filter p A)))
+(assert (= (bag.count 3 B) 57))
+(assert (= (bag.count 3 B) 58))
+(check-sat)
diff --git a/test/regress/regress1/bags/filter4.smt2 b/test/regress/regress1/bags/filter4.smt2
new file mode 100644 (file)
index 0000000..9be6952
--- /dev/null
@@ -0,0 +1,11 @@
+(set-logic HO_ALL)
+(set-info :status unsat)
+(declare-fun A () (Bag Int))
+(declare-fun B () (Bag Int))
+(declare-fun element () Int)
+(declare-fun p (Int) Bool)
+(assert (= B (bag.filter p A)))
+(assert (p element))
+(assert (not (bag.member element B)))
+(assert (bag.member element A))
+(check-sat)
diff --git a/test/regress/regress1/bags/filter5.smt2 b/test/regress/regress1/bags/filter5.smt2
new file mode 100644 (file)
index 0000000..74ca054
--- /dev/null
@@ -0,0 +1,11 @@
+(set-logic HO_ALL)
+(set-info :status unsat)
+(declare-fun A () (Bag Int))
+(declare-fun B () (Bag Int))
+(declare-fun element () Int)
+(declare-fun p (Int) Bool)
+(assert (= B (bag.filter p A)))
+(assert (p element))
+(assert (not (bag.member element A)))
+(assert (bag.member element B))
+(check-sat)
index c7dc3d63666ab45942e5b1b930231ac58e099967..748d327ddc44a10981ae5d1a2caf46809a9d11cc 100644 (file)
@@ -6,7 +6,6 @@
 (declare-fun y () Int)
 (declare-fun f (Int) Int)
 (assert (= A (bag.union_max (bag x 1) (bag y 2))))
-(assert (= A (bag.union_max (bag x 1) (bag y 2))))
 (assert (= B (bag.map f A)))
 (assert (distinct (f x) (f y) x y))
 (check-sat)
index 5f3abfceecb8e005a265fdc74d0a5f58317af951..4c8c41f0bed8b33936c5c6bf143d43b829c18081 100644 (file)
@@ -18,7 +18,7 @@
 #include "expr/emptyset.h"
 #include "test_smt.h"
 #include "theory/bags/bags_rewriter.h"
-#include "theory/bags/normal_form.h"
+#include "theory/bags/bags_utils.h"
 #include "theory/strings/type_enumerator.h"
 #include "util/rational.h"
 #include "util/string.h"
@@ -65,7 +65,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, empty_bag_normal_form)
   Node emptybag = d_nodeManager->mkConst(EmptyBag(d_nodeManager->stringType()));
   // empty bags are in normal form
   ASSERT_TRUE(emptybag.isConst());
-  Node n = NormalForm::evaluate(emptybag);
+  Node n = BagsUtils::evaluate(emptybag);
   ASSERT_EQ(emptybag, n);
 }
 
@@ -89,9 +89,9 @@ TEST_F(TestTheoryWhiteBagsNormalForm, mkBag_constant_element)
 
   ASSERT_FALSE(negative.isConst());
   ASSERT_FALSE(zero.isConst());
-  ASSERT_EQ(emptybag, NormalForm::evaluate(negative));
-  ASSERT_EQ(emptybag, NormalForm::evaluate(zero));
-  ASSERT_EQ(positive, NormalForm::evaluate(positive));
+  ASSERT_EQ(emptybag, BagsUtils::evaluate(negative));
+  ASSERT_EQ(emptybag, BagsUtils::evaluate(zero));
+  ASSERT_EQ(positive, BagsUtils::evaluate(positive));
 }
 
 TEST_F(TestTheoryWhiteBagsNormalForm, bag_count)
@@ -126,25 +126,25 @@ TEST_F(TestTheoryWhiteBagsNormalForm, bag_count)
 
   Node input1 = d_nodeManager->mkNode(BAG_COUNT, x, empty);
   Node output1 = zero;
-  ASSERT_EQ(output1, NormalForm::evaluate(input1));
+  ASSERT_EQ(output1, BagsUtils::evaluate(input1));
 
   Node input2 = d_nodeManager->mkNode(BAG_COUNT, x, y_5);
   Node output2 = zero;
-  ASSERT_EQ(output2, NormalForm::evaluate(input2));
+  ASSERT_EQ(output2, BagsUtils::evaluate(input2));
 
   Node input3 = d_nodeManager->mkNode(BAG_COUNT, x, x_4);
   Node output3 = four;
-  ASSERT_EQ(output2, NormalForm::evaluate(input2));
+  ASSERT_EQ(output2, BagsUtils::evaluate(input2));
 
   Node unionDisjointXY = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_4, y_5);
   Node input4 = d_nodeManager->mkNode(BAG_COUNT, x, unionDisjointXY);
   Node output4 = four;
-  ASSERT_EQ(output3, NormalForm::evaluate(input3));
+  ASSERT_EQ(output3, BagsUtils::evaluate(input3));
 
   Node unionDisjointYZ = d_nodeManager->mkNode(BAG_UNION_DISJOINT, y_5, z_5);
   Node input5 = d_nodeManager->mkNode(BAG_COUNT, x, unionDisjointYZ);
   Node output5 = zero;
-  ASSERT_EQ(output4, NormalForm::evaluate(input4));
+  ASSERT_EQ(output4, BagsUtils::evaluate(input4));
 }
 
 TEST_F(TestTheoryWhiteBagsNormalForm, duplicate_removal)
@@ -161,7 +161,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, duplicate_removal)
       EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
   Node input1 = d_nodeManager->mkNode(BAG_DUPLICATE_REMOVAL, emptybag);
   Node output1 = emptybag;
-  ASSERT_EQ(output1, NormalForm::evaluate(input1));
+  ASSERT_EQ(output1, BagsUtils::evaluate(input1));
 
   Node x = d_nodeManager->mkConst(String("x"));
   Node y = d_nodeManager->mkConst(String("y"));
@@ -186,12 +186,12 @@ TEST_F(TestTheoryWhiteBagsNormalForm, duplicate_removal)
 
   Node input2 = d_nodeManager->mkNode(BAG_DUPLICATE_REMOVAL, x_4);
   Node output2 = x_1;
-  ASSERT_EQ(output2, NormalForm::evaluate(input2));
+  ASSERT_EQ(output2, BagsUtils::evaluate(input2));
 
   Node normalBag = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_4, y_5);
   Node input3 = d_nodeManager->mkNode(BAG_DUPLICATE_REMOVAL, normalBag);
   Node output3 = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_1, y_1);
-  ASSERT_EQ(output3, NormalForm::evaluate(input3));
+  ASSERT_EQ(output3, BagsUtils::evaluate(input3));
 }
 
 TEST_F(TestTheoryWhiteBagsNormalForm, union_max)
@@ -241,7 +241,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, union_max)
       d_nodeManager->mkNode(BAG_UNION_DISJOINT, y_1, z_2));
 
   ASSERT_TRUE(output.isConst());
-  ASSERT_EQ(output, NormalForm::evaluate(input));
+  ASSERT_EQ(output, BagsUtils::evaluate(input));
 }
 
 TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint1)
@@ -265,12 +265,12 @@ TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint1)
   Node unionDisjointAB = d_nodeManager->mkNode(BAG_UNION_DISJOINT, A, B);
   // unionDisjointAB is already in a normal form
   ASSERT_TRUE(unionDisjointAB.isConst());
-  ASSERT_EQ(unionDisjointAB, NormalForm::evaluate(unionDisjointAB));
+  ASSERT_EQ(unionDisjointAB, BagsUtils::evaluate(unionDisjointAB));
 
   Node unionDisjointBA = d_nodeManager->mkNode(BAG_UNION_DISJOINT, B, A);
   // unionDisjointAB is the normal form of unionDisjointBA
   ASSERT_FALSE(unionDisjointBA.isConst());
-  ASSERT_EQ(unionDisjointAB, NormalForm::evaluate(unionDisjointBA));
+  ASSERT_EQ(unionDisjointAB, BagsUtils::evaluate(unionDisjointBA));
 
   Node unionDisjointAB_C =
       d_nodeManager->mkNode(BAG_UNION_DISJOINT, unionDisjointAB, C);
@@ -280,7 +280,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint1)
   // unionDisjointA_BC is the normal form of unionDisjointAB_C
   ASSERT_FALSE(unionDisjointAB_C.isConst());
   ASSERT_TRUE(unionDisjointA_BC.isConst());
-  ASSERT_EQ(unionDisjointA_BC, NormalForm::evaluate(unionDisjointAB_C));
+  ASSERT_EQ(unionDisjointA_BC, BagsUtils::evaluate(unionDisjointAB_C));
 
   Node unionDisjointAA = d_nodeManager->mkNode(BAG_UNION_DISJOINT, A, A);
   Node AA =
@@ -289,7 +289,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint1)
                            d_nodeManager->mkConst(CONST_RATIONAL, Rational(4)));
   ASSERT_FALSE(unionDisjointAA.isConst());
   ASSERT_TRUE(AA.isConst());
-  ASSERT_EQ(AA, NormalForm::evaluate(unionDisjointAA));
+  ASSERT_EQ(AA, BagsUtils::evaluate(unionDisjointAA));
 }
 
 TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint2)
@@ -339,7 +339,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, union_disjoint2)
       d_nodeManager->mkNode(BAG_UNION_DISJOINT, y_1, z_2));
 
   ASSERT_TRUE(output.isConst());
-  ASSERT_EQ(output, NormalForm::evaluate(input));
+  ASSERT_EQ(output, BagsUtils::evaluate(input));
 }
 
 TEST_F(TestTheoryWhiteBagsNormalForm, intersection_min)
@@ -384,7 +384,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, intersection_min)
   Node output = x_3;
 
   ASSERT_TRUE(output.isConst());
-  ASSERT_EQ(output, NormalForm::evaluate(input));
+  ASSERT_EQ(output, BagsUtils::evaluate(input));
 }
 
 TEST_F(TestTheoryWhiteBagsNormalForm, difference_subtract)
@@ -433,7 +433,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, difference_subtract)
   Node output = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_1, z_2);
 
   ASSERT_TRUE(output.isConst());
-  ASSERT_EQ(output, NormalForm::evaluate(input));
+  ASSERT_EQ(output, BagsUtils::evaluate(input));
 }
 
 TEST_F(TestTheoryWhiteBagsNormalForm, difference_remove)
@@ -482,7 +482,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, difference_remove)
   Node output = z_2;
 
   ASSERT_TRUE(output.isConst());
-  ASSERT_EQ(output, NormalForm::evaluate(input));
+  ASSERT_EQ(output, BagsUtils::evaluate(input));
 }
 
 TEST_F(TestTheoryWhiteBagsNormalForm, bag_card)
@@ -509,16 +509,16 @@ TEST_F(TestTheoryWhiteBagsNormalForm, bag_card)
   Node input1 = d_nodeManager->mkNode(BAG_CARD, empty);
   Node output1 = d_nodeManager->mkConst(CONST_RATIONAL, Rational(0));
 
-  ASSERT_EQ(output1, NormalForm::evaluate(input1));
+  ASSERT_EQ(output1, BagsUtils::evaluate(input1));
 
   Node input2 = d_nodeManager->mkNode(BAG_CARD, x_4);
   Node output2 = d_nodeManager->mkConst(CONST_RATIONAL, Rational(4));
-  ASSERT_EQ(output2, NormalForm::evaluate(input2));
+  ASSERT_EQ(output2, BagsUtils::evaluate(input2));
 
   Node union_disjoint = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_4, y_1);
   Node input3 = d_nodeManager->mkNode(BAG_CARD, union_disjoint);
   Node output3 = d_nodeManager->mkConst(CONST_RATIONAL, Rational(5));
-  ASSERT_EQ(output3, NormalForm::evaluate(input3));
+  ASSERT_EQ(output3, BagsUtils::evaluate(input3));
 }
 
 TEST_F(TestTheoryWhiteBagsNormalForm, is_singleton)
@@ -552,20 +552,20 @@ TEST_F(TestTheoryWhiteBagsNormalForm, is_singleton)
 
   Node input1 = d_nodeManager->mkNode(BAG_IS_SINGLETON, empty);
   Node output1 = falseNode;
-  ASSERT_EQ(output1, NormalForm::evaluate(input1));
+  ASSERT_EQ(output1, BagsUtils::evaluate(input1));
 
   Node input2 = d_nodeManager->mkNode(BAG_IS_SINGLETON, x_1);
   Node output2 = trueNode;
-  ASSERT_EQ(output2, NormalForm::evaluate(input2));
+  ASSERT_EQ(output2, BagsUtils::evaluate(input2));
 
   Node input3 = d_nodeManager->mkNode(BAG_IS_SINGLETON, x_4);
   Node output3 = falseNode;
-  ASSERT_EQ(output2, NormalForm::evaluate(input2));
+  ASSERT_EQ(output2, BagsUtils::evaluate(input2));
 
   Node union_disjoint = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_1, y_1);
   Node input4 = d_nodeManager->mkNode(BAG_IS_SINGLETON, union_disjoint);
   Node output4 = falseNode;
-  ASSERT_EQ(output3, NormalForm::evaluate(input3));
+  ASSERT_EQ(output3, BagsUtils::evaluate(input3));
 }
 
 TEST_F(TestTheoryWhiteBagsNormalForm, from_set)
@@ -583,7 +583,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, from_set)
       EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
   Node input1 = d_nodeManager->mkNode(BAG_FROM_SET, emptyset);
   Node output1 = emptybag;
-  ASSERT_EQ(output1, NormalForm::evaluate(input1));
+  ASSERT_EQ(output1, BagsUtils::evaluate(input1));
 
   Node x = d_nodeManager->mkConst(String("x"));
   Node y = d_nodeManager->mkConst(String("y"));
@@ -602,13 +602,13 @@ TEST_F(TestTheoryWhiteBagsNormalForm, from_set)
 
   Node input2 = d_nodeManager->mkNode(BAG_FROM_SET, xSingleton);
   Node output2 = x_1;
-  ASSERT_EQ(output2, NormalForm::evaluate(input2));
+  ASSERT_EQ(output2, BagsUtils::evaluate(input2));
 
   // for normal sets, the first node is the largest, not smallest
   Node normalSet = d_nodeManager->mkNode(SET_UNION, ySingleton, xSingleton);
   Node input3 = d_nodeManager->mkNode(BAG_FROM_SET, normalSet);
   Node output3 = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_1, y_1);
-  ASSERT_EQ(output3, NormalForm::evaluate(input3));
+  ASSERT_EQ(output3, BagsUtils::evaluate(input3));
 }
 
 TEST_F(TestTheoryWhiteBagsNormalForm, to_set)
@@ -626,7 +626,7 @@ TEST_F(TestTheoryWhiteBagsNormalForm, to_set)
       EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
   Node input1 = d_nodeManager->mkNode(BAG_TO_SET, emptybag);
   Node output1 = emptyset;
-  ASSERT_EQ(output1, NormalForm::evaluate(input1));
+  ASSERT_EQ(output1, BagsUtils::evaluate(input1));
 
   Node x = d_nodeManager->mkConst(String("x"));
   Node y = d_nodeManager->mkConst(String("y"));
@@ -645,13 +645,13 @@ TEST_F(TestTheoryWhiteBagsNormalForm, to_set)
 
   Node input2 = d_nodeManager->mkNode(BAG_TO_SET, x_4);
   Node output2 = xSingleton;
-  ASSERT_EQ(output2, NormalForm::evaluate(input2));
+  ASSERT_EQ(output2, BagsUtils::evaluate(input2));
 
   // for normal sets, the first node is the largest, not smallest
   Node normalBag = d_nodeManager->mkNode(BAG_UNION_DISJOINT, x_4, y_5);
   Node input3 = d_nodeManager->mkNode(BAG_TO_SET, normalBag);
   Node output3 = d_nodeManager->mkNode(SET_UNION, ySingleton, xSingleton);
-  ASSERT_EQ(output3, NormalForm::evaluate(input3));
+  ASSERT_EQ(output3, BagsUtils::evaluate(input3));
 }
 }  // namespace test
 }  // namespace cvc5