Add skolem lemmas for bags card terms (#7995)
authormudathirmahgoub <mudathirmahgoub@gmail.com>
Tue, 15 Mar 2022 19:28:41 +0000 (14:28 -0500)
committerGitHub <noreply@github.com>
Tue, 15 Mar 2022 19:28:41 +0000 (19:28 +0000)
This PR refactors the way skolem lemmas are generated for bags, count terms, and card terms.
As a side effect, this refactoring fixed cvc5/cvc5-projects#481

22 files changed:
src/expr/skolem_manager.cpp
src/expr/skolem_manager.h
src/theory/bags/bag_solver.cpp
src/theory/bags/bags_rewriter.cpp
src/theory/bags/bags_rewriter.h
src/theory/bags/card_solver.cpp
src/theory/bags/card_solver.h
src/theory/bags/infer_info.cpp
src/theory/bags/infer_info.h
src/theory/bags/inference_generator.cpp
src/theory/bags/inference_generator.h
src/theory/bags/solver_state.cpp
src/theory/bags/solver_state.h
src/theory/bags/theory_bags.cpp
src/theory/bags/theory_bags.h
src/theory/inference_id.cpp
src/theory/inference_id.h
src/theory/theory_state.cpp
src/theory/theory_state.h
test/regress/CMakeLists.txt
test/regress/regress1/bags/murxla5.smt2 [new file with mode: 0644]
test/unit/theory/theory_bags_rewriter_white.cpp

index ba1f0b0ea00d3a354f3c3490e306fea5c691a62d..7df8711a918645386ce467b175cb1b252cd9f4b0 100644 (file)
@@ -86,6 +86,7 @@ const char* toString(SkolemFunId id)
     case SkolemFunId::BAGS_FOLD_ELEMENTS: return "BAGS_FOLD_ELEMENTS";
     case SkolemFunId::BAGS_FOLD_UNION_DISJOINT: return "BAGS_FOLD_UNION_DISJOINT";
     case SkolemFunId::BAGS_MAP_PREIMAGE: return "BAGS_MAP_PREIMAGE";
+    case SkolemFunId::BAGS_MAP_PREIMAGE_SIZE: return "BAGS_MAP_PREIMAGE_SIZE";
     case SkolemFunId::BAGS_MAP_PREIMAGE_INDEX: return "BAGS_MAP_PREIMAGE_INDEX";
     case SkolemFunId::BAGS_MAP_SUM: return "BAGS_MAP_SUM";
     case SkolemFunId::HO_TYPE_MATCH_PRED: return "HO_TYPE_MATCH_PRED";
index 733fbabdcff86b716a7d488afd008651e4976401..7ca3d2dc765db5ee4c57a089d0ab9b5704ebae3a 100644 (file)
@@ -147,6 +147,12 @@ enum class SkolemFunId
    * where uf: Int -> E is a skolem function, and E is the type of elements of A
    */
   BAGS_MAP_PREIMAGE,
+  /**
+   * A skolem variable for the size of the preimage of {y} that is unique per
+   * terms (map f A), y which might be an element in (map f A). (see the
+   * documentation for BAGS_MAP_PREIMAGE)
+   */
+  BAGS_MAP_PREIMAGE_SIZE,
   /**
    * A skolem variable for the index that is unique per terms
    * (map f A), y, preImageSize, y, e which might be an element in A.
index 219e3187db9c370b7cfd3865f294e12bd8fa0cbc..f657d0e6a6e5b9c0e890f08730dea2c506bc2a2c 100644 (file)
@@ -246,9 +246,9 @@ void BagSolver::checkDuplicateRemoval(Node n)
 
 void BagSolver::checkDisequalBagTerms()
 {
-  for (const Node& n : d_state.getDisequalBagTerms())
+  for (const auto& [equality, witness] : d_state.getDisequalBagTerms())
   {
-    InferInfo info = d_ig.bagDisequality(n);
+    InferInfo info = d_ig.bagDisequality(equality, witness);
     d_im.lemmaTheoryInference(&info);
   }
 }
@@ -295,7 +295,7 @@ void BagSolver::checkFilter(Node n)
 
   set<Node> elements;
   const set<Node>& downwards = d_state.getElements(n);
-  const set<Node>& upwards = d_state.getElements(n[0]);
+  const set<Node>& upwards = d_state.getElements(n[1]);
   elements.insert(downwards.begin(), downwards.end());
   elements.insert(upwards.begin(), upwards.end());
 
@@ -316,6 +316,7 @@ void BagSolver::checkProduct(Node n)
   Assert(n.getKind() == TABLE_PRODUCT);
   const set<Node>& elementsA = d_state.getElements(n[0]);
   const set<Node>& elementsB = d_state.getElements(n[1]);
+
   for (const Node& e1 : elementsA)
   {
     for (const Node& e2 : elementsB)
index 031910cdd2a4c2a5367fb818d29c5b131db6a33a..2506f13e258c600c146999a5d2510952630b6b07 100644 (file)
@@ -454,14 +454,6 @@ BagsRewriteResponse BagsRewriter::rewriteCard(const TNode& n) const
     return BagsRewriteResponse(n[0][1], Rewrite::CARD_BAG_MAKE);
   }
 
-  if (n[0].getKind() == BAG_UNION_DISJOINT)
-  {
-    // (bag.card (bag.union-disjoint A B)) = (+ (bag.card A) (bag.card B))
-    Node A = d_nm->mkNode(BAG_CARD, n[0][0]);
-    Node B = d_nm->mkNode(BAG_CARD, n[0][1]);
-    Node plus = d_nm->mkNode(ADD, A, B);
-    return BagsRewriteResponse(plus, Rewrite::CARD_DISJOINT);
-  }
   return BagsRewriteResponse(n, Rewrite::NONE);
 }
 
index f05766c5344010b1895eea49e59f26f766a39632..be72f4017f56567a82a54256763df73f731f935b 100644 (file)
@@ -186,7 +186,6 @@ class BagsRewriter : public TheoryRewriter
   /**
    * rewrites for n include:
    * - (bag.card (bag x c)) = c where c is a constant > 0
-   * - (bag.card (union-disjoint A B)) = (+ (bag.card A) (bag.card B))
    * - otherwise = n
    */
   BagsRewriteResponse rewriteCard(const TNode& n) const;
@@ -250,12 +249,13 @@ class BagsRewriter : public TheoryRewriter
   /**
    *  rewrites for n include:
    *  - (bag.product A (as bag.empty T2)) = (as bag.empty T)
-   *  - (bag.product (as bag.empty T2)) = (f t ... (f t (f t x))) n times, where n > 0
+   *  - (bag.product (as bag.empty T2)) = (f t ... (f t (f t x))) n times, where
+   * n > 0
    *  - (bag.fold f t (bag.union_disjoint A B)) =
    *       (bag.fold f (bag.fold f t A) B) where A < B to break symmetry
    *  where f: T1 -> T2 -> T2
    */
-  BagsRewriteResponse postRewriteProduct(const TNode& n)const;
+  BagsRewriteResponse postRewriteProduct(const TNode& n) const;
 
  private:
   /** Reference to the rewriter statistics. */
index 94c281c9af946fd1d56dc22eaf714fbc104f65b7..146149e3bd2549b01df16ddf52fb9b22da90f489 100644 (file)
@@ -65,8 +65,6 @@ std::set<Node> CardSolver::getChildren(Node bag)
 
 void CardSolver::checkCardinalityGraph()
 {
-  generateRelatedCardinalityTerms();
-
   for (const auto& pair : d_state.getCardinalityTerms())
   {
     Trace("bags-card") << "CardSolver::checkCardinalityGraph cardTerm: " << pair
@@ -111,98 +109,6 @@ void CardSolver::checkCardinalityGraph()
   }
 }
 
-void CardSolver::generateRelatedCardinalityTerms()
-{
-  const set<Node>& bags = d_state.getBags();
-  for (const auto& pair : d_state.getCardinalityTerms())
-  {
-    Assert(pair.first.getKind() == BAG_CARD);
-    // get the representative of the bag in the card term
-    Node rep = d_state.getRepresentative(pair.first[0]);
-    // enumerate all bag terms that are related to the current bag
-    for (const auto& bag : bags)
-    {
-      if (rep == bag)
-      {
-        continue;
-      }
-
-      eq::EqClassIterator it = eq::EqClassIterator(
-          d_state.getRepresentative(bag), d_state.getEqualityEngine());
-      while (!it.isFinished())
-      {
-        Node n = (*it);
-        Kind k = n.getKind();
-        switch (k)
-        {
-          case BAG_EMPTY: break;
-          case BAG_MAKE: break;
-          case BAG_UNION_DISJOINT:
-          {
-            Node A = d_state.getRepresentative(n[0]);
-            Node B = d_state.getRepresentative(n[1]);
-            if (A == rep || B == rep)
-            {
-              d_state.registerCardinalityTerm(d_nm->mkNode(BAG_CARD, A));
-              d_state.registerCardinalityTerm(d_nm->mkNode(BAG_CARD, B));
-              d_state.registerCardinalityTerm(d_nm->mkNode(BAG_CARD, n));
-            }
-            break;
-          }
-          case BAG_UNION_MAX:
-          {
-            Node A = d_state.getRepresentative(n[0]);
-            Node B = d_state.getRepresentative(n[1]);
-            if (A == rep || B == rep)
-            {
-              d_state.registerCardinalityTerm(d_nm->mkNode(BAG_CARD, A));
-              d_state.registerCardinalityTerm(d_nm->mkNode(BAG_CARD, B));
-              d_state.registerCardinalityTerm(d_nm->mkNode(BAG_CARD, n));
-              // break the intersection symmetry using the node id
-              Node inter = A <= B ? d_nm->mkNode(BAG_INTER_MIN, A, B)
-                                  : d_nm->mkNode(BAG_INTER_MIN, B, A);
-              Node subtractAB =
-                  d_nm->mkNode(kind::BAG_DIFFERENCE_SUBTRACT, A, B);
-              Node subtractBA =
-                  d_nm->mkNode(kind::BAG_DIFFERENCE_SUBTRACT, B, A);
-              d_state.registerBag(inter);
-              d_state.registerBag(subtractAB);
-              d_state.registerBag(subtractBA);
-              d_state.registerCardinalityTerm(d_nm->mkNode(BAG_CARD, inter));
-              d_state.registerCardinalityTerm(
-                  d_nm->mkNode(BAG_CARD, subtractAB));
-              d_state.registerCardinalityTerm(
-                  d_nm->mkNode(BAG_CARD, subtractBA));
-            }
-            break;
-          }
-          case BAG_INTER_MIN: break;
-          case BAG_DIFFERENCE_SUBTRACT:
-          {
-            Node A = d_state.getRepresentative(n[0]);
-            Node B = d_state.getRepresentative(n[1]);
-            if (A == rep || B == rep)
-            {
-              d_state.registerCardinalityTerm(d_nm->mkNode(BAG_CARD, A));
-              d_state.registerCardinalityTerm(d_nm->mkNode(BAG_CARD, B));
-              d_state.registerCardinalityTerm(d_nm->mkNode(BAG_CARD, n));
-              // break the intersection symmetry using the node id
-              Node inter = A <= B ? d_nm->mkNode(BAG_INTER_MIN, A, B)
-                                  : d_nm->mkNode(BAG_INTER_MIN, B, A);
-              d_state.registerBag(inter);
-              d_state.registerCardinalityTerm(d_nm->mkNode(BAG_CARD, inter));
-            }
-            break;
-          }
-          case BAG_DIFFERENCE_REMOVE: break;
-          default: break;
-        }
-        it++;
-      }
-    }
-  }
-}
-
 void CardSolver::checkEmpty(const std::pair<Node, Node>& pair, const Node& n)
 {
   Assert(n.getKind() == BAG_EMPTY);
@@ -238,9 +144,6 @@ void CardSolver::checkUnionMax(const std::pair<Node, Node>& pair, const Node& n)
   // break the intersection symmetry using the node id
   Node interAB = A <= B ? d_nm->mkNode(BAG_INTER_MIN, A, B)
                         : d_nm->mkNode(BAG_INTER_MIN, B, A);
-  d_state.registerBag(subtractAB);
-  d_state.registerBag(subtractBA);
-  d_state.registerBag(interAB);
   Node subtractABRep = d_state.getRepresentative(subtractAB);
   Node subtractBARep = d_state.getRepresentative(subtractBA);
   Node interABRep = d_state.getRepresentative(interAB);
@@ -323,6 +226,8 @@ void CardSolver::addChildren(const Node& premise,
       // child.
       const std::set<Node>& oldChildren = *d_cardGraph[parent].begin();
       d_cardGraph[parent].insert(children);
+      Trace("bags-card") << "CardSolver::addChildren parent: " << parent
+                         << std::endl;
       Trace("bags-card") << "CardSolver::addChildren set1: " << oldChildren
                          << std::endl;
       Trace("bags-card") << "CardSolver::addChildren set2: " << children
@@ -352,9 +257,6 @@ void CardSolver::checkIntersectionMin(const std::pair<Node, Node>& pair,
   // break the intersection symmetry using the node id
   Node interAB = A <= B ? d_nm->mkNode(BAG_INTER_MIN, A, B)
                         : d_nm->mkNode(BAG_INTER_MIN, B, A);
-  d_state.registerBag(subtractAB);
-  d_state.registerBag(subtractBA);
-  d_state.registerBag(interAB);
   Node subtractABRep = d_state.getRepresentative(subtractAB);
   Node subtractBARep = d_state.getRepresentative(subtractBA);
   Node interABRep = d_state.getRepresentative(interAB);
@@ -372,7 +274,6 @@ void CardSolver::checkDifferenceSubtract(const std::pair<Node, Node>& pair,
   // break the intersection symmetry using the node id
   Node interAB = A <= B ? d_nm->mkNode(BAG_INTER_MIN, A, B)
                         : d_nm->mkNode(BAG_INTER_MIN, B, A);
-  d_state.registerBag(interAB);
   Node interABRep = d_state.getRepresentative(interAB);
   addChildren(bag.eqNode(n), A, {bag, interABRep});
 }
index 4bec4ea23014fc2220c55e11ffbcc49691a24b7b..5af0bc7013477ac3edeae1081684dcad2f357c64 100644 (file)
@@ -61,28 +61,6 @@ class CardSolver : protected EnvObj
   std::set<Node> getChildren(Node bag);
 
  private:
-  /**
-   * Generate all cardinality terms needed in the cardinality graph.
-   * suppose (bag.card bag) is a term, and r is the representative of bag.
-   * Suppose A, B are bag terms and r in {A, B}.
-   * - If (bag.union_disjoint A B) is a term, add the following terms:
-   *   (bag.card A)
-   *   (bag.card B)
-   *   (bag.card (bag.union_disjoint A B))
-   * - If (bag.union_max A B) is a term, add the following terms:
-   *   (bag.card A)
-   *   (bag.card B)
-   *   (bag.card (bag.difference_subtract A B))
-   *   (bag.card (bag.inter_min A B))
-   *   (bag.card (bag.difference_subtract B A))
-   * - If (bag.difference_subtract A B) is a term, add the following terms:
-   *   (bag.card A)
-   *   (bag.card B)
-   *   (bag.card (bag.inter_min A B))
-   *   (bag.card (bag.difference_subtract A B))
-   */
-  void generateRelatedCardinalityTerms();
-
   /** apply inference rules for empty bags */
   void checkEmpty(const std::pair<Node, Node>& pair, const Node& n);
   /** apply inference rules for bag make */
index 187cbff99992f5c7cd7255e4e5754fd2e09db4aa..aa949ebc8678d6764783d37f6fecb20154160e3a 100644 (file)
  */
 
 #include "theory/bags/infer_info.h"
-
+#include "theory/inference_manager_buffered.h"
 #include "theory/bags/inference_manager.h"
 
 namespace cvc5 {
 namespace theory {
 namespace bags {
 
-InferInfo::InferInfo(TheoryInferenceManager* im, InferenceId id)
+InferInfo::InferInfo(InferenceManagerBuffered* im, InferenceId id)
     : TheoryInference(id), d_im(im)
 {
 }
@@ -36,12 +36,11 @@ TrustNode InferInfo::processLemma(LemmaProperty& p)
   for (const auto& pair : d_skolems)
   {
     Node n = pair.first.eqNode(pair.second);
-    TrustNode trustedLemma = TrustNode::mkTrustLemma(n, nullptr);
-    d_im->trustedLemma(trustedLemma, getId(), p);
+    d_im->addPendingLemma(n, InferenceId::BAGS_SKOLEM);
   }
 
   Trace("bags::InferInfo::process") << (*this) << std::endl;
-
+  d_im->addPendingLemma(lemma, getId());
   return TrustNode::mkTrustLemma(lemma, nullptr);
 }
 
index 4fde553f00b45e6cb734b76bcd97babaa2ff996a..bebeaa077e920905894370426f20d751deab080d 100644 (file)
@@ -28,7 +28,7 @@
 namespace cvc5 {
 namespace theory {
 
-class TheoryInferenceManager;
+class InferenceManagerBuffered;
 
 namespace bags {
 
@@ -40,12 +40,12 @@ namespace bags {
 class InferInfo : public TheoryInference
 {
  public:
-  InferInfo(TheoryInferenceManager* im, InferenceId id);
+  InferInfo(InferenceManagerBuffered* im, InferenceId id);
   ~InferInfo() {}
   /** Process lemma */
   TrustNode processLemma(LemmaProperty& p) override;
   /** Pointer to the class used for processing this info */
-  TheoryInferenceManager* d_im;
+  InferenceManagerBuffered* d_im;
   /** The conclusion */
   Node d_conclusion;
   /**
index 2aab8473dc6ab17ec3640c733b4a9ad4a1c794dd..63a90cf185f2281c9cba988a93025de47c1148f4 100644 (file)
@@ -45,6 +45,30 @@ InferenceGenerator::InferenceGenerator(SolverState* state, InferenceManager* im)
   d_one = d_nm->mkConstInt(Rational(1));
 }
 
+Node InferenceGenerator::registerCountTerm(Node n)
+{
+  Assert(n.getKind() == BAG_COUNT);
+  Node element = d_state->getRepresentative(n[0]);
+  Node bag = d_state->getRepresentative(n[1]);
+  Node count = d_nm->mkNode(BAG_COUNT, element, bag);
+  Node skolem = registerAndAssertSkolemLemma(count, "bag.count");
+  d_state->registerCountTerm(bag, element, skolem);
+  return skolem;
+}
+
+void InferenceGenerator::registerCardinalityTerm(Node n)
+{
+  Assert(n.getKind() == BAG_CARD);
+  Node bag = d_state->getRepresentative(n[0]);
+  Node cardTerm = d_nm->mkNode(BAG_CARD, bag);
+  Node skolem = registerAndAssertSkolemLemma(cardTerm, "bag.card");
+  d_state->registerCardinalityTerm(cardTerm, skolem);
+  Node premise = n[0].eqNode(bag);
+  Node conclusion = skolem.eqNode(n);
+  Node lemma = premise.notNode().orNode(conclusion);
+  d_im->addPendingLemma(lemma, InferenceId::BAGS_SKOLEM);
+}
+
 InferInfo InferenceGenerator::nonNegativeCount(Node n, Node e)
 {
   Assert(n.getType().isBag());
@@ -102,7 +126,7 @@ InferInfo InferenceGenerator::bagMake(Node n, Node e)
   Node same = d_nm->mkNode(EQUAL, e, x);
   Node geq = d_nm->mkNode(GEQ, c, d_one);
   Node andNode = same.andNode(geq);
-  Node skolem = getSkolem(n, inferInfo);
+  Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag");
   Node count = getMultiplicityTerm(e, skolem);
   Node equalC = d_nm->mkNode(EQUAL, count, c);
   Node equalZero = d_nm->mkNode(EQUAL, count, d_zero);
@@ -132,43 +156,34 @@ struct SecondIndexVarAttributeId
 typedef expr::Attribute<SecondIndexVarAttributeId, Node>
     SecondIndexVarAttribute;
 
-struct BagsDeqAttributeId
+InferInfo InferenceGenerator::bagDisequality(Node equality, Node witness)
 {
-};
-typedef expr::Attribute<BagsDeqAttributeId, Node> BagsDeqAttribute;
-
-InferInfo InferenceGenerator::bagDisequality(Node n)
-{
-  Assert(n.getKind() == EQUAL && n[0].getType().isBag());
-
-  Node A = n[0];
-  Node B = n[1];
+  Assert(equality.getKind() == EQUAL && equality[0].getType().isBag());
+  Node A = equality[0];
+  Node B = equality[1];
 
   InferInfo inferInfo(d_im, InferenceId::BAGS_DISEQUALITY);
 
-  TypeNode elementType = A.getType().getBagElementType();
-  BoundVarManager* bvm = d_nm->getBoundVarManager();
-  Node element = bvm->mkBoundVar<BagsDeqAttribute>(n, elementType);
-  Node skolem =
-      d_sm->mkSkolem(element,
-                     n,
-                     "bag_disequal",
-                     "an extensional lemma for disequality of two bags");
-
-  Node countA = getMultiplicityTerm(skolem, A);
-  Node countB = getMultiplicityTerm(skolem, B);
+  Node countA = getMultiplicityTerm(witness, A);
+  Node skolemA = registerCountTerm(countA);
+  Node countB = getMultiplicityTerm(witness, B);
+  Node skolemB = registerCountTerm(countB);
 
-  Node disEqual = countA.eqNode(countB).notNode();
+  Node disequal = skolemA.eqNode(skolemB).notNode();
 
-  inferInfo.d_premises.push_back(n.notNode());
-  inferInfo.d_conclusion = disEqual;
+  inferInfo.d_premises.push_back(equality.notNode());
+  inferInfo.d_conclusion = disequal;
   return inferInfo;
 }
 
-Node InferenceGenerator::getSkolem(Node& n, InferInfo& inferInfo)
+Node InferenceGenerator::registerAndAssertSkolemLemma(Node& n,
+                                                      const std::string& prefix)
 {
-  Node skolem = d_sm->mkPurifySkolem(n, "skolem_bag", "skolem bag");
-  inferInfo.d_skolems[n] = skolem;
+  Node skolem = d_sm->mkPurifySkolem(n, prefix);
+  Node lemma = n.eqNode(skolem);
+  d_im->addPendingLemma(lemma, InferenceId::BAGS_SKOLEM);
+  Trace("bags-skolems") << "bags-skolems:  " << skolem << " = " << n
+                        << std::endl;
   return skolem;
 }
 
@@ -178,7 +193,7 @@ InferInfo InferenceGenerator::empty(Node n, Node e)
   Assert(e.getType().isSubtypeOf(n.getType().getBagElementType()));
 
   InferInfo inferInfo(d_im, InferenceId::BAGS_EMPTY);
-  Node skolem = getSkolem(n, inferInfo);
+  Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag");
   Node count = getMultiplicityTerm(e, skolem);
 
   Node equal = count.eqNode(d_zero);
@@ -198,7 +213,7 @@ InferInfo InferenceGenerator::unionDisjoint(Node n, Node e)
   Node countA = getMultiplicityTerm(e, A);
   Node countB = getMultiplicityTerm(e, B);
 
-  Node skolem = getSkolem(n, inferInfo);
+  Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag");
   Node count = getMultiplicityTerm(e, skolem);
 
   Node sum = d_nm->mkNode(ADD, countA, countB);
@@ -220,7 +235,7 @@ InferInfo InferenceGenerator::unionMax(Node n, Node e)
   Node countA = getMultiplicityTerm(e, A);
   Node countB = getMultiplicityTerm(e, B);
 
-  Node skolem = getSkolem(n, inferInfo);
+  Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag");
   Node count = getMultiplicityTerm(e, skolem);
 
   Node gt = d_nm->mkNode(GT, countA, countB);
@@ -242,7 +257,7 @@ InferInfo InferenceGenerator::intersection(Node n, Node e)
 
   Node countA = getMultiplicityTerm(e, A);
   Node countB = getMultiplicityTerm(e, B);
-  Node skolem = getSkolem(n, inferInfo);
+  Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag");
   Node count = getMultiplicityTerm(e, skolem);
 
   Node lt = d_nm->mkNode(LT, countA, countB);
@@ -263,7 +278,7 @@ InferInfo InferenceGenerator::differenceSubtract(Node n, Node e)
 
   Node countA = getMultiplicityTerm(e, A);
   Node countB = getMultiplicityTerm(e, B);
-  Node skolem = getSkolem(n, inferInfo);
+  Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag");
   Node count = getMultiplicityTerm(e, skolem);
 
   Node subtract = d_nm->mkNode(SUB, countA, countB);
@@ -286,7 +301,7 @@ InferInfo InferenceGenerator::differenceRemove(Node n, Node e)
   Node countA = getMultiplicityTerm(e, A);
   Node countB = getMultiplicityTerm(e, B);
 
-  Node skolem = getSkolem(n, inferInfo);
+  Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag");
   Node count = getMultiplicityTerm(e, skolem);
 
   Node notInB = d_nm->mkNode(LEQ, countB, d_zero);
@@ -305,7 +320,7 @@ InferInfo InferenceGenerator::duplicateRemoval(Node n, Node e)
   InferInfo inferInfo(d_im, InferenceId::BAGS_DUPLICATE_REMOVAL);
 
   Node countA = getMultiplicityTerm(e, A);
-  Node skolem = getSkolem(n, inferInfo);
+  Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag");
   Node count = getMultiplicityTerm(e, skolem);
 
   Node gte = d_nm->mkNode(GEQ, countA, d_one);
@@ -320,7 +335,7 @@ InferInfo InferenceGenerator::cardEmpty(const std::pair<Node, Node>& pair,
 {
   Assert(pair.first.getKind() == BAG_CARD);
   Assert(n.getKind() == BAG_EMPTY && n.getType() == pair.first[0].getType());
-  InferInfo inferInfo(d_im, InferenceId::BAGS_CARD);
+  InferInfo inferInfo(d_im, InferenceId::BAGS_CARD_EMPTY);
   Node premise = pair.first[0].eqNode(n);
   Node conclusion = pair.second.eqNode(d_zero);
   inferInfo.d_conclusion = premise.notNode().orNode(conclusion);
@@ -358,7 +373,7 @@ InferInfo InferenceGenerator::cardUnionDisjoint(Node premise,
   Node unionDisjoints = child;
   Node card = d_nm->mkNode(BAG_CARD, child);
   std::vector<Node> lemmas;
-  lemmas.push_back(d_state->registerCardinalityTerm(card));
+  registerCardinalityTerm(card);
   Node sum = d_state->getCardinalitySkolem(card);
   ++it;
   while (it != children.end())
@@ -368,14 +383,13 @@ InferInfo InferenceGenerator::cardUnionDisjoint(Node premise,
     unionDisjoints =
         d_nm->mkNode(kind::BAG_UNION_DISJOINT, unionDisjoints, child);
     card = d_nm->mkNode(BAG_CARD, child);
-    lemmas.push_back(d_state->registerCardinalityTerm(card));
-    d_state->getCardinalitySkolem(card);
+    registerCardinalityTerm(card);
     Node skolem = d_state->getCardinalitySkolem(card);
     sum = d_nm->mkNode(ADD, sum, skolem);
     ++it;
   }
   Node parentCard = d_nm->mkNode(BAG_CARD, parent);
-  lemmas.push_back(d_state->registerCardinalityTerm(parentCard));
+  registerCardinalityTerm(parentCard);
   Node parentSkolem = d_state->getCardinalitySkolem(parentCard);
 
   Node bags = parent.eqNode(unionDisjoints);
@@ -420,10 +434,11 @@ std::tuple<InferInfo, Node, Node> InferenceGenerator::mapDown(Node n, Node e)
   Node baseCase = d_nm->mkNode(EQUAL, sum_zero, d_zero);
 
   // guess the size of the preimage of e
-  Node preImageSize = d_sm->mkDummySkolem("preImageSize", d_nm->integerType());
+  Node preImageSize = d_sm->mkSkolemFunction(
+      SkolemFunId::BAGS_MAP_PREIMAGE_SIZE, d_nm->integerType(), {n, e});
 
   // (= (sum preImageSize) (bag.count e skolem))
-  Node mapSkolem = getSkolem(n, inferInfo);
+  Node mapSkolem = registerAndAssertSkolemLemma(n, "skolem_bag");
   Node countE = getMultiplicityTerm(e, mapSkolem);
   Node totalSum = d_nm->mkNode(APPLY_UF, sum, preImageSize);
   Node totalSumEqualCountE = d_nm->mkNode(EQUAL, totalSum, countE);
@@ -485,8 +500,6 @@ std::tuple<InferInfo, Node, Node> InferenceGenerator::mapDown(Node n, Node e)
       AND, {baseCase, totalSumEqualCountE, forAll_i, preImageGTE_zero});
   inferInfo.d_conclusion = conclusion;
 
-  std::map<Node, Node> m;
-  m[e] = conclusion;
   Trace("bags::InferenceGenerator::mapDown")
       << "conclusion: " << inferInfo.d_conclusion << std::endl;
   return std::tuple(inferInfo, uf, preImageSize);
@@ -532,7 +545,7 @@ InferInfo InferenceGenerator::filterDownwards(Node n, Node e)
   InferInfo inferInfo(d_im, InferenceId::BAGS_FILTER_DOWN);
 
   Node countA = getMultiplicityTerm(e, A);
-  Node skolem = getSkolem(n, inferInfo);
+  Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag");
   Node count = getMultiplicityTerm(e, skolem);
 
   Node member = d_nm->mkNode(GEQ, count, d_one);
@@ -554,7 +567,7 @@ InferInfo InferenceGenerator::filterUpwards(Node n, Node e)
   InferInfo inferInfo(d_im, InferenceId::BAGS_FILTER_UP);
 
   Node countA = getMultiplicityTerm(e, A);
-  Node skolem = getSkolem(n, inferInfo);
+  Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag");
   Node count = getMultiplicityTerm(e, skolem);
 
   Node member = d_nm->mkNode(GEQ, countA, d_one);
@@ -580,7 +593,7 @@ InferInfo InferenceGenerator::productUp(Node n, Node e1, Node e2)
   Node countA = getMultiplicityTerm(e1, A);
   Node countB = getMultiplicityTerm(e2, B);
 
-  Node skolem = getSkolem(n, inferInfo);
+  Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag");
   Node count = getMultiplicityTerm(tuple, skolem);
 
   Node multiply = d_nm->mkNode(MULT, countA, countB);
@@ -613,7 +626,7 @@ InferInfo InferenceGenerator::productDown(Node n, Node e)
   Node countA = getMultiplicityTerm(a, A);
   Node countB = getMultiplicityTerm(b, B);
 
-  Node skolem = getSkolem(n, inferInfo);
+  Node skolem = registerAndAssertSkolemLemma(n, "skolem_bag");
   Node count = getMultiplicityTerm(e, skolem);
 
   Node multiply = d_nm->mkNode(MULT, countA, countB);
index ed6f856e2bd1a033a2f7fec58ade74a053f671fe..ef8939b6431f59ceb57aed15e4fafaf690d51b72 100644 (file)
@@ -37,6 +37,20 @@ class InferenceGenerator
  public:
   InferenceGenerator(SolverState* state, InferenceManager* im);
 
+  /**
+   * @param n a node of the form (bag.count e A)
+   * @return a skolem that equals (bag.count repE repA) where
+   * repE, repA are representatives of e, A respectively
+   */
+  Node registerCountTerm(Node n);
+
+  /**
+   * @param n a node of the form (bag.card A)
+   * @return a skolem that equals (bag.card repA) where repA is the
+   * representative of A
+   */
+  void registerCardinalityTerm(Node n);
+
   /**
    * @param A is a bag of type (Bag E)
    * @param e is a node of type E
@@ -73,15 +87,16 @@ class InferenceGenerator
    */
   InferInfo bagMake(Node n, Node e);
   /**
-   * @param n is (= A B) where A, B are bags of type (Bag E), and
+   * @param equality is (= A B) where A, B are bags of type (Bag E), and
    * (not (= A B)) is an assertion in the equality engine
+   * @param witness a skolem node that witnesses the disequality
    * @return an inference that represents the following implication
    * (=>
    *   (not (= A B))
-   *   (not (= (bag.count e A) (bag.count e B))))
-   *   where e is a fresh skolem of type E.
+   *   (not (= (bag.count witness A) (bag.count witness B))))
+   *   where witness is a skolem of type E.
    */
-  InferInfo bagDisequality(Node n);
+  InferInfo bagDisequality(Node equality, Node witness);
   /**
    * @param n is (as bag.empty (Bag E))
    * @param e is a node of Type E
@@ -211,7 +226,7 @@ class InferenceGenerator
   /**
    * @param n is (bag.map f A) where f is a function (-> E T), A a bag of type
    * (Bag E)
-   * @param e is a node of Type E
+   * @param e is a node of Type T
    * @return an inference that represents the following implication
    * (and
    *   (= (sum 0) 0)
@@ -321,8 +336,10 @@ class InferenceGenerator
   Node getMultiplicityTerm(Node element, Node bag);
 
  private:
-  /** generate skolem variable for node n and add it to inferInfo */
-  Node getSkolem(Node& n, InferInfo& inferInfo);
+  /**
+   * generate skolem variable for node n and add pending lemma for the equality
+   */
+  Node registerAndAssertSkolemLemma(Node& n, const std::string& prefix);
 
   NodeManager* d_nm;
   SkolemManager* d_sm;
index 604c05cb4d584496d352cc167866e50c9f507500..3a66d5c9a51b702b75eed41f7b8e95b96feac314 100644 (file)
@@ -38,43 +38,30 @@ void SolverState::registerBag(TNode n)
 {
   Assert(n.getType().isBag());
   d_bags.insert(n);
-  if (!d_ee->hasTerm(n))
-  {
-    d_ee->addTerm(n);
-  }
 }
 
-Node SolverState::registerCountTerm(TNode n)
+void SolverState::registerCountTerm(Node bag, Node element, Node skolem)
 {
-  Assert(n.getKind() == BAG_COUNT);
-  Node element = getRepresentative(n[0]);
-  Node bag = getRepresentative(n[1]);
-  Node count = d_nm->mkNode(BAG_COUNT, element, bag);
-  Node skolem = d_nm->getSkolemManager()->mkPurifySkolem(count, "bag.count");
+  Assert(bag.getType().isBag() && bag == getRepresentative(bag));
+  Assert(element.getType().isSubtypeOf(bag.getType().getBagElementType())
+         && element == getRepresentative(element));
+  Assert(skolem.isVar() && skolem.getType().isInteger());
   std::pair<Node, Node> pair = std::make_pair(element, skolem);
   if (std::find(d_bagElements[bag].begin(), d_bagElements[bag].end(), pair)
       == d_bagElements[bag].end())
   {
     d_bagElements[bag].push_back(pair);
   }
-  return count.eqNode(skolem);
 }
 
-Node SolverState::registerCardinalityTerm(TNode n)
+void SolverState::registerCardinalityTerm(Node n, Node skolem)
 {
   Assert(n.getKind() == BAG_CARD);
-  if (!d_ee->hasTerm(n))
-  {
-    d_ee->addTerm(n);
-  }
-  Node bag = getRepresentative(n[0]);
-  Node cardTerm = d_nm->mkNode(BAG_CARD, bag);
-  Node skolem = d_nm->getSkolemManager()->mkPurifySkolem(cardTerm, "bag.card");
-  d_cardTerms[cardTerm] = skolem;
-  return cardTerm.eqNode(skolem).andNode(skolem.eqNode(n));
+  Assert(skolem.isVar());
+  d_cardTerms[n] = skolem;
 }
 
-Node SolverState::getCardinalitySkolem(TNode n)
+Node SolverState::getCardinalitySkolem(Node n)
 {
   Assert(n.getKind() == BAG_CARD);
   Node bag = getRepresentative(n[0]);
@@ -110,76 +97,10 @@ const std::vector<std::pair<Node, Node>>& SolverState::getElementCountPairs(
   return d_bagElements[bag];
 }
 
-const std::set<Node>& SolverState::getDisequalBagTerms() { return d_deq; }
-
-void SolverState::reset()
+struct BagsDeqAttributeId
 {
-  d_bagElements.clear();
-  d_bags.clear();
-  d_deq.clear();
-  d_cardTerms.clear();
-}
-
-std::vector<Node> SolverState::initialize()
-{
-  reset();
-  collectDisequalBagTerms();
-  return collectBagsAndCountTerms();
-}
-
-std::vector<Node> SolverState::collectBagsAndCountTerms()
-{
-  std::vector<Node> lemmas;
-
-  eq::EqClassesIterator repIt = eq::EqClassesIterator(d_ee);
-  while (!repIt.isFinished())
-  {
-    Node eqc = (*repIt);
-    Trace("bags-eqc") << "(eqc " << eqc << std::endl << "";
-
-    if (eqc.getType().isBag())
-    {
-      registerBag(eqc);
-    }
-
-    eq::EqClassIterator it = eq::EqClassIterator(eqc, d_ee);
-    while (!it.isFinished())
-    {
-      Node n = (*it);
-      Trace("bags-eqc") << (*it) << " ";
-      Kind k = n.getKind();
-      if (k == BAG_MAKE)
-      {
-        // for terms (bag x c) we need to store x by registering the count term
-        // (bag.count x (bag x c))
-        Node count = d_nm->mkNode(BAG_COUNT, n[0], n);
-        Node lemma = registerCountTerm(count);
-        lemmas.push_back(lemma);
-        Trace("SolverState::collectBagsAndCountTerms")
-            << "registered " << count << endl;
-      }
-      if (k == BAG_COUNT)
-      {
-        // this takes care of all count terms in each equivalent class
-        Node lemma = registerCountTerm(n);
-        lemmas.push_back(lemma);
-        Trace("SolverState::collectBagsAndCountTerms")
-            << "registered " << n << endl;
-      }
-      if (k == BAG_CARD)
-      {
-        Node lemma = registerCardinalityTerm(n);
-        lemmas.push_back(lemma);
-      }
-      ++it;
-    }
-    Trace("bags-eqc") << std::endl << " ) " << std::endl;
-    ++repIt;
-  }
-
-  Trace("bags-eqc") << "(bagRepresentatives " << d_bags << ")" << std::endl;
-  return lemmas;
-}
+};
+typedef expr::Attribute<BagsDeqAttributeId, Node> BagsDeqAttribute;
 
 void SolverState::collectDisequalBagTerms()
 {
@@ -189,13 +110,38 @@ void SolverState::collectDisequalBagTerms()
     Node n = (*it);
     if (n.getKind() == EQUAL && n[0].getType().isBag())
     {
-      Trace("bags-eqc") << "(disequalTerms " << n << " )" << std::endl;
-      d_deq.insert(n);
+      Trace("bags-eqc") << "Disequal terms: " << n << std::endl;
+      Node A = getRepresentative(n[0]);
+      Node B = getRepresentative(n[1]);
+      Node equal = A <= B ? A.eqNode(B) : B.eqNode(A);
+      if (d_deq.find(equal) == d_deq.end())
+      {
+        TypeNode elementType = A.getType().getBagElementType();
+        BoundVarManager* bvm = d_nm->getBoundVarManager();
+        Node element = bvm->mkBoundVar<BagsDeqAttribute>(equal, elementType);
+        SkolemManager* sm = d_nm->getSkolemManager();
+        Node skolem =
+            sm->mkSkolem(element,
+                         n,
+                         "bag_disequal",
+                         "an extensional lemma for disequality of two bags");
+        d_deq[equal] = skolem;
+      }
     }
     ++it;
   }
 }
 
+const std::map<Node, Node>& SolverState::getDisequalBagTerms() { return d_deq; }
+
+void SolverState::reset()
+{
+  d_bagElements.clear();
+  d_bags.clear();
+  d_deq.clear();
+  d_cardTerms.clear();
+}
+
 }  // namespace bags
 }  // namespace theory
 }  // namespace cvc5
index 4e3997793f1b16553c3ee2b2e4901e3844126eac..d59e4e6bd45cc8c7de2175b0e5657f024921a0db 100644 (file)
@@ -40,25 +40,25 @@ class SolverState : public TheoryState
   void registerBag(TNode n);
 
   /**
-   * @param n has the form (bag.count e A)
-   * @pre bag A is already registered using registerBag(A)
-   * @return a lemma (= skolem (bag.count eRep ARep)) where
-   * eRep, ARep are representatives of e, A respectively
+   * register the pair <element, skolem> with the given bag
+   * @param bag a representative of type (Bag E)
+   * @param element a representative of type E
+   * @param skolem an integer variable
+   * @pre (= (bag.count element bag) skolem)
    */
-  Node registerCountTerm(TNode n);
+  void registerCountTerm(Node bag, Node element, Node skolem);
 
   /**
-   * This function generates a skolem variable for the given card term and
-   * stores both of them in a cache.
-   * @param n has the form (bag.card A)
-   * @return a lemma that the card term equals the skolem variable
+   * store cardinality term and its skolem in a cahce
+   * @param n has the form (bag.card A) where A is a representative
+   * @param skolem for n
    */
-  Node registerCardinalityTerm(TNode n);
+  void registerCardinalityTerm(Node n, Node skolem);
 
   /**
    * @param n has the form (bag.card A)
    */
-  Node getCardinalitySkolem(TNode n);
+  Node getCardinalitySkolem(Node n);
 
   bool hasCardinalityTerms() const;
 
@@ -84,30 +84,24 @@ class SolverState : public TheoryState
    */
   std::set<Node> getElements(Node B);
   /**
-   * initialize bag and count terms
-   * @return a list of skolem lemmas to be asserted
-   * */
-  std::vector<Node> initialize();
-  /** return disequal bag terms */
-  const std::set<Node>& getDisequalBagTerms();
+   * return disequal bag terms where keys are equality nodes and values are
+   * skolems that witness the negation of these equalities
+   */
+  const std::map<Node, Node>& getDisequalBagTerms();
   /**
    * return a list of bag elements and their skolem counts
    */
   const std::vector<std::pair<Node, Node>>& getElementCountPairs(Node n);
 
- private:
   /** clear all bags data structures */
   void reset();
-  /**
-   * collect bags' representatives and all count terms.
-   * This function is called during postCheck
-   * @return a list of skolem lemmas to be asserted
-   */
-  std::vector<Node> collectBagsAndCountTerms();
+
   /**
    * collect disequal bag terms. This function is called during postCheck.
    */
   void collectDisequalBagTerms();
+
+ private:
   /** constants */
   Node d_true;
   Node d_false;
@@ -121,8 +115,12 @@ class SolverState : public TheoryState
    * This map is cleared and initialized at the start of each full effort check.
    */
   std::map<Node, std::vector<std::pair<Node, Node>>> d_bagElements;
-  /** Disequal bag terms */
-  std::set<Node> d_deq;
+  /**
+   * A map from equalities between bag terms to elements that witness their
+   * disequalities. This map is cleared and initialized at the start of each
+   * full effort check.
+   */
+  std::map<Node, Node> d_deq;
   /** a map from card terms to their skolem variables */
   std::map<Node, Node> d_cardTerms;
 }; /* class SolverState */
index 39b5a7e3cb74ec34380725d106ccf054ad42e929..3db8637d9b53fee8271f93f2751525e1675a8b4c 100644 (file)
@@ -139,6 +139,57 @@ TrustNode TheoryBags::expandChooseOperator(const Node& node,
   return TrustNode::mkTrustRewrite(node, x, nullptr);
 }
 
+void TheoryBags::initialize()
+{
+  d_state.reset();
+  d_state.collectDisequalBagTerms();
+  collectBagsAndCountTerms();
+}
+
+void TheoryBags::collectBagsAndCountTerms()
+{
+  eq::EqualityEngine* ee = d_state.getEqualityEngine();
+  eq::EqClassesIterator repIt = eq::EqClassesIterator(ee);
+  while (!repIt.isFinished())
+  {
+    Node eqc = (*repIt);
+    Trace("bags-eqc") << "Eqc [ " << eqc << " ] = { ";
+
+    if (eqc.getType().isBag())
+    {
+      d_state.registerBag(eqc);
+    }
+
+    eq::EqClassIterator it = eq::EqClassIterator(eqc, ee);
+    while (!it.isFinished())
+    {
+      Node n = (*it);
+      Trace("bags-eqc") << (*it) << " ";
+      Kind k = n.getKind();
+      if (k == BAG_MAKE)
+      {
+        // for terms (bag x c) we need to store x by registering the count term
+        // (bag.count x (bag x c))
+        NodeManager* nm = NodeManager::currentNM();
+        Node count = nm->mkNode(BAG_COUNT, n[0], n);
+        d_ig.registerCountTerm(count);
+      }
+      if (k == BAG_COUNT)
+      {
+        // this takes care of all count terms in each equivalent class
+        d_ig.registerCountTerm(n);
+      }
+      if (k == BAG_CARD)
+      {
+        d_ig.registerCardinalityTerm(n);
+      }
+      ++it;
+    }
+    Trace("bags-eqc") << " } " << std::endl;
+    ++repIt;
+  }
+}
+
 void TheoryBags::postCheck(Effort effort)
 {
   d_im.doPendingFacts();
@@ -157,12 +208,8 @@ void TheoryBags::postCheck(Effort effort)
       d_im.reset();
       // TODO issue #78: add ++(d_statistics.d_strategyRuns);
       Trace("bags-check") << "  * Run strategy..." << std::endl;
-      std::vector<Node> lemmas = d_state.initialize();
+      initialize();
       d_cardSolver.reset();
-      for (Node lemma : lemmas)
-      {
-        d_im.lemma(lemma, InferenceId::BAGS_COUNT_SKOLEM);
-      }
       runStrategy(effort);
 
       // remember if we had pending facts or lemmas
@@ -275,7 +322,8 @@ bool TheoryBags::collectModelValues(TheoryModel* m,
 
   Trace("bags-model") << "Term set: " << termSet << std::endl;
 
-  std::set<Node> processedBags;
+  // a map from bag representatives to their constructed values
+  std::map<Node, Node> processedBags;
 
   // get the relevant bag equivalence classes
   for (const Node& n : termSet)
@@ -287,14 +335,13 @@ bool TheoryBags::collectModelValues(TheoryModel* m,
       continue;
     }
     Node r = d_state.getRepresentative(n);
+
     if (processedBags.find(r) != processedBags.end())
     {
       // skip bags whose representatives are already processed
       continue;
     }
 
-    processedBags.insert(r);
-
     const std::vector<std::pair<Node, Node>>& solverElements =
         d_state.getElementCountPairs(r);
     std::vector<std::pair<Node, Node>> elements;
@@ -317,8 +364,6 @@ bool TheoryBags::collectModelValues(TheoryModel* m,
     }
     Node constructedBag = BagsUtils::constructBagFromElements(tn, elementReps);
     constructedBag = rewrite(constructedBag);
-    Trace("bags-model") << "constructed bag for " << n
-                        << " is: " << constructedBag << std::endl;
     NodeManager* nm = NodeManager::currentNM();
     if (d_state.hasCardinalityTerms())
     {
@@ -355,8 +400,6 @@ bool TheoryBags::collectModelValues(TheoryModel* m,
             constructedBag =
                 nm->mkNode(kind::BAG_UNION_DISJOINT, constructedBag, slackBag);
             constructedBag = rewrite(constructedBag);
-            Trace("bags-model") << "constructed bag for " << n
-                                << " is: " << constructedBag << std::endl;
           }
         }
       }
@@ -379,7 +422,10 @@ bool TheoryBags::collectModelValues(TheoryModel* m,
     }
     m->assertEquality(constructedBag, n, true);
     m->assertSkeleton(constructedBag);
+    processedBags[r] = constructedBag;
   }
+
+  Trace("bags-model") << "processedBags:  " << processedBags << std::endl;
   return true;
 }
 
index 18e306e9593b22abb6bff75fa3305b56311cf06b..d0d2a83eedf814b497f94ea0546f5dfe35082294 100644 (file)
@@ -59,6 +59,15 @@ class TheoryBags : public Theory
   TrustNode ppRewrite(TNode atom, std::vector<SkolemLemma>& lems) override;
   //--------------------------------- end initialization
 
+  /**
+   * initialize bag and count terms
+   */
+  void initialize();
+  /**
+   * collect bags' representatives and all count terms.
+   */
+  void collectBagsAndCountTerms();
+
   //--------------------------------- standard check
   /** Post-check, called after the fact queue of the theory is processed. */
   void postCheck(Effort effort) override;
index bca36e196edfb9698b677deea8b538c1b9220120..b3b5d6c97f9a3bdcfb338726fbe988e6a13046a6 100644 (file)
@@ -118,7 +118,7 @@ const char* toString(InferenceId i)
     case InferenceId::BAGS_NON_NEGATIVE_COUNT: return "BAGS_NON_NEGATIVE_COUNT";
     case InferenceId::BAGS_BAG_MAKE: return "BAGS_BAG_MAKE";
     case InferenceId::BAGS_BAG_MAKE_SPLIT: return "BAGS_BAG_MAKE_SPLIT";
-    case InferenceId::BAGS_COUNT_SKOLEM: return "BAGS_COUNT_SKOLEM";
+    case InferenceId::BAGS_SKOLEM: return "BAGS_SKOLEM";
     case InferenceId::BAGS_EQUALITY: return "BAGS_EQUALITY";
     case InferenceId::BAGS_DISEQUALITY: return "BAGS_DISEQUALITY";
     case InferenceId::BAGS_EMPTY: return "BAGS_EMPTY";
@@ -135,6 +135,7 @@ const char* toString(InferenceId i)
     case InferenceId::BAGS_FILTER_UP: return "BAGS_FILTER_UP";
     case InferenceId::BAGS_FOLD: return "BAGS_FOLD";
     case InferenceId::BAGS_CARD: return "BAGS_CARD";
+    case InferenceId::BAGS_CARD_EMPTY: return "BAGS_CARD_EMPTY";
     case InferenceId::TABLES_PRODUCT_UP: return "TABLES_PRODUCT_UP";
     case InferenceId::TABLES_PRODUCT_DOWN: return "TABLES_PRODUCT_DOWN";
 
index ddfbdb665b904262f4148bdcc8522ebcac31e49f..1beeaf04d360d9036c2b3ed56bcea4815795a62e 100644 (file)
@@ -185,7 +185,7 @@ enum class InferenceId
   BAGS_NON_NEGATIVE_COUNT,
   BAGS_BAG_MAKE,
   BAGS_BAG_MAKE_SPLIT,
-  BAGS_COUNT_SKOLEM,
+  BAGS_SKOLEM,
   BAGS_EQUALITY,
   BAGS_DISEQUALITY,
   BAGS_EMPTY,
@@ -201,6 +201,7 @@ enum class InferenceId
   BAGS_FILTER_UP,
   BAGS_FOLD,
   BAGS_CARD,
+  BAGS_CARD_EMPTY,
   TABLES_PRODUCT_UP,
   TABLES_PRODUCT_DOWN,
   // ---------------------------------- end bags theory
index 08ad20e01ab6a87d590fa1ee769d662cf0e1a6e6..7bf95028f43579bfd5af4c74a7a9233755fbb3ef 100644 (file)
@@ -27,10 +27,16 @@ TheoryState::TheoryState(Env& env, Valuation val)
 
 void TheoryState::setEqualityEngine(eq::EqualityEngine* ee) { d_ee = ee; }
 
-bool TheoryState::hasTerm(TNode a) const
+bool TheoryState::hasTerm(TNode t) const
 {
   Assert(d_ee != nullptr);
-  return d_ee->hasTerm(a);
+  return d_ee->hasTerm(t);
+}
+
+void TheoryState::addTerm(TNode t)
+{
+  Assert(d_ee != nullptr);
+  d_ee->addTerm(t);
 }
 
 TNode TheoryState::getRepresentative(TNode t) const
index 4d2c29e1bd26e32497d86619c98f110336fb405f..99d3f6cadbb9c832bb35a4c4b774dc86e6542e30 100644 (file)
@@ -44,7 +44,9 @@ class TheoryState : protected EnvObj
   void setEqualityEngine(eq::EqualityEngine* ee);
   //-------------------------------------- equality information
   /** Is t registered as a term in the equality engine of this class? */
-  virtual bool hasTerm(TNode a) const;
+  virtual bool hasTerm(TNode t) const;
+  /** Add term t to the equality engine if it is not registered */
+  virtual void addTerm(TNode t);
   /**
    * Get the representative of t in the equality engine of this class, or t
    * itself if it is not registered as a term.
index 398c08cbbc5636f328d5b0de8497727fa60c84db..61fd9a13e015c47a1821947c4344e18428c066b8 100644 (file)
@@ -1761,6 +1761,7 @@ set(regress_1_tests
   regress1/bags/murxla2.smt2
   regress1/bags/murxla3.smt2
   regress1/bags/murxla4.smt2
+  regress1/bags/murxla5.smt2
   regress1/bags/product1.smt2
   regress1/bags/product2.smt2
   regress1/bags/product3.smt2
diff --git a/test/regress/regress1/bags/murxla5.smt2 b/test/regress/regress1/bags/murxla5.smt2
new file mode 100644 (file)
index 0000000..1d9bebf
--- /dev/null
@@ -0,0 +1,5 @@
+(set-logic ALL)
+(set-info :status sat)
+(declare-const x Bool)
+(declare-const x9 (Bag String))
+(check-sat-assuming (((_ divisible 861286585) (bag.card (ite false (bag.difference_remove x9 x9) x9))) (or (bag.subbag x9 (bag.inter_min x9 (ite x x9 (bag.difference_remove x9 x9)))))))
index 6f57cab991f28a4909e7544680527722c1c2fedf..9f46a3f46fe9468fde6945458e896fc80f3a1144 100644 (file)
@@ -678,15 +678,6 @@ TEST_F(TestTheoryWhiteBagsRewriter, bag_card)
   RewriteResponse response2 = d_rewriter->postRewrite(n2);
   ASSERT_TRUE(response2.d_node == c
               && response2.d_status == REWRITE_AGAIN_FULL);
-
-  // (bag.card (bag.union_disjoint A B)) = (+ (bag.card A) (bag.card B))
-  Node n3 = d_nodeManager->mkNode(BAG_CARD, unionDisjointAB);
-  Node cardA = d_nodeManager->mkNode(BAG_CARD, A);
-  Node cardB = d_nodeManager->mkNode(BAG_CARD, B);
-  Node plus = d_nodeManager->mkNode(ADD, cardA, cardB);
-  RewriteResponse response3 = d_rewriter->postRewrite(n3);
-  ASSERT_TRUE(response3.d_node == plus
-              && response3.d_status == REWRITE_AGAIN_FULL);
 }
 
 TEST_F(TestTheoryWhiteBagsRewriter, is_singleton)