add bag.fold operator (#7718)
authormudathirmahgoub <mudathirmahgoub@gmail.com>
Thu, 2 Dec 2021 02:12:30 +0000 (20:12 -0600)
committerGitHub <noreply@github.com>
Thu, 2 Dec 2021 02:12:30 +0000 (02:12 +0000)
26 files changed:
src/CMakeLists.txt
src/api/cpp/cvc5.cpp
src/api/cpp/cvc5_kind.h
src/expr/skolem_manager.cpp
src/expr/skolem_manager.h
src/parser/smt2/smt2.cpp
src/printer/smt2/smt2_printer.cpp
src/theory/bags/bag_reduction.cpp [new file with mode: 0644]
src/theory/bags/bag_reduction.h [new file with mode: 0644]
src/theory/bags/bags_rewriter.cpp
src/theory/bags/bags_rewriter.h
src/theory/bags/kinds
src/theory/bags/normal_form.cpp
src/theory/bags/normal_form.h
src/theory/bags/rewrites.cpp
src/theory/bags/rewrites.h
src/theory/bags/theory_bags.cpp
src/theory/bags/theory_bags.h
src/theory/bags/theory_bags_type_rules.cpp
src/theory/bags/theory_bags_type_rules.h
src/theory/inference_id.cpp
src/theory/inference_id.h
test/regress/CMakeLists.txt
test/regress/regress1/bags/fold1.smt2 [new file with mode: 0644]
test/regress/regress1/bags/fold2.smt2 [new file with mode: 0644]
test/unit/theory/theory_bags_rewriter_white.cpp

index 025f499e6952f8bc69ae3b1cd93ee18b73f87f25..96de9afeb70de04a55fbe45ae1697b61dd86c79f 100644 (file)
@@ -535,6 +535,8 @@ libcvc5_add_sources(
   theory/bags/bags_rewriter.h
   theory/bags/bag_solver.cpp
   theory/bags/bag_solver.h
+  theory/bags/bag_reduction.cpp
+  theory/bags/bag_reduction.h
   theory/bags/bags_statistics.cpp
   theory/bags/bags_statistics.h
   theory/bags/infer_info.cpp
index 6129ff89130ca7e9f6656164a297a3abef18428e..c62dde511631fc6afc5c6644f57eb83b766704bb 100644 (file)
@@ -313,6 +313,7 @@ const static std::unordered_map<Kind, cvc5::Kind> s_kinds{
     {BAG_FROM_SET, cvc5::Kind::BAG_FROM_SET},
     {BAG_TO_SET, cvc5::Kind::BAG_TO_SET},
     {BAG_MAP, cvc5::Kind::BAG_MAP},
+    {BAG_FOLD, cvc5::Kind::BAG_FOLD},
     /* Strings ------------------------------------------------------------- */
     {STRING_CONCAT, cvc5::Kind::STRING_CONCAT},
     {STRING_IN_REGEXP, cvc5::Kind::STRING_IN_REGEXP},
@@ -624,6 +625,7 @@ const static std::unordered_map<cvc5::Kind, Kind, cvc5::kind::KindHashFunction>
         {cvc5::Kind::BAG_FROM_SET, BAG_FROM_SET},
         {cvc5::Kind::BAG_TO_SET, BAG_TO_SET},
         {cvc5::Kind::BAG_MAP, BAG_MAP},
+        {cvc5::Kind::BAG_FOLD, BAG_FOLD},
         /* Strings --------------------------------------------------------- */
         {cvc5::Kind::STRING_CONCAT, STRING_CONCAT},
         {cvc5::Kind::STRING_IN_REGEXP, STRING_IN_REGEXP},
index e6a03cbe445826ef6c583a40be106ad0d07cffcd..73843f9b54a64b75c49abbbabc77addb65b9589e 100644 (file)
@@ -2539,6 +2539,22 @@ enum Kind : int32_t
    *   - `Solver::mkTerm(Kind kind, const std::vector<Term>& children) const`
    */
   BAG_MAP,
+  /**
+   * bag.fold operator combines elements of a bag into a single value.
+   * (bag.fold f t B) folds the elements of bag B starting with term t and using
+   * the combining function f.
+   *
+   * Parameters:
+   *   - 1: a binary operation of type (-> T1 T2 T2)
+   *   - 2: an initial value of type T2
+   *   - 2: a bag of type (Bag T1)
+   *
+   * Create with:
+   *   - `Solver::mkTerm(Kind kind, const Term& child1, const Term& child2,
+   * const Term& child3) const`
+   *   - `Solver::mkTerm(Kind kind, const std::vector<Term>& children) const`
+   */
+  BAG_FOLD,
 
   /* Strings --------------------------------------------------------------- */
 
index db976559f12d137572bcf191535846163f1500a9..47651782084e4ffd3aca2c8d77c4037cfcd3e569 100644 (file)
@@ -68,6 +68,10 @@ const char* toString(SkolemFunId id)
     case SkolemFunId::SK_FIRST_MATCH_POST: return "SK_FIRST_MATCH_POST";
     case SkolemFunId::RE_UNFOLD_POS_COMPONENT: return "RE_UNFOLD_POS_COMPONENT";
     case SkolemFunId::BAGS_CHOOSE: return "BAGS_CHOOSE";
+    case SkolemFunId::BAGS_FOLD_CARD: return "BAGS_FOLD_CARD";
+    case SkolemFunId::BAGS_FOLD_COMBINE: return "BAGS_FOLD_COMBINE";
+    case SkolemFunId::BAGS_FOLD_ELEMENTS: return "BAGS_FOLD_ELEMENTS";
+    case SkolemFunId::BAGS_FOLD_UNION_DISJOINT: return "BAGS_FOLD_UNION_DISJOINT";
     case SkolemFunId::BAGS_MAP_PREIMAGE: return "BAGS_MAP_PREIMAGE";
     case SkolemFunId::BAGS_MAP_SUM: return "BAGS_MAP_SUM";
     case SkolemFunId::HO_TYPE_MATCH_PRED: return "HO_TYPE_MATCH_PRED";
index a18de8a2e3ad570b1edfc5cbf769365ae6dcc50c..780413d17976cf08244b043f024ed3f6e157cb7b 100644 (file)
@@ -112,6 +112,10 @@ enum class SkolemFunId
    * i = 0, ..., n.
    */
   RE_UNFOLD_POS_COMPONENT,
+  BAGS_FOLD_CARD,
+  BAGS_FOLD_COMBINE,
+  BAGS_FOLD_ELEMENTS,
+  BAGS_FOLD_UNION_DISJOINT,
   /** An interpreted function for bag.choose operator:
    * (bag.choose A) is expanded as
    * (witness ((x elementType))
index ad380a31c730eec8320963f885d4f3699c19d262..4e1a8aae8486d80fd19a8b0201ae42eb37231bf0 100644 (file)
@@ -629,6 +629,7 @@ Command* Smt2::setLogic(std::string name, bool fromCommand)
     addOperator(api::BAG_FROM_SET, "bag.from_set");
     addOperator(api::BAG_TO_SET, "bag.to_set");
     addOperator(api::BAG_MAP, "bag.map");
+    addOperator(api::BAG_FOLD, "bag.fold");
   }
   if(d_logic.isTheoryEnabled(theory::THEORY_STRINGS)) {
     defineType("String", d_solver->getStringSort(), true, true);
index 13477b7923449a9912242c8323f585c805d6962a..875ca7dc25acce93c99a00eed534ca8c69d1132f 100644 (file)
@@ -1098,6 +1098,7 @@ std::string Smt2Printer::smtKindString(Kind k, Variant v)
   case kind::BAG_FROM_SET: return "bag.from_set";
   case kind::BAG_TO_SET: return "bag.to_set";
   case kind::BAG_MAP: return "bag.map";
+  case kind::BAG_FOLD: return "bag.fold";
 
     // fp theory
   case kind::FLOATINGPOINT_FP: return "fp";
diff --git a/src/theory/bags/bag_reduction.cpp b/src/theory/bags/bag_reduction.cpp
new file mode 100644 (file)
index 0000000..9203a1c
--- /dev/null
@@ -0,0 +1,119 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Mudathir Mohamed
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * bag reduction.
+ */
+
+#include "theory/bags/bag_reduction.h"
+
+#include "expr/bound_var_manager.h"
+#include "expr/emptybag.h"
+#include "expr/skolem_manager.h"
+#include "theory/quantifiers/fmf/bounded_integers.h"
+#include "util/rational.h"
+
+using namespace cvc5;
+using namespace cvc5::kind;
+
+namespace cvc5 {
+namespace theory {
+namespace bags {
+
+BagReduction::BagReduction(Env& env) : EnvObj(env) {}
+
+BagReduction::~BagReduction() {}
+
+/**
+ * A bound variable corresponding to the universally quantified integer
+ * variable used to range over the distinct elements in a bag, used
+ * for axiomatizing the behavior of some term.
+ */
+struct IndexVarAttributeId
+{
+};
+typedef expr::Attribute<IndexVarAttributeId, Node> IndexVarAttribute;
+
+Node BagReduction::reduceFoldOperator(Node node, std::vector<Node>& asserts)
+{
+  Assert(node.getKind() == BAG_FOLD);
+  if (d_env.getLogicInfo().isHigherOrder())
+  {
+    NodeManager* nm = NodeManager::currentNM();
+    SkolemManager* sm = nm->getSkolemManager();
+    Node f = node[0];
+    Node t = node[1];
+    Node A = node[2];
+    Node zero = nm->mkConst(CONST_RATIONAL, Rational(0));
+    Node one = nm->mkConst(CONST_RATIONAL, Rational(1));
+    // types
+    TypeNode bagType = A.getType();
+    TypeNode elementType = A.getType().getBagElementType();
+    TypeNode integerType = nm->integerType();
+    TypeNode ufType = nm->mkFunctionType(integerType, elementType);
+    TypeNode resultType = t.getType();
+    TypeNode combineType = nm->mkFunctionType(integerType, resultType);
+    TypeNode unionDisjointType = nm->mkFunctionType(integerType, bagType);
+    // skolem functions
+    Node n = sm->mkSkolemFunction(SkolemFunId::BAGS_FOLD_CARD, integerType, A);
+    Node uf = sm->mkSkolemFunction(SkolemFunId::BAGS_FOLD_ELEMENTS, ufType, A);
+    Node unionDisjoint = sm->mkSkolemFunction(
+        SkolemFunId::BAGS_FOLD_UNION_DISJOINT, unionDisjointType, A);
+    Node combine = sm->mkSkolemFunction(
+        SkolemFunId::BAGS_FOLD_COMBINE, combineType, {f, t, A});
+
+    BoundVarManager* bvm = nm->getBoundVarManager();
+    Node i = bvm->mkBoundVar<IndexVarAttribute>(node, "i", nm->integerType());
+    Node iList = nm->mkNode(BOUND_VAR_LIST, i);
+    Node iMinusOne = nm->mkNode(MINUS, i, one);
+    Node uf_i = nm->mkNode(APPLY_UF, uf, i);
+    Node combine_0 = nm->mkNode(APPLY_UF, combine, zero);
+    Node combine_iMinusOne = nm->mkNode(APPLY_UF, combine, iMinusOne);
+    Node combine_i = nm->mkNode(APPLY_UF, combine, i);
+    Node combine_n = nm->mkNode(APPLY_UF, combine, n);
+    Node unionDisjoint_0 = nm->mkNode(APPLY_UF, unionDisjoint, zero);
+    Node unionDisjoint_iMinusOne =
+        nm->mkNode(APPLY_UF, unionDisjoint, iMinusOne);
+    Node unionDisjoint_i = nm->mkNode(APPLY_UF, unionDisjoint, i);
+    Node unionDisjoint_n = nm->mkNode(APPLY_UF, unionDisjoint, n);
+    Node combine_0_equal = combine_0.eqNode(t);
+    Node combine_i_equal =
+        combine_i.eqNode(nm->mkNode(APPLY_UF, f, uf_i, combine_iMinusOne));
+    Node unionDisjoint_0_equal =
+        unionDisjoint_0.eqNode(nm->mkConst(EmptyBag(bagType)));
+    Node singleton = nm->mkBag(elementType, uf_i, one);
+
+    Node unionDisjoint_i_equal = unionDisjoint_i.eqNode(
+        nm->mkNode(BAG_UNION_DISJOINT, singleton, unionDisjoint_iMinusOne));
+    Node interval_i =
+        nm->mkNode(AND, nm->mkNode(GEQ, i, one), nm->mkNode(LEQ, i, n));
+
+    Node body_i =
+        nm->mkNode(IMPLIES,
+                   interval_i,
+                   nm->mkNode(AND, combine_i_equal, unionDisjoint_i_equal));
+    Node forAll_i =
+        quantifiers::BoundedIntegers::mkBoundedForall(iList, body_i);
+    Node nonNegative = nm->mkNode(GEQ, n, zero);
+    Node unionDisjoint_n_equal = A.eqNode(unionDisjoint_n);
+    asserts.push_back(forAll_i);
+    asserts.push_back(combine_0_equal);
+    asserts.push_back(unionDisjoint_0_equal);
+    asserts.push_back(unionDisjoint_n_equal);
+    asserts.push_back(nonNegative);
+    return combine_n;
+  }
+  return Node::null();
+}
+
+}  // namespace bags
+}  // namespace theory
+}  // namespace cvc5
diff --git a/src/theory/bags/bag_reduction.h b/src/theory/bags/bag_reduction.h
new file mode 100644 (file)
index 0000000..11f091f
--- /dev/null
@@ -0,0 +1,77 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Mudathir Mohamed
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * bag reduction.
+ */
+
+#ifndef CVC5__BAG_REDUCTION_H
+#define CVC5__BAG_REDUCTION_H
+
+#include <vector>
+
+#include "cvc5_private.h"
+#include "smt/env_obj.h"
+#include "theory/bags/inference_manager.h"
+
+namespace cvc5 {
+namespace theory {
+namespace bags {
+
+/**
+ * class for bag reductions
+ */
+class BagReduction : EnvObj
+{
+ public:
+  BagReduction(Env& env);
+  ~BagReduction();
+
+  /**
+   * @param node a term of the form (bag.fold f t A) where
+   *        f: (-> T1 T2 T2) is a binary operation
+   *        t: T2 is the initial value
+   *        A: (Bag T1) is a bag
+   * @param asserts a list of assertions generated by this reduction
+   * @return the reduction term (combine n) such that
+   * (and
+   *  (forall ((i Int))
+   *    (let ((iMinusOne (- i 1)))
+   *      (let ((uf_i (uf i)))
+   *        (=>
+   *          (and (>= i 1) (<= i n))
+   *          (and
+   *            (= (combine i) (f uf_i (combine iMinusOne)))
+   *            (=
+   *              (unionDisjoint i)
+   *              (bag.union_disjoint
+   *                (bag uf_i 1)
+   *                (unionDisjoint iMinusOne))))))))
+   *   (= (combine 0) t)
+   *   (= (unionDisjoint 0) (as bag.empty (Bag T1)))
+   *   (= A (unionDisjoint n))
+   *   (>= n 0))
+   *   where
+   *   n: Int is the cardinality of bag A
+   *   uf:Int -> T1 is an uninterpreted function that represents elements of A
+   *   combine: Int -> T2 is an uninterpreted function
+   *   unionDisjoint: Int -> (Bag T1) is an uninterpreted function
+   */
+  Node reduceFoldOperator(Node node, std::vector<Node>& asserts);
+
+ private:
+};
+
+}  // namespace bags
+}  // namespace theory
+}  // namespace cvc5
+
+#endif /* CVC5__BAG_REDUCTION_H */
index b8f3b80c99c5cd5a2007a1c5211336faa8dca300..7667318060c336ac24701ef44cd83a0dca6a0e9d 100644 (file)
@@ -90,6 +90,7 @@ RewriteResponse BagsRewriter::postRewrite(TNode n)
       case BAG_FROM_SET: response = rewriteFromSet(n); break;
       case BAG_TO_SET: response = rewriteToSet(n); break;
       case BAG_MAP: response = postRewriteMap(n); break;
+      case BAG_FOLD: response = postRewriteFold(n); break;
       default: response = BagsRewriteResponse(n, Rewrite::NONE); break;
     }
   }
@@ -560,6 +561,45 @@ BagsRewriteResponse BagsRewriter::postRewriteMap(const TNode& n) const
     default: return BagsRewriteResponse(n, Rewrite::NONE);
   }
 }
+
+BagsRewriteResponse BagsRewriter::postRewriteFold(const TNode& n) const
+{
+  Assert(n.getKind() == kind::BAG_FOLD);
+  Node f = n[0];
+  Node t = n[1];
+  Node bag = n[2];
+  if (bag.isConst())
+  {
+    Node value = NormalForm::evaluateBagFold(n);
+    return BagsRewriteResponse(value, Rewrite::FOLD_CONST);
+  }
+  Kind k = bag.getKind();
+  switch (k)
+  {
+    case BAG_MAKE:
+    {
+      if (bag[1].isConst() && bag[1].getConst<Rational>() > Rational(0))
+      {
+        // (bag.fold f t (bag x n)) = (f t ... (f t (f t x))) n times, n > 0
+        Node value = NormalForm::evaluateBagFold(n);
+        return BagsRewriteResponse(value, Rewrite::FOLD_BAG);
+      }
+      break;
+    }
+    case BAG_UNION_DISJOINT:
+    {
+      // (bag.fold f t (bag.union_disjoint A B)) =
+      //       (bag.fold f (bag.fold f t A) B) where A < B to break symmetry
+      Node A = bag[0] < bag[1] ? bag[0] : bag[1];
+      Node B = bag[0] < bag[1] ? bag[1] : bag[0];
+      Node foldA = d_nm->mkNode(BAG_FOLD, f, t, A);
+      Node fold = d_nm->mkNode(BAG_FOLD, f, foldA, B);
+      return BagsRewriteResponse(fold, Rewrite::FOLD_UNION_DISJOINT);
+    }
+    default: return BagsRewriteResponse(n, Rewrite::NONE);
+  }
+  return BagsRewriteResponse(n, Rewrite::NONE);
+}
 }  // namespace bags
 }  // namespace theory
 }  // namespace cvc5
index a938b3bd49473ec7877ecc47ba4c68142de583aa..d666982a7d08c967112e375adb101b1801d1e75b 100644 (file)
@@ -222,6 +222,16 @@ class BagsRewriter : public TheoryRewriter
    */
   BagsRewriteResponse postRewriteMap(const TNode& n) const;
 
+  /**
+   *  rewrites for n include:
+   *  - (bag.fold f t (as bag.empty (Bag T1))) = t
+   *  - (bag.fold f t (bag x n)) = (f t ... (f t (f t x))) n times, where n > 0
+   *  - (bag.fold f t (bag.union_disjoint A B)) =
+   *       (bag.fold f (bag.fold f t A) B) where A < B to break symmetry
+   *  where f: T1 -> T2 -> T2
+   */
+  BagsRewriteResponse postRewriteFold(const TNode& n) const;
+
  private:
   /** Reference to the rewriter statistics. */
   NodeManager* d_nm;
index a5c6e75bf095d2c2f375d467c4ff6e4b0bff3d3c..5e4119fa19b05ddb257d3d844859cb23fc68277b 100644 (file)
@@ -76,6 +76,14 @@ operator BAG_CHOOSE        1  "return an element in the bag given as a parameter
 # of the second argument, a bag of type (Bag T1), and returns a bag of type (Bag T2).
 operator BAG_MAP           2  "bag map function"
 
+# bag.fold operator combines elements of a bag into a single value.
+# (bag.fold f t B) folds the elements of bag B starting with term t and using
+# the combining function f.
+#  f: a binary operation of type (-> T1 T2 T2)
+#  t: an initial value of type T2
+#  B: a bag of type (Bag T1)
+operator BAG_FOLD          3  "bag fold operator"
+
 typerule BAG_UNION_MAX           ::cvc5::theory::bags::BinaryOperatorTypeRule
 typerule BAG_UNION_DISJOINT      ::cvc5::theory::bags::BinaryOperatorTypeRule
 typerule BAG_INTER_MIN           ::cvc5::theory::bags::BinaryOperatorTypeRule
@@ -93,6 +101,7 @@ typerule BAG_IS_SINGLETON        ::cvc5::theory::bags::IsSingletonTypeRule
 typerule BAG_FROM_SET            ::cvc5::theory::bags::FromSetTypeRule
 typerule BAG_TO_SET              ::cvc5::theory::bags::ToSetTypeRule
 typerule BAG_MAP                 ::cvc5::theory::bags::BagMapTypeRule
+typerule BAG_FOLD                ::cvc5::theory::bags::BagFoldTypeRule
 
 construle BAG_UNION_DISJOINT     ::cvc5::theory::bags::BinaryOperatorTypeRule
 construle BAG_MAKE               ::cvc5::theory::bags::BagMakeTypeRule
index 12bf513b5101efb92095368e6d16bbd18f4dd7e3..9a510c6f5b938ab8e4461e16b9a436c144502f48 100644 (file)
@@ -110,6 +110,7 @@ Node NormalForm::evaluate(TNode n)
     case BAG_FROM_SET: return evaluateFromSet(n);
     case BAG_TO_SET: return evaluateToSet(n);
     case BAG_MAP: return evaluateBagMap(n);
+    case BAG_FOLD: return evaluateBagFold(n);
     default: break;
   }
   Unhandled() << "Unexpected bag kind '" << n.getKind() << "' in node " << n
@@ -169,8 +170,6 @@ Node NormalForm::evaluateBinaryOperation(const TNode& n,
 
 std::map<Node, Rational> NormalForm::getBagElements(TNode n)
 {
-  Assert(n.isConst()) << "node " << n << " is not in a normal form"
-                      << std::endl;
   std::map<Node, Rational> elements;
   if (n.getKind() == BAG_EMPTY)
   {
@@ -692,6 +691,41 @@ Node NormalForm::evaluateBagMap(TNode n)
   return ret;
 }
 
+Node NormalForm::evaluateBagFold(TNode n)
+{
+  Assert(n.getKind() == BAG_FOLD);
+
+  // Examples
+  // --------
+  // minimum string
+  // - (bag.fold
+  //     ((lambda ((x String) (y String)) (ite (str.< x y) x y))
+  //     ""
+  //     (bag.union_disjoint (bag "a" 2) (bag "b" 3))
+  //   = "a"
+
+  Node f = n[0];    // combining function
+  Node ret = n[1];  // initial value
+  Node A = n[2];    // bag
+  std::map<Node, Rational> elements = NormalForm::getBagElements(A);
+
+  std::map<Node, Rational>::iterator it = elements.begin();
+  NodeManager* nm = NodeManager::currentNM();
+  while (it != elements.end())
+  {
+    // apply the combination function n times, where n is the multiplicity
+    Rational count = it->second;
+    Assert(count.sgn() >= 0) << "negative multiplicity" << std::endl;
+    while (!count.isZero())
+    {
+      ret = nm->mkNode(APPLY_UF, f, it->first, ret);
+      count = count - 1;
+    }
+    ++it;
+  }
+  return ret;
+}
+
 }  // namespace bags
 }  // namespace theory
 }  // namespace cvc5
index 8ceee2881ec781646559e5f1da01c683ca9d53fc..5275678ffae1d392b3f400e84dabb39e005d4a3c 100644 (file)
@@ -75,6 +75,12 @@ class NormalForm
   static Node constructBagFromElements(TypeNode t,
                                        const std::map<Node, Node>& elements);
 
+  /**
+   * @param n has the form (bag.fold f t A) where A is a constant bag
+   * @return a single value which is the result of the fold
+   */
+  static Node evaluateBagFold(TNode n);
+
  private:
   /**
    * a high order helper function that return a constant bag that is the result
index 896c4f251f40c1432b649a6ce4c38c2dccdf94a5..1a8f8f8491ea6cd36ed80cdcfcf38e109e8feb87 100644 (file)
@@ -38,6 +38,9 @@ const char* toString(Rewrite r)
     case Rewrite::EQ_REFL: return "EQ_REFL";
     case Rewrite::EQ_SYM: return "EQ_SYM";
     case Rewrite::FROM_SINGLETON: return "FROM_SINGLETON";
+    case Rewrite::FOLD_BAG: return "FOLD_BAG";
+    case Rewrite::FOLD_CONST: return "FOLD_CONST";
+    case Rewrite::FOLD_UNION_DISJOINT: return "FOLD_UNION_DISJOINT";
     case Rewrite::IDENTICAL_NODES: return "IDENTICAL_NODES";
     case Rewrite::INTERSECTION_EMPTY_LEFT: return "INTERSECTION_EMPTY_LEFT";
     case Rewrite::INTERSECTION_EMPTY_RIGHT: return "INTERSECTION_EMPTY_RIGHT";
index c5050ea7230386e9683d99354530924ca2068f9a..0b71885992557bf7cd22e05dbd57f5f5a4c4f511 100644 (file)
@@ -42,6 +42,9 @@ enum class Rewrite : uint32_t
   EQ_REFL,
   EQ_SYM,
   FROM_SINGLETON,
+  FOLD_BAG,
+  FOLD_CONST,
+  FOLD_UNION_DISJOINT,
   IDENTICAL_NODES,
   INTERSECTION_EMPTY_LEFT,
   INTERSECTION_EMPTY_RIGHT,
index 4dffbdb00c8a846fe706af1a05148c03cb03e742..68bdb7b1bbc5109fab8e4bde929147f2a9f7b9ac 100644 (file)
@@ -20,6 +20,7 @@
 #include "proof/proof_checker.h"
 #include "smt/logic_exception.h"
 #include "theory/bags/normal_form.h"
+#include "theory/quantifiers/fmf/bounded_integers.h"
 #include "theory/rewriter.h"
 #include "theory/theory_model.h"
 #include "util/rational.h"
@@ -39,7 +40,8 @@ TheoryBags::TheoryBags(Env& env, OutputChannel& out, Valuation valuation)
       d_statistics(),
       d_rewriter(&d_statistics.d_rewrites),
       d_termReg(env, d_state, d_im),
-      d_solver(env, d_state, d_im, d_termReg)
+      d_solver(env, d_state, d_im, d_termReg),
+      d_bagReduction(env)
 {
   // use the official theory state and inference manager objects
   d_theoryState = &d_state;
@@ -87,6 +89,18 @@ TrustNode TheoryBags::ppRewrite(TNode atom, std::vector<SkolemLemma>& lems)
   {
     case kind::BAG_CHOOSE: return expandChooseOperator(atom, lems);
     case kind::BAG_CARD: return expandCardOperator(atom, lems);
+    case kind::BAG_FOLD:
+    {
+      std::vector<Node> asserts;
+      Node ret = d_bagReduction.reduceFoldOperator(atom, asserts);
+      NodeManager* nm = NodeManager::currentNM();
+      Node andNode = nm->mkNode(AND, asserts);
+      d_im.lemma(andNode, InferenceId::BAGS_FOLD);
+      Trace("bags::ppr") << "reduce(" << atom << ") = " << ret
+                         << " such that:" << std::endl
+                         << asserts << std::endl;
+      return TrustNode::mkTrustRewrite(atom, ret, nullptr);
+    }
     default: return TrustNode::null();
   }
 }
@@ -131,9 +145,9 @@ TrustNode TheoryBags::expandChooseOperator(const Node& node,
   return TrustNode::mkTrustRewrite(node, ret, nullptr);
 }
 
-TrustNode TheoryBags::expandCardOperator(TNode n,
-                                         std::vector<SkolemLemma>& vector)
+TrustNode TheoryBags::expandCardOperator(TNode n, std::vector<SkolemLemma>&)
 {
+  Assert(n.getKind() == BAG_CARD);
   if (d_env.getLogicInfo().isHigherOrder())
   {
     // (bag.card A) = (bag.count 1 (bag.map (lambda ((x E)) 1) A)),
index fd28482d4119ec7524cebdfbd228fd3fdeaf0e20..1a8af780e83a7900ac0c114b928a11794f846c4b 100644 (file)
@@ -18,6 +18,7 @@
 #ifndef CVC5__THEORY__BAGS__THEORY_BAGS_H
 #define CVC5__THEORY__BAGS__THEORY_BAGS_H
 
+#include "theory/bags/bag_reduction.h"
 #include "theory/bags/bag_solver.h"
 #include "theory/bags/bags_rewriter.h"
 #include "theory/bags/bags_statistics.h"
@@ -112,6 +113,9 @@ class TheoryBags : public Theory
   /** the main solver for bags */
   BagSolver d_solver;
 
+  /** bag reduction */
+  BagReduction d_bagReduction;
+
   void eqNotifyNewClass(TNode n);
   void eqNotifyMerge(TNode n1, TNode n2);
   void eqNotifyDisequal(TNode t1, TNode t2, TNode reason);
index 2623f3ed7e79682e5e666940249421c384f484c4..fe81fadf5781b114607532e3d44f4109a6b07239 100644 (file)
@@ -327,6 +327,57 @@ TypeNode BagMapTypeRule::computeType(NodeManager* nodeManager,
   return retType;
 }
 
+TypeNode BagFoldTypeRule::computeType(NodeManager* nodeManager,
+                                      TNode n,
+                                      bool check)
+{
+  Assert(n.getKind() == kind::BAG_FOLD);
+  TypeNode functionType = n[0].getType(check);
+  TypeNode initialValueType = n[1].getType(check);
+  TypeNode bagType = n[2].getType(check);
+  if (check)
+  {
+    if (!bagType.isBag())
+    {
+      throw TypeCheckingExceptionPrivate(
+          n,
+          "bag.fold operator expects a bag in the third argument, "
+          "a non-bag is found");
+    }
+
+    TypeNode elementType = bagType.getBagElementType();
+
+    if (!(functionType.isFunction()))
+    {
+      std::stringstream ss;
+      ss << "Operator " << n.getKind() << " expects a function of type  (-> "
+         << elementType << " T2 T2) as a first argument. "
+         << "Found a term of type '" << functionType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+    std::vector<TypeNode> argTypes = functionType.getArgTypes();
+    TypeNode rangeType = functionType.getRangeType();
+    if (!(argTypes.size() == 2 && argTypes[0] == elementType
+          && argTypes[1] == rangeType))
+    {
+      std::stringstream ss;
+      ss << "Operator " << n.getKind() << " expects a function of type  (-> "
+         << elementType << " T2 T2). "
+         << "Found a function of type '" << functionType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+    if (rangeType != initialValueType)
+    {
+      std::stringstream ss;
+      ss << "Operator " << n.getKind() << " expects an initial value of type "
+         << rangeType << ". Found a term of type '" << initialValueType << "'.";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+  }
+  TypeNode retType = n[0].getType().getRangeType();
+  return retType;
+}
+
 Cardinality BagsProperties::computeCardinality(TypeNode type)
 {
   return Cardinality::INTEGERS;
index d7b8b27377782a0000268cf677d0db50fbb2edef..fa2f7831315d605356e6591a8f1adcf8938d4e71 100644 (file)
@@ -132,6 +132,15 @@ struct BagMapTypeRule
   static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
 }; /* struct BagMapTypeRule */
 
+/**
+ * Type rule for (bag.fold f t A) to make sure f is a binary operation of type
+ * (-> T1 T2 T2), t of type T2, and B is a bag of type (Bag T1)
+ */
+struct BagFoldTypeRule
+{
+  static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
+}; /* struct BagFoldTypeRule */
+
 struct BagsProperties
 {
   static Cardinality computeCardinality(TypeNode type);
index 82ae674e2e1b9642d85c62fb2163049d5d64cc27..56d2f0500ee2840b4bf54061e4a8ff584736e895 100644 (file)
@@ -118,6 +118,7 @@ const char* toString(InferenceId i)
     case InferenceId::BAGS_DIFFERENCE_REMOVE: return "BAGS_DIFFERENCE_REMOVE";
     case InferenceId::BAGS_DUPLICATE_REMOVAL: return "BAGS_DUPLICATE_REMOVAL";
     case InferenceId::BAGS_MAP: return "BAGS_MAP";
+    case InferenceId::BAGS_FOLD: return "BAGS_FOLD";
 
     case InferenceId::BV_BITBLAST_CONFLICT: return "BV_BITBLAST_CONFLICT";
     case InferenceId::BV_BITBLAST_INTERNAL_EAGER_LEMMA:
index ad879d7ab6860f491fe79c8fcab425d7c20c04f4..d98d3ff256aac910740304f4c01a878de47f3c77 100644 (file)
@@ -180,6 +180,7 @@ enum class InferenceId
   BAGS_DIFFERENCE_REMOVE,
   BAGS_DUPLICATE_REMOVAL,
   BAGS_MAP,
+  BAGS_FOLD,
   // ---------------------------------- end bags theory
 
   // ---------------------------------- bitvector theory
index cf114711adea0c84be9322a16bd1a7d491c29c42..4169036badac4fae7a84642a12f902ebe4ad61b4 100644 (file)
@@ -1606,6 +1606,7 @@ set(regress_1_tests
   regress1/bags/duplicate_removal1.smt2
   regress1/bags/duplicate_removal2.smt2
   regress1/bags/emptybag1.smt2
+  regress1/bags/fold1.smt2
   regress1/bags/fuzzy1.smt2
   regress1/bags/fuzzy2.smt2
   regress1/bags/fuzzy3.smt2
@@ -2820,6 +2821,8 @@ set(regression_disabled_tests
   regress0/tptp/SYN075+1.p
   regress0/uf/iso_icl_repgen004.smtv1.smt2
   ###
+  # takes around 30 sec
+  regress1/bags/fold2.smt2
   regress1/bug472.smt2
   regress1/datatypes/non-simple-rec-set.smt2
   # results in an assertion failure (see issue #1650).
diff --git a/test/regress/regress1/bags/fold1.smt2 b/test/regress/regress1/bags/fold1.smt2
new file mode 100644 (file)
index 0000000..73caeda
--- /dev/null
@@ -0,0 +1,10 @@
+(set-logic HO_ALL)
+(set-info :status sat)
+(set-option :fmf-bound true)
+(set-option :uf-lazy-ll true)
+(define-fun plus ((x Int) (y Int)) Int (+ x y))
+(declare-fun A () (Bag Int))
+(declare-fun sum () Int)
+(assert (= sum (bag.fold plus 1 A)))
+(assert (= sum 10))
+(check-sat)
diff --git a/test/regress/regress1/bags/fold2.smt2 b/test/regress/regress1/bags/fold2.smt2
new file mode 100644 (file)
index 0000000..9863a11
--- /dev/null
@@ -0,0 +1,15 @@
+(set-logic HO_ALL)
+(set-info :status sat)
+(set-option :fmf-bound true)
+(set-option :uf-lazy-ll true)
+(set-option :strings-exp true)
+(define-fun min ((x String) (y String)) String (ite (str.< x y) x y))
+(declare-fun A () (Bag String))
+(declare-fun x () String)
+(declare-fun minimum () String)
+(assert (= minimum (bag.fold min "zzz" A)))
+(assert (str.< "aaa" minimum ))
+(assert (str.< minimum "zzz"))
+(assert (distinct x minimum))
+(assert (= (bag.count x A) 2))
+(check-sat)
index ee1e894482974b53b8a665e45a61eae3df04bc33..ff98c308a70153050f156292a2c69e5190a5fe13 100644 (file)
@@ -750,7 +750,7 @@ TEST_F(TestTheoryWhiteBagsRewriter, map)
 
   Node empty = d_nodeManager->mkConst(String(""));
   Node xString = d_nodeManager->mkBoundVar("x", d_nodeManager->stringType());
-  Node bound = d_nodeManager->mkNode(kind::BOUND_VAR_LIST, xString);
+  Node bound = d_nodeManager->mkNode(BOUND_VAR_LIST, xString);
   Node lambda = d_nodeManager->mkNode(LAMBDA, bound, empty);
 
   // (bag.map (lambda ((x U))  t) (as bag.empty (Bag String)) =
@@ -800,5 +800,62 @@ TEST_F(TestTheoryWhiteBagsRewriter, map)
   ASSERT_TRUE(rewritten3 == unionDisjointMapK1K2);
 }
 
+TEST_F(TestTheoryWhiteBagsRewriter, fold)
+{
+  TypeNode bagIntegerType =
+      d_nodeManager->mkBagType(d_nodeManager->integerType());
+  Node emptybag = d_nodeManager->mkConst(EmptyBag(bagIntegerType));
+  Node zero = d_nodeManager->mkConst(CONST_RATIONAL, Rational(0));
+  Node one = d_nodeManager->mkConst(CONST_RATIONAL, Rational(1));
+  Node ten = d_nodeManager->mkConst(CONST_RATIONAL, Rational(10));
+  Node n = d_nodeManager->mkConst(CONST_RATIONAL, Rational(2));
+  Node x = d_nodeManager->mkBoundVar("x", d_nodeManager->integerType());
+  Node y = d_nodeManager->mkBoundVar("y", d_nodeManager->integerType());
+  Node xy = d_nodeManager->mkNode(BOUND_VAR_LIST, x, y);
+  Node sum = d_nodeManager->mkNode(PLUS, x, y);
+
+  // f(x,y) = 0 for all x, y
+  Node f = d_nodeManager->mkNode(LAMBDA, xy, zero);
+  Node node1 = d_nodeManager->mkNode(BAG_FOLD, f, one, emptybag);
+  RewriteResponse response1 = d_rewriter->postRewrite(node1);
+  ASSERT_TRUE(response1.d_node == one
+              && response1.d_status == REWRITE_AGAIN_FULL);
+
+  // (bag.fold f t (bag x n)) = (f t ... (f t (f t x))) n times,  where n > 0
+  f = d_nodeManager->mkNode(LAMBDA, xy, sum);
+  Node xSkolem = d_nodeManager->getSkolemManager()->mkDummySkolem(
+      "x", d_nodeManager->integerType());
+  Node bag = d_nodeManager->mkBag(d_nodeManager->integerType(), xSkolem, n);
+  Node node2 = d_nodeManager->mkNode(BAG_FOLD, f, one, bag);
+  Node apply_f_once = d_nodeManager->mkNode(APPLY_UF, f, xSkolem, one);
+  Node apply_f_twice =
+      d_nodeManager->mkNode(APPLY_UF, f, xSkolem, apply_f_once);
+  RewriteResponse response2 = d_rewriter->postRewrite(node2);
+  ASSERT_TRUE(response2.d_node == apply_f_twice
+              && response2.d_status == REWRITE_AGAIN_FULL);
+
+  // (bag.fold (lambda ((x Int)(y Int)) (+ x y)) 1 (bag 10 2)) = 21
+  bag = d_nodeManager->mkBag(d_nodeManager->integerType(), ten, n);
+  Node node3 = d_nodeManager->mkNode(BAG_FOLD, f, one, bag);
+  Node result3 = d_nodeManager->mkConst(CONST_RATIONAL, Rational(21));
+  Node response3 = Rewriter::rewrite(node3);
+  ASSERT_TRUE(response3 == result3);
+
+  // (bag.fold f t (bag.union_disjoint A B)) =
+  //       (bag.fold f (bag.fold f t A) B) where A < B to break symmetry
+
+  Node A =
+      d_nodeManager->getSkolemManager()->mkDummySkolem("A", bagIntegerType);
+  Node B =
+      d_nodeManager->getSkolemManager()->mkDummySkolem("B", bagIntegerType);
+  Node disjoint = d_nodeManager->mkNode(BAG_UNION_DISJOINT, A, B);
+  Node node4 = d_nodeManager->mkNode(BAG_FOLD, f, one, disjoint);
+  Node foldA = d_nodeManager->mkNode(BAG_FOLD, f, one, A);
+  Node fold = d_nodeManager->mkNode(BAG_FOLD, f, foldA, B);
+  RewriteResponse response4 = d_rewriter->postRewrite(node4);
+  ASSERT_TRUE(response4.d_node == fold
+              && response2.d_status == REWRITE_AGAIN_FULL);
+}
+
 }  // namespace test
 }  // namespace cvc5