expand bag.choose operator (#7481)
authormudathirmahgoub <mudathirmahgoub@gmail.com>
Mon, 8 Nov 2021 23:13:27 +0000 (17:13 -0600)
committerGitHub <noreply@github.com>
Mon, 8 Nov 2021 23:13:27 +0000 (23:13 +0000)
This PR expands bag.choose operator as a preprocessing step.
It also refactors the implementation of choose operator for sets

16 files changed:
src/expr/skolem_manager.cpp
src/expr/skolem_manager.h
src/theory/bags/bags_rewriter.cpp
src/theory/bags/inference_generator.cpp
src/theory/bags/normal_form.cpp
src/theory/bags/normal_form.h
src/theory/bags/theory_bags.cpp
src/theory/bags/theory_bags.h
src/theory/sets/theory_sets_private.cpp
src/theory/sets/theory_sets_private.h
test/regress/CMakeLists.txt
test/regress/regress1/bags/choose1.smt2 [new file with mode: 0644]
test/regress/regress1/bags/choose2.smt2 [new file with mode: 0644]
test/regress/regress1/bags/choose3.smt2 [new file with mode: 0644]
test/regress/regress1/bags/choose4.smt2 [new file with mode: 0644]
test/unit/theory/theory_bags_normal_form_white.cpp

index c1741beac5c91b9d00aef9309a8e791673fa27ab..db976559f12d137572bcf191535846163f1500a9 100644 (file)
@@ -67,6 +67,7 @@ const char* toString(SkolemFunId id)
     case SkolemFunId::SK_FIRST_MATCH: return "SK_FIRST_MATCH";
     case SkolemFunId::SK_FIRST_MATCH_POST: return "SK_FIRST_MATCH_POST";
     case SkolemFunId::RE_UNFOLD_POS_COMPONENT: return "RE_UNFOLD_POS_COMPONENT";
+    case SkolemFunId::BAGS_CHOOSE: return "BAGS_CHOOSE";
     case SkolemFunId::BAGS_MAP_PREIMAGE: return "BAGS_MAP_PREIMAGE";
     case SkolemFunId::BAGS_MAP_SUM: return "BAGS_MAP_SUM";
     case SkolemFunId::HO_TYPE_MATCH_PRED: return "HO_TYPE_MATCH_PRED";
index 0556185df1af30866f765e50b364a486ae9d29ba..a18de8a2e3ad570b1edfc5cbf769365ae6dcc50c 100644 (file)
@@ -112,6 +112,17 @@ enum class SkolemFunId
    * i = 0, ..., n.
    */
   RE_UNFOLD_POS_COMPONENT,
+  /** An interpreted function for bag.choose operator:
+   * (bag.choose A) is expanded as
+   * (witness ((x elementType))
+   *    (ite
+   *      (= A (as emptybag (Bag E)))
+   *      (= x (uf A))
+   *      (and (>= (bag.count x A) 1) (= x (uf A)))
+   * where uf: (Bag E) -> E is a skolem function, and E is the type of elements
+   * of A
+   */
+  BAGS_CHOOSE,
   /** An uninterpreted function for bag.map operator:
    * To compute (bag.count y (map f A)), we need to find the distinct
    * elements in A that are mapped to y by function f (i.e., preimage of {y}).
@@ -128,6 +139,17 @@ enum class SkolemFunId
    * sum(i) = sum (i-1) + (bag.count (uf i) A)
    */
   BAGS_MAP_SUM,
+  /** An interpreted function for bag.choose operator:
+   * (choose A) is expanded as
+   * (witness ((x elementType))
+   *    (ite
+   *      (= A (as emptyset (Set E)))
+   *      (= x (uf A))
+   *      (and (member x A) (= x uf(A)))
+   * where uf: (Set E) -> E is a skolem function, and E is the type of elements
+   * of A
+   */
+  SETS_CHOOSE,
   /** Higher-order type match predicate, see HoTermDb */
   HO_TYPE_MATCH_PRED,
 };
index 7f430ed6375080ea25d2dee5bf406d16e3188779..593145f6fd4c36b15e8031400685b2306b4ffb62 100644 (file)
@@ -61,6 +61,10 @@ RewriteResponse BagsRewriter::postRewrite(TNode n)
   {
     response = postRewriteEqual(n);
   }
+  else if (n.getKind() == BAG_CHOOSE)
+  {
+    response = rewriteChoose(n);
+  }
   else if (NormalForm::areChildrenConstants(n))
   {
     Node value = NormalForm::evaluate(n);
@@ -79,7 +83,6 @@ RewriteResponse BagsRewriter::postRewrite(TNode n)
       case INTERSECTION_MIN: response = rewriteIntersectionMin(n); break;
       case DIFFERENCE_SUBTRACT: response = rewriteDifferenceSubtract(n); break;
       case DIFFERENCE_REMOVE: response = rewriteDifferenceRemove(n); break;
-      case BAG_CHOOSE: response = rewriteChoose(n); break;
       case BAG_CARD: response = rewriteCard(n); break;
       case BAG_IS_SINGLETON: response = rewriteIsSingleton(n); break;
       case BAG_FROM_SET: response = rewriteFromSet(n); break;
@@ -417,7 +420,8 @@ BagsRewriteResponse BagsRewriter::rewriteDifferenceRemove(const TNode& n) const
 BagsRewriteResponse BagsRewriter::rewriteChoose(const TNode& n) const
 {
   Assert(n.getKind() == BAG_CHOOSE);
-  if (n[0].getKind() == MK_BAG && n[0][1].isConst())
+  if (n[0].getKind() == MK_BAG && n[0][1].isConst()
+      && n[0][1].getConst<Rational>() > 0)
   {
     // (bag.choose (mkBag x c)) = x where c is a constant > 0
     return BagsRewriteResponse(n[0][0], Rewrite::CHOOSE_MK_BAG);
index e88a7e0ca13a1cc3cfdbf081ca8a878c776066bf..66c38ab869561149ad2f04b6e409fbaeb33d9211 100644 (file)
@@ -400,6 +400,8 @@ std::tuple<InferInfo, Node, Node> InferenceGenerator::mapDownwards(Node n,
 
   std::map<Node, Node> m;
   m[e] = conclusion;
+  Trace("bags::InferenceGenerator::mapDownwards")
+      << "conclusion: " << inferInfo.d_conclusion << std::endl;
   return std::tuple(inferInfo, uf, preImageSize);
 }
 
@@ -426,8 +428,8 @@ InferInfo InferenceGenerator::mapUpwards(
   Node orNode = d_nm->mkNode(OR, notEqual, andNode);
   Node implies = d_nm->mkNode(IMPLIES, xInA, orNode);
   inferInfo.d_conclusion = implies;
-  std::cout << "Upwards conclusion: " << inferInfo.d_conclusion << std::endl
-            << std::endl;
+  Trace("bags::InferenceGenerator::mapUpwards")
+      << "conclusion: " << inferInfo.d_conclusion << std::endl;
   return inferInfo;
 }
 
index 60d223a711b29b0970f69a3503cd305fc00269bb..59344cf0bc281310cbd44b95ab249724d67a4fc0 100644 (file)
@@ -15,6 +15,7 @@
 #include "normal_form.h"
 
 #include "expr/emptybag.h"
+#include "smt/logic_exception.h"
 #include "theory/sets/normal_form.h"
 #include "theory/type_enumerator.h"
 #include "util/rational.h"
@@ -104,7 +105,6 @@ Node NormalForm::evaluate(TNode n)
     case INTERSECTION_MIN: return evaluateIntersectionMin(n);
     case DIFFERENCE_SUBTRACT: return evaluateDifferenceSubtract(n);
     case DIFFERENCE_REMOVE: return evaluateDifferenceRemove(n);
-    case BAG_CHOOSE: return evaluateChoose(n);
     case BAG_CARD: return evaluateCard(n);
     case BAG_IS_SINGLETON: return evaluateIsSingleton(n);
     case BAG_FROM_SET: return evaluateFromSet(n);
@@ -564,29 +564,13 @@ Node NormalForm::evaluateChoose(TNode n)
   Assert(n.getKind() == BAG_CHOOSE);
   // Examples
   // --------
-  // - (choose (emptyBag String)) = "" // the empty string which is the first
-  //   element returned by the type enumerator
-  // - (choose (MK_BAG "x" 4)) = "x"
-  // - (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1))) = "x"
-  //     deterministically return the first element
-
-  if (n[0].getKind() == EMPTYBAG)
-  {
-    TypeNode elementType = n[0].getType().getBagElementType();
-    TypeEnumerator typeEnumerator(elementType);
-    // get the first value from the typeEnumerator
-    Node element = *typeEnumerator;
-    return element;
-  }
+  // - (bag.choose (MK_BAG "x" 4)) = "x"
 
   if (n[0].getKind() == MK_BAG)
   {
     return n[0][0];
   }
-  Assert(n[0].getKind() == UNION_DISJOINT);
-  // return the first element
-  // e.g. (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1)))
-  return n[0][0][0];
+  throw LogicException("BAG_CHOOSE_TOTAL is not supported yet");
 }
 
 Node NormalForm::evaluateCard(TNode n)
@@ -676,7 +660,6 @@ Node NormalForm::evaluateToSet(TNode n)
   return set;
 }
 
-
 Node NormalForm::evaluateBagMap(TNode n)
 {
   Assert(n.getKind() == BAG_MAP);
index f104e0381c03d824c7c915e8f36268e0b61bb138..bf96a1fbac311b35c3938f2fdb53c8338dcda96a 100644 (file)
@@ -166,8 +166,8 @@ class NormalForm
   static Node evaluateDifferenceRemove(TNode n);
   /**
    * @param n has the form (bag.choose A) where A is a constant bag
-   * @return the first element of A if A is not empty. Otherwise, it returns the
-   * first element returned by the type enumerator for the elements
+   * @return x if n has the form (bag.choose (bag x c)). Otherwise an error is
+   * thrown.
    */
   static Node evaluateChoose(TNode n);
   /**
index 9db6149ef863c0e105c3a2607b8e3f16cb4b323e..813f45669e4743a8c6c535292b35dfcfb53d3842 100644 (file)
 
 #include "theory/bags/theory_bags.h"
 
+#include "expr/emptybag.h"
+#include "expr/skolem_manager.h"
 #include "proof/proof_checker.h"
 #include "smt/logic_exception.h"
 #include "theory/bags/normal_form.h"
 #include "theory/rewriter.h"
 #include "theory/theory_model.h"
+#include "util/rational.h"
 
 using namespace cvc5::kind;
 
@@ -60,7 +63,6 @@ void TheoryBags::finishInit()
 {
   Assert(d_equalityEngine != nullptr);
 
-  // choice is used to eliminate witness
   d_valuation.setUnevaluatedKind(WITNESS);
 
   // functions we are doing congruence over
@@ -77,6 +79,55 @@ void TheoryBags::finishInit()
   d_equalityEngine->addFunctionKind(BAG_TO_SET);
 }
 
+TrustNode TheoryBags::ppRewrite(TNode atom, std::vector<SkolemLemma>& lems)
+{
+  Trace("bags-ppr") << "TheoryBags::ppRewrite " << atom << std::endl;
+
+  switch (atom.getKind())
+  {
+    case kind::BAG_CHOOSE: return expandChooseOperator(atom, lems);
+    default: return TrustNode::null();
+  }
+}
+
+TrustNode TheoryBags::expandChooseOperator(const Node& node,
+                                           std::vector<SkolemLemma>& lems)
+{
+  Assert(node.getKind() == BAG_CHOOSE);
+
+  // (bag.choose A) is expanded as
+  // (witness ((x elementType))
+  //    (ite
+  //      (= A (as emptybag (Bag E)))
+  //      (= x (uf A))
+  //      (and (>= (bag.count x A) 1) (= x (uf A)))
+  // where uf: (Bag E) -> E is a skolem function, and E is the type of elements
+  // of A
+
+  NodeManager* nm = NodeManager::currentNM();
+  SkolemManager* sm = nm->getSkolemManager();
+  Node A = node[0];
+  TypeNode bagType = A.getType();
+  TypeNode ufType = nm->mkFunctionType(bagType, bagType.getBagElementType());
+  // a Null node is used here to get a unique skolem function per bag type
+  Node uf = sm->mkSkolemFunction(SkolemFunId::BAGS_CHOOSE, ufType, Node());
+  Node ufA = NodeManager::currentNM()->mkNode(APPLY_UF, uf, A);
+
+  Node x = nm->mkBoundVar(bagType.getBagElementType());
+
+  Node equal = x.eqNode(ufA);
+  Node emptyBag = nm->mkConst(EmptyBag(bagType));
+  Node isEmpty = A.eqNode(emptyBag);
+  Node count = nm->mkNode(BAG_COUNT, x, A);
+  Node one = nm->mkConst(Rational(1));
+  Node geqOne = nm->mkNode(GEQ, count, one);
+  Node geqOneAndEqual = geqOne.andNode(equal);
+  Node ite = nm->mkNode(ITE, isEmpty, equal, geqOneAndEqual);
+  Node ret = sm->mkSkolem(x, ite, "kBagChoose");
+  lems.push_back(SkolemLemma(ret, nullptr));
+  return TrustNode::mkTrustRewrite(node, ret, nullptr);
+}
+
 void TheoryBags::postCheck(Effort effort)
 {
   d_im.doPendingFacts();
index 671623d05a7e9ccdf7f41c136a926212eb23722b..b4888f3b458e4da2b24abe6153f58e1ada6fe637 100644 (file)
@@ -52,6 +52,8 @@ class TheoryBags : public Theory
   bool needsEqualityEngine(EeSetupInfo& esi) override;
   /** finish initialization */
   void finishInit() override;
+  /** preprocess rewrite */
+  TrustNode ppRewrite(TNode atom, std::vector<SkolemLemma>& lems) override;
   //--------------------------------- end initialization
 
   //--------------------------------- standard check
@@ -87,6 +89,10 @@ class TheoryBags : public Theory
     TheoryBags& d_theory;
   };
 
+  /** expand the definition of the bag.choose operator */
+  TrustNode expandChooseOperator(const Node& node,
+                                 std::vector<SkolemLemma>& lems);
+
   /** The state of the bags solver at full effort */
   SolverState d_state;
   /** The inference manager */
index 2032d3ba5bfed3817f116709f58418202d6180ac..7b596be86086fafd601d523759931bd1adc9dfce 100644 (file)
@@ -1332,26 +1332,30 @@ TrustNode TheorySetsPrivate::expandChooseOperator(
   // (choose A) is expanded as
   // (witness ((x elementType))
   //    (ite
-  //      (= A (as emptyset setType))
-  //      (= x chooseUf(A))
-  //      (and (member x A) (= x chooseUf(A)))
+  //      (= A (as emptyset (Set E)))
+  //      (= x (uf A))
+  //      (and (member x A) (= x uf(A)))
+  // where uf: (Set E) -> E is a skolem function, and E is the type of elements
+  // of A
 
   NodeManager* nm = NodeManager::currentNM();
-  Node set = node[0];
-  TypeNode setType = set.getType();
-  Node chooseSkolem = getChooseFunction(setType);
-  Node apply = NodeManager::currentNM()->mkNode(APPLY_UF, chooseSkolem, set);
+  SkolemManager* sm = nm->getSkolemManager();
+  Node A = node[0];
+  TypeNode setType = A.getType();
+  TypeNode ufType = nm->mkFunctionType(setType, setType.getSetElementType());
+  // a Null node is used here to get a unique skolem function per set type
+  Node uf = sm->mkSkolemFunction(SkolemFunId::SETS_CHOOSE, ufType, Node());
+  Node ufA = NodeManager::currentNM()->mkNode(APPLY_UF, uf, A);
 
-  Node witnessVariable = nm->mkBoundVar(setType.getSetElementType());
+  Node x = nm->mkBoundVar(setType.getSetElementType());
 
-  Node equal = witnessVariable.eqNode(apply);
+  Node equal = x.eqNode(ufA);
   Node emptySet = nm->mkConst(EmptySet(setType));
-  Node isEmpty = set.eqNode(emptySet);
-  Node member = nm->mkNode(SET_MEMBER, witnessVariable, set);
+  Node isEmpty = A.eqNode(emptySet);
+  Node member = nm->mkNode(SET_MEMBER, x, A);
   Node memberAndEqual = member.andNode(equal);
   Node ite = nm->mkNode(ITE, isEmpty, equal, memberAndEqual);
-  SkolemManager* sm = nm->getSkolemManager();
-  Node ret = sm->mkSkolem(witnessVariable, ite, "kSetChoose");
+  Node ret = sm->mkSkolem(x, ite, "kSetChoose");
   lems.push_back(SkolemLemma(ret, nullptr));
   return TrustNode::mkTrustRewrite(node, ret, nullptr);
 }
@@ -1394,25 +1398,6 @@ TrustNode TheorySetsPrivate::expandIsSingletonOperator(const Node& node)
   return TrustNode::mkTrustRewrite(node, exists, nullptr);
 }
 
-Node TheorySetsPrivate::getChooseFunction(const TypeNode& setType)
-{
-  std::map<TypeNode, Node>::iterator it = d_chooseFunctions.find(setType);
-  if (it != d_chooseFunctions.end())
-  {
-    return it->second;
-  }
-
-  NodeManager* nm = NodeManager::currentNM();
-  SkolemManager* sm = nm->getSkolemManager();
-  TypeNode chooseUf = nm->mkFunctionType(setType, setType.getSetElementType());
-  stringstream stream;
-  stream << "chooseUf" << setType.getId();
-  string name = stream.str();
-  Node chooseSkolem = sm->mkDummySkolem(name, chooseUf, "choose function");
-  d_chooseFunctions[setType] = chooseSkolem;
-  return chooseSkolem;
-}
-
 void TheorySetsPrivate::presolve() { d_state.reset(); }
 
 }  // namespace sets
index f464d475b9c5bc1d446e5ecccbf1eeeeea7750fd..a2564b2de4e7ebcbbd21d12ca0c75f7c0ab3600a 100644 (file)
@@ -200,12 +200,7 @@ class TheorySetsPrivate : protected EnvObj
   bool isEntailed(Node n, bool pol) { return d_state.isEntailed(n, pol); }
 
  private:
-  /** get choose function
-   *
-   * Returns the existing uninterpreted function for the choose operator for the
-   * given set type, or creates a new one if it does not exist.
-   */
-  Node getChooseFunction(const TypeNode& setType);
+
   /** expand the definition of the choose operator */
   TrustNode expandChooseOperator(const Node& node,
                                  std::vector<SkolemLemma>& lems);
@@ -231,9 +226,6 @@ class TheorySetsPrivate : protected EnvObj
   /** The theory rewriter for this theory. */
   TheorySetsRewriter d_rewriter;
 
-  /** a map that stores the choose functions for set types */
-  std::map<TypeNode, Node> d_chooseFunctions;
-
   /** a map that maps each set to an existential quantifier generated for
    * operator is_singleton */
   std::map<Node, Node> d_isSingletonNodes;
index b0b19315ecb2dac99cf0efc21588dc36fb1a4b2c..23ef3936ccdfcea5db5746dd3abba513e09bee57 100644 (file)
@@ -1585,6 +1585,10 @@ set(regress_1_tests
   regress1/bug681.smt2
   regress1/bug694-Unapply1.scala-0.smt2
   regress1/bug800.smt2
+  regress1/bags/choose1.smt2
+  regress1/bags/choose2.smt2
+  regress1/bags/choose3.smt2
+  regress1/bags/choose4.smt2
   regress1/bags/difference_remove1.smt2
   regress1/bags/disequality.smt2
   regress1/bags/duplicate_removal1.smt2
diff --git a/test/regress/regress1/bags/choose1.smt2 b/test/regress/regress1/bags/choose1.smt2
new file mode 100644 (file)
index 0000000..b157bbc
--- /dev/null
@@ -0,0 +1,10 @@
+; COMMAND-LINE: --quiet
+(set-logic ALL)
+(set-info :status sat)
+(declare-fun A () (Bag Int))
+(declare-fun a () Int)
+(assert (not (= A (as emptybag (Bag Int)))))
+(assert (= (bag.choose A) 10))
+(assert (= a (bag.choose A)))
+(assert (exists ((x Int)) (and (= x (bag.choose A)) (= x a))))
+(check-sat)
diff --git a/test/regress/regress1/bags/choose2.smt2 b/test/regress/regress1/bags/choose2.smt2
new file mode 100644 (file)
index 0000000..161c92d
--- /dev/null
@@ -0,0 +1,5 @@
+(set-logic ALL)
+(set-info :status unsat)
+(declare-fun A () (Bag Int))
+(assert (distinct (bag.choose A) (bag.choose A)))
+(check-sat)
diff --git a/test/regress/regress1/bags/choose3.smt2 b/test/regress/regress1/bags/choose3.smt2
new file mode 100644 (file)
index 0000000..ffa9ae9
--- /dev/null
@@ -0,0 +1,8 @@
+; COMMAND-LINE: -q
+; EXPECT: sat
+(set-logic ALL)
+(set-info :status sat)
+(declare-fun A () (Bag Int))
+(assert (= (bag.choose A) 10))
+(assert (= A (as emptybag (Bag Int))))
+(check-sat)
diff --git a/test/regress/regress1/bags/choose4.smt2 b/test/regress/regress1/bags/choose4.smt2
new file mode 100644 (file)
index 0000000..a0290b9
--- /dev/null
@@ -0,0 +1,9 @@
+; COMMAND-LINE: --quiet
+(set-logic ALL)
+(set-info :status sat)
+(declare-fun A () (Bag Int))
+(declare-fun a () Int)
+(assert (not (= A (as emptybag (Bag Int)))))
+(assert (> (bag.count 10 A) 0))
+(assert (= a (bag.choose A)))
+(check-sat)
index e0a5577b4bad123c3a7f2e496a1fcf7aa345568d..9634d55c259fe7edb2d5a1c5b97b92da2bcbac13 100644 (file)
@@ -405,41 +405,6 @@ TEST_F(TestTheoryWhiteBagsNormalForm, difference_remove)
   ASSERT_EQ(output, NormalForm::evaluate(input));
 }
 
-TEST_F(TestTheoryWhiteBagsNormalForm, choose)
-{
-  // Example
-  // -------
-  // input:  (choose (emptybag String))
-  // output: "A"; the first element returned by the type enumerator
-  // input:  (choose (MK_BAG "x" 4))
-  // output: "x"
-  // input:  (choose (union_disjoint (MK_BAG "x" 4) (MK_BAG "y" 1)))
-  // output: "x"; deterministically return the first element
-  Node empty = d_nodeManager->mkConst(
-      EmptyBag(d_nodeManager->mkBagType(d_nodeManager->stringType())));
-  Node x = d_nodeManager->mkConst(String("x"));
-  Node y = d_nodeManager->mkConst(String("y"));
-  Node z = d_nodeManager->mkConst(String("z"));
-  Node x_4 = d_nodeManager->mkBag(
-      d_nodeManager->stringType(), x, d_nodeManager->mkConst(Rational(4)));
-  Node y_1 = d_nodeManager->mkBag(
-      d_nodeManager->stringType(), y, d_nodeManager->mkConst(Rational(1)));
-
-  Node input1 = d_nodeManager->mkNode(BAG_CHOOSE, empty);
-  Node output1 = d_nodeManager->mkConst(String(""));
-
-  ASSERT_EQ(output1, NormalForm::evaluate(input1));
-
-  Node input2 = d_nodeManager->mkNode(BAG_CHOOSE, x_4);
-  Node output2 = x;
-  ASSERT_EQ(output2, NormalForm::evaluate(input2));
-
-  Node union_disjoint = d_nodeManager->mkNode(UNION_DISJOINT, x_4, y_1);
-  Node input3 = d_nodeManager->mkNode(BAG_CHOOSE, union_disjoint);
-  Node output3 = x;
-  ASSERT_EQ(output3, NormalForm::evaluate(input3));
-}
-
 TEST_F(TestTheoryWhiteBagsNormalForm, bag_card)
 {
   // Examples