Refactor bags::SolverState (#5783)
authormudathirmahgoub <mudathirmahgoub@gmail.com>
Mon, 25 Jan 2021 20:38:45 +0000 (14:38 -0600)
committerGitHub <noreply@github.com>
Mon, 25 Jan 2021 20:38:45 +0000 (14:38 -0600)
Couple of changes:

SolverState now keep tracks of elements per bag instead of per type.
bags::InferInfo now stores multiple conclusions (conjuncts).
BagSolver applies upward/downward closures for bag elements

14 files changed:
src/theory/bags/bag_solver.cpp
src/theory/bags/bag_solver.h
src/theory/bags/bags_rewriter.h
src/theory/bags/infer_info.cpp
src/theory/bags/infer_info.h
src/theory/bags/inference_generator.cpp
src/theory/bags/inference_generator.h
src/theory/bags/inference_manager.h
src/theory/bags/solver_state.cpp
src/theory/bags/solver_state.h
src/theory/bags/theory_bags.cpp
test/regress/CMakeLists.txt
test/regress/regress1/bags/difference_remove1.smt2 [new file with mode: 0644]
test/regress/regress1/bags/issue5759.smt2 [new file with mode: 0644]

index 5621a7c1c073b0b742c8a56f90abec8e5f30687b..495f73723d81b6325a8edc155e5f8038df9e6e5b 100644 (file)
@@ -39,25 +39,63 @@ BagSolver::~BagSolver() {}
 
 void BagSolver::postCheck()
 {
+  d_state.initialize();
+
+  // At this point, all bag and count representatives should be in the solver
+  // state.
+  for (const Node& bag : d_state.getBags())
+  {
+    // iterate through all bags terms in each equivalent class
+    eq::EqClassIterator it =
+        eq::EqClassIterator(bag, d_state.getEqualityEngine());
+    while (!it.isFinished())
+    {
+      Node n = (*it);
+      Kind k = n.getKind();
+      switch (k)
+      {
+        case kind::MK_BAG: checkMkBag(n); break;
+        case kind::UNION_DISJOINT: checkUnionDisjoint(n); break;
+        case kind::UNION_MAX: checkUnionMax(n); break;
+        case kind::DIFFERENCE_SUBTRACT: checkDifferenceSubtract(n); break;
+        case kind::DIFFERENCE_REMOVE: checkDifferenceRemove(n); break;
+        default: break;
+      }
+      it++;
+    }
+  }
+
+  // add non negative constraints for all multiplicities
   for (const Node& n : d_state.getBags())
   {
-    Kind k = n.getKind();
-    switch (k)
+    for (const Node& e : d_state.getElements(n))
     {
-      case kind::MK_BAG: checkMkBag(n); break;
-      case kind::UNION_DISJOINT: checkUnionDisjoint(n); break;
-      case kind::UNION_MAX: checkUnionMax(n); break;
-      case kind::DIFFERENCE_SUBTRACT: checkDifferenceSubtract(n); break;
-      default: break;
+      checkNonNegativeCountTerms(n, e);
     }
   }
 }
 
+set<Node> BagSolver::getElementsForBinaryOperator(const Node& n)
+{
+  set<Node> elements;
+  const set<Node>& downwards = d_state.getElements(n);
+  const set<Node>& upwards0 = d_state.getElements(n[0]);
+  const set<Node>& upwards1 = d_state.getElements(n[1]);
+
+  set_union(downwards.begin(),
+            downwards.end(),
+            upwards0.begin(),
+            upwards0.end(),
+            inserter(elements, elements.begin()));
+  elements.insert(upwards1.begin(), upwards1.end());
+  return elements;
+}
+
 void BagSolver::checkUnionDisjoint(const Node& n)
 {
   Assert(n.getKind() == UNION_DISJOINT);
-  TypeNode elementType = n.getType().getBagElementType();
-  for (const Node& e : d_state.getElements(elementType))
+  std::set<Node> elements = getElementsForBinaryOperator(n);
+  for (const Node& e : elements)
   {
     InferenceGenerator ig(&d_state);
     InferInfo i = ig.unionDisjoint(n, e);
@@ -69,8 +107,8 @@ void BagSolver::checkUnionDisjoint(const Node& n)
 void BagSolver::checkUnionMax(const Node& n)
 {
   Assert(n.getKind() == UNION_MAX);
-  TypeNode elementType = n.getType().getBagElementType();
-  for (const Node& e : d_state.getElements(elementType))
+  std::set<Node> elements = getElementsForBinaryOperator(n);
+  for (const Node& e : elements)
   {
     InferenceGenerator ig(&d_state);
     InferInfo i = ig.unionMax(n, e);
@@ -82,8 +120,8 @@ void BagSolver::checkUnionMax(const Node& n)
 void BagSolver::checkDifferenceSubtract(const Node& n)
 {
   Assert(n.getKind() == DIFFERENCE_SUBTRACT);
-  TypeNode elementType = n.getType().getBagElementType();
-  for (const Node& e : d_state.getElements(elementType))
+  std::set<Node> elements = getElementsForBinaryOperator(n);
+  for (const Node& e : elements)
   {
     InferenceGenerator ig(&d_state);
     InferInfo i = ig.differenceSubtract(n, e);
@@ -91,11 +129,14 @@ void BagSolver::checkDifferenceSubtract(const Node& n)
     Trace("bags::BagSolver::postCheck") << i << endl;
   }
 }
+
 void BagSolver::checkMkBag(const Node& n)
 {
   Assert(n.getKind() == MK_BAG);
-  TypeNode elementType = n.getType().getBagElementType();
-  for (const Node& e : d_state.getElements(elementType))
+  Trace("bags::BagSolver::postCheck")
+      << "BagSolver::checkMkBag Elements of " << n
+      << " are: " << d_state.getElements(n) << std::endl;
+  for (const Node& e : d_state.getElements(n))
   {
     InferenceGenerator ig(&d_state);
     InferInfo i = ig.mkBag(n, e);
@@ -103,6 +144,26 @@ void BagSolver::checkMkBag(const Node& n)
     Trace("bags::BagSolver::postCheck") << i << endl;
   }
 }
+void BagSolver::checkNonNegativeCountTerms(const Node& bag, const Node& element)
+{
+  InferenceGenerator ig(&d_state);
+  InferInfo i = ig.nonNegativeCount(bag, element);
+  i.process(&d_im, true);
+  Trace("bags::BagSolver::postCheck") << i << endl;
+}
+
+void BagSolver::checkDifferenceRemove(const Node& n)
+{
+  Assert(n.getKind() == DIFFERENCE_REMOVE);
+  std::set<Node> elements = getElementsForBinaryOperator(n);
+  for (const Node& e : elements)
+  {
+    InferenceGenerator ig(&d_state);
+    InferInfo i = ig.differenceRemove(n, e);
+    i.process(&d_im, true);
+    Trace("bags::BagSolver::postCheck") << i << endl;
+  }
+}
 
 }  // namespace bags
 }  // namespace theory
index 48583d134f59dd5faab0a4db1d5bb4b016536700..b4b18c00cab16f27edc2e89ab1b3a6d5641de427 100644 (file)
@@ -41,14 +41,31 @@ class BagSolver
   void postCheck();
 
  private:
-  /** apply inference rules for MK_BAG operator */
+  /**
+   * apply inference rules for MK_BAG operator.
+   * Example: Suppose n = (bag x c), and we have two count terms (bag.count x n)
+   * and (bag.count y n).
+   * This function will add inferences for the count terms as documented in
+   * InferenceGenerator::mkBag.
+   * Note that element y may not be in bag n. See the documentation of
+   * SolverState::getElements.
+   */
   void checkMkBag(const Node& n);
+  /**
+   * @param n is a bag that has the form (op A B)
+   * @return the set union of known elements in (op A B) , A, and B.
+   */
+  std::set<Node> getElementsForBinaryOperator(const Node& n);
   /** apply inference rules for union disjoint */
   void checkUnionDisjoint(const Node& n);
   /** apply inference rules for union max */
   void checkUnionMax(const Node& n);
   /** apply inference rules for difference subtract */
   void checkDifferenceSubtract(const Node& n);
+  /** apply inference rules for difference remove */
+  void checkDifferenceRemove(const Node& n);
+  /** apply non negative constraints for multiplicities */
+  void checkNonNegativeCountTerms(const Node& bag, const Node& element);
 
   /** The solver state object */
   SolverState& d_state;
index fb76fb1c22a99557fdaac9beac69decd5941df71..48cd9c419b6f2172cb50dd1d83523dc7288e2a89 100644 (file)
@@ -70,8 +70,8 @@ class BagsRewriter : public TheoryRewriter
 
   /**
    * rewrites for n include:
-   * - (mkBag x 0) = (emptybag T) where T is the type of x
-   * - (mkBag x (-c)) = (emptybag T) where T is the type of x, and c > 0 is a
+   * - (bag x 0) = (emptybag T) where T is the type of x
+   * - (bag x (-c)) = (emptybag T) where T is the type of x, and c > 0 is a
    *   constant
    * - otherwise = n
    */
@@ -87,7 +87,7 @@ class BagsRewriter : public TheoryRewriter
 
   /**
    *  rewrites for n include:
-   *  - (duplicate_removal (mkBag x n)) = (mkBag x 1)
+   *  - (duplicate_removal (bag x n)) = (bag x 1)
    *     where n is a positive constant
    */
   BagsRewriteResponse rewriteDuplicateRemoval(const TNode& n) const;
@@ -171,13 +171,13 @@ class BagsRewriter : public TheoryRewriter
   BagsRewriteResponse rewriteDifferenceRemove(const TNode& n) const;
   /**
    * rewrites for n include:
-   * - (bag.choose (mkBag x c)) = x where c is a constant > 0
+   * - (bag.choose (bag x c)) = x where c is a constant > 0
    * - otherwise = n
    */
   BagsRewriteResponse rewriteChoose(const TNode& n) const;
   /**
    * rewrites for n include:
-   * - (bag.card (mkBag x c)) = c where c is a constant > 0
+   * - (bag.card (bag x c)) = c where c is a constant > 0
    * - (bag.card (union-disjoint A B)) = (+ (bag.card A) (bag.card B))
    * - otherwise = n
    */
@@ -185,19 +185,19 @@ class BagsRewriter : public TheoryRewriter
 
   /**
    * rewrites for n include:
-   * - (bag.is_singleton (mkBag x c)) = (c == 1)
+   * - (bag.is_singleton (bag x c)) = (c == 1)
    */
   BagsRewriteResponse rewriteIsSingleton(const TNode& n) const;
 
   /**
    *  rewrites for n include:
-   *  - (bag.from_set (singleton (singleton_op Int) x)) = (mkBag x 1)
+   *  - (bag.from_set (singleton (singleton_op Int) x)) = (bag x 1)
    */
   BagsRewriteResponse rewriteFromSet(const TNode& n) const;
 
   /**
    *  rewrites for n include:
-   *  - (bag.to_set (mkBag x n)) = (singleton (singleton_op T) x)
+   *  - (bag.to_set (bag x n)) = (singleton (singleton_op T) x)
    *     where n is a positive constant and T is the type of the bag's elements
    */
   BagsRewriteResponse rewriteToSet(const TNode& n) const;
index 1244a43ac08d8935b264dac78859ac2c22f4059b..5b3274617307b2c6da8f9077b9a37ab66cd764a5 100644 (file)
@@ -25,6 +25,8 @@ const char* toString(Inference i)
   switch (i)
   {
     case Inference::NONE: return "NONE";
+    case Inference::BAG_NON_NEGATIVE_COUNT: return "BAG_NON_NEGATIVE_COUNT";
+    case Inference::BAG_MK_BAG_SAME_ELEMENT: return "BAG_MK_BAG_SAME_ELEMENT";
     case Inference::BAG_MK_BAG: return "BAG_MK_BAG";
     case Inference::BAG_EQUALITY: return "BAG_EQUALITY";
     case Inference::BAG_DISEQUALITY: return "BAG_DISEQUALITY";
@@ -62,9 +64,19 @@ bool InferInfo::process(TheoryInferenceManager* im, bool asLemma)
   if (asLemma)
   {
     TrustNode trustedLemma = TrustNode::mkTrustLemma(lemma, nullptr);
-    return im->trustedLemma(trustedLemma);
+    im->trustedLemma(trustedLemma);
   }
-  Unimplemented();
+  else
+  {
+    Unimplemented();
+  }
+  for (const auto& pair : d_skolems)
+  {
+    Node n = pair.first.eqNode(pair.second);
+    TrustNode trustedLemma = TrustNode::mkTrustLemma(n, nullptr);
+    im->trustedLemma(trustedLemma);
+  }
+  return true;
 }
 
 bool InferInfo::isTrivial() const
@@ -87,21 +99,15 @@ bool InferInfo::isFact() const
   return !atom.isConst() && atom.getKind() != kind::OR;
 }
 
-Node InferInfo::getPremises() const
-{
-  // d_noExplain is a subset of d_ant
-  NodeManager* nm = NodeManager::currentNM();
-  return nm->mkAnd(d_premises);
-}
-
 std::ostream& operator<<(std::ostream& out, const InferInfo& ii)
 {
-  out << "(infer " << ii.d_id << " " << ii.d_conclusion << std::endl;
+  out << "(infer :id " << ii.d_id << std::endl;
+  out << ":conclusion " << ii.d_conclusion << std::endl;
   if (!ii.d_premises.empty())
   {
     out << " :premise (" << ii.d_premises << ")" << std::endl;
   }
-
+  out << ":skolems " << ii.d_skolems << std::endl;
   out << ")";
   return out;
 }
index 3edbef737ffa2f08b7cf097c1d9108b8930bf8d2..ecfc354d11400f8a58de33badbd1bb27b755e0f4 100644 (file)
@@ -33,6 +33,8 @@ namespace bags {
 enum class Inference : uint32_t
 {
   NONE,
+  BAG_NON_NEGATIVE_COUNT,
+  BAG_MK_BAG_SAME_ELEMENT,
   BAG_MK_BAG,
   BAG_EQUALITY,
   BAG_DISEQUALITY,
@@ -81,7 +83,7 @@ class InferInfo : public TheoryInference
   bool process(TheoryInferenceManager* im, bool asLemma) override;
   /** The inference identifier */
   Inference d_id;
-  /** The conclusion */
+  /** The conclusions */
   Node d_conclusion;
   /**
    * The premise(s) of the inference, interpreted conjunctively. These are
@@ -90,11 +92,10 @@ class InferInfo : public TheoryInference
   std::vector<Node> d_premises;
 
   /**
-   * A list of new skolems introduced as a result of this inference. They
-   * are mapped to by a length status, indicating the length constraint that
-   * can be assumed for them.
+   * A map of nodes to their skolem variables introduced as a result of this
+   * inference.
    */
-  std::vector<Node> d_newSkolem;
+  std::map<Node, Node> d_skolems;
   /**  Is this infer info trivial? True if d_conc is true. */
   bool isTrivial() const;
   /**
@@ -108,8 +109,6 @@ class InferInfo : public TheoryInference
    * engine with no new external premises (d_noExplain).
    */
   bool isFact() const;
-  /** Get premises */
-  Node getPremises() const;
 };
 
 /**
index 759ea1f0c470200bb7e3fd6356b7eaaa572e8198..7ef126911cf7842a247527c52ef24a9f5055a6c1 100644 (file)
@@ -32,18 +32,33 @@ InferenceGenerator::InferenceGenerator(SolverState* state) : d_state(state)
   d_one = d_nm->mkConst(Rational(1));
 }
 
+InferInfo InferenceGenerator::nonNegativeCount(Node n, Node e)
+{
+  Assert(n.getType().isBag());
+  Assert(e.getType() == n.getType().getBagElementType());
+
+  InferInfo inferInfo;
+  inferInfo.d_id = Inference::BAG_NON_NEGATIVE_COUNT;
+  Node count = d_nm->mkNode(kind::BAG_COUNT, e, n);
+
+  Node gte = d_nm->mkNode(kind::GEQ, count, d_zero);
+  inferInfo.d_conclusion = gte;
+  return inferInfo;
+}
+
 InferInfo InferenceGenerator::mkBag(Node n, Node e)
 {
   Assert(n.getKind() == kind::MK_BAG);
   Assert(e.getType() == n.getType().getBagElementType());
 
   InferInfo inferInfo;
-  inferInfo.d_id = Inference::BAG_MK_BAG;
-  Node count = getMultiplicitySkolem(e, n, inferInfo);
+  Node skolem = getSkolem(n, inferInfo);
+  Node count = getMultiplicityTerm(e, skolem);
   if (n[0] == e)
   {
-    // TODO: refactor this with the rewriter
+    // TODO issue #78: refactor this with BagRewriter
     // (=> true (= (bag.count e (bag e c)) c))
+    inferInfo.d_id = Inference::BAG_MK_BAG_SAME_ELEMENT;
     inferInfo.d_conclusion = count.eqNode(n[1]);
   }
   else
@@ -51,7 +66,7 @@ InferInfo InferenceGenerator::mkBag(Node n, Node e)
     // (=>
     //   true
     //   (= (bag.count e (bag x c)) (ite (= e x) c 0)))
-
+    inferInfo.d_id = Inference::BAG_MK_BAG;
     Node same = d_nm->mkNode(kind::EQUAL, n[0], e);
     Node ite = d_nm->mkNode(kind::ITE, same, n[1], d_zero);
     Node equal = count.eqNode(ite);
@@ -60,30 +75,12 @@ InferInfo InferenceGenerator::mkBag(Node n, Node e)
   return inferInfo;
 }
 
-InferInfo InferenceGenerator::bagEquality(Node n, Node e)
-{
-  Assert(n.getKind() == kind::EQUAL && n[0].getType().isBag());
-  Assert(e.getType() == n[0].getType().getBagElementType());
-
-  Node A = n[0];
-  Node B = n[1];
-  InferInfo inferInfo;
-  inferInfo.d_id = Inference::BAG_EQUALITY;
-  inferInfo.d_premises.push_back(n);
-  Node countA = getMultiplicitySkolem(e, A, inferInfo);
-  Node countB = getMultiplicitySkolem(e, B, inferInfo);
-
-  Node equal = countA.eqNode(countB);
-  inferInfo.d_conclusion = equal;
-  return inferInfo;
-}
-
 struct BagsDeqAttributeId
 {
 };
 typedef expr::Attribute<BagsDeqAttributeId, Node> BagsDeqAttribute;
 
-InferInfo InferenceGenerator::bagDisequality(Node n)
+InferInfo InferenceGenerator::bagDisequality(Node n, Node reason)
 {
   Assert(n.getKind() == kind::NOT && n[0].getKind() == kind::EQUAL);
   Assert(n[0][0].getType().isBag());
@@ -93,22 +90,19 @@ InferInfo InferenceGenerator::bagDisequality(Node n)
 
   InferInfo inferInfo;
   inferInfo.d_id = Inference::BAG_DISEQUALITY;
+  inferInfo.d_premises.push_back(reason);
 
   TypeNode elementType = A.getType().getBagElementType();
-
   BoundVarManager* bvm = d_nm->getBoundVarManager();
   Node element = bvm->mkBoundVar<BagsDeqAttribute>(n, elementType);
-  SkolemManager* sm = d_nm->getSkolemManager();
   Node skolem =
-      sm->mkSkolem(element,
-                   n,
-                   "bag_disequal",
-                   "an extensional lemma for disequality of two bags");
+      d_sm->mkSkolem(element,
+                     n,
+                     "bag_disequal",
+                     "an extensional lemma for disequality of two bags");
 
-  inferInfo.d_newSkolem.push_back(skolem);
-
-  Node countA = getMultiplicitySkolem(skolem, A, inferInfo);
-  Node countB = getMultiplicitySkolem(skolem, B, inferInfo);
+  Node countA = getMultiplicityTerm(skolem, A);
+  Node countB = getMultiplicityTerm(skolem, B);
 
   Node disEqual = countA.eqNode(countB).notNode();
 
@@ -117,13 +111,20 @@ InferInfo InferenceGenerator::bagDisequality(Node n)
   return inferInfo;
 }
 
+Node InferenceGenerator::getSkolem(Node& n, InferInfo& inferInfo)
+{
+  Node skolem = d_sm->mkPurifySkolem(n, "skolem_bag", "skolem bag");
+  inferInfo.d_skolems[n] = skolem;
+  return skolem;
+}
+
 InferInfo InferenceGenerator::bagEmpty(Node e)
 {
   EmptyBag emptyBag = EmptyBag(d_nm->mkBagType(e.getType()));
   Node empty = d_nm->mkConst(emptyBag);
   InferInfo inferInfo;
   inferInfo.d_id = Inference::BAG_EMPTY;
-  Node count = getMultiplicitySkolem(e, empty, inferInfo);
+  Node count = getMultiplicityTerm(e, empty);
 
   Node equal = count.eqNode(d_zero);
   inferInfo.d_conclusion = equal;
@@ -140,9 +141,11 @@ InferInfo InferenceGenerator::unionDisjoint(Node n, Node e)
   InferInfo inferInfo;
   inferInfo.d_id = Inference::BAG_UNION_DISJOINT;
 
-  Node countA = getMultiplicitySkolem(e, A, inferInfo);
-  Node countB = getMultiplicitySkolem(e, B, inferInfo);
-  Node count = getMultiplicitySkolem(e, n, inferInfo);
+  Node countA = getMultiplicityTerm(e, A);
+  Node countB = getMultiplicityTerm(e, B);
+
+  Node skolem = getSkolem(n, inferInfo);
+  Node count = getMultiplicityTerm(e, skolem);
 
   Node sum = d_nm->mkNode(kind::PLUS, countA, countB);
   Node equal = count.eqNode(sum);
@@ -161,9 +164,11 @@ InferInfo InferenceGenerator::unionMax(Node n, Node e)
   InferInfo inferInfo;
   inferInfo.d_id = Inference::BAG_UNION_MAX;
 
-  Node countA = getMultiplicitySkolem(e, A, inferInfo);
-  Node countB = getMultiplicitySkolem(e, B, inferInfo);
-  Node count = getMultiplicitySkolem(e, n, inferInfo);
+  Node countA = getMultiplicityTerm(e, A);
+  Node countB = getMultiplicityTerm(e, B);
+
+  Node skolem = getSkolem(n, inferInfo);
+  Node count = getMultiplicityTerm(e, skolem);
 
   Node gt = d_nm->mkNode(kind::GT, countA, countB);
   Node max = d_nm->mkNode(kind::ITE, gt, countA, countB);
@@ -183,9 +188,10 @@ InferInfo InferenceGenerator::intersection(Node n, Node e)
   InferInfo inferInfo;
   inferInfo.d_id = Inference::BAG_INTERSECTION_MIN;
 
-  Node countA = getMultiplicitySkolem(e, A, inferInfo);
-  Node countB = getMultiplicitySkolem(e, B, inferInfo);
-  Node count = getMultiplicitySkolem(e, n, inferInfo);
+  Node countA = getMultiplicityTerm(e, A);
+  Node countB = getMultiplicityTerm(e, B);
+  Node skolem = getSkolem(n, inferInfo);
+  Node count = getMultiplicityTerm(e, skolem);
 
   Node lt = d_nm->mkNode(kind::LT, countA, countB);
   Node min = d_nm->mkNode(kind::ITE, lt, countA, countB);
@@ -204,9 +210,10 @@ InferInfo InferenceGenerator::differenceSubtract(Node n, Node e)
   InferInfo inferInfo;
   inferInfo.d_id = Inference::BAG_DIFFERENCE_SUBTRACT;
 
-  Node countA = getMultiplicitySkolem(e, A, inferInfo);
-  Node countB = getMultiplicitySkolem(e, B, inferInfo);
-  Node count = getMultiplicitySkolem(e, n, inferInfo);
+  Node countA = getMultiplicityTerm(e, A);
+  Node countB = getMultiplicityTerm(e, B);
+  Node skolem = getSkolem(n, inferInfo);
+  Node count = getMultiplicityTerm(e, skolem);
 
   Node subtract = d_nm->mkNode(kind::MINUS, countA, countB);
   Node gte = d_nm->mkNode(kind::GEQ, countA, countB);
@@ -226,9 +233,11 @@ InferInfo InferenceGenerator::differenceRemove(Node n, Node e)
   InferInfo inferInfo;
   inferInfo.d_id = Inference::BAG_DIFFERENCE_REMOVE;
 
-  Node countA = getMultiplicitySkolem(e, A, inferInfo);
-  Node countB = getMultiplicitySkolem(e, B, inferInfo);
-  Node count = getMultiplicitySkolem(e, n, inferInfo);
+  Node countA = getMultiplicityTerm(e, A);
+  Node countB = getMultiplicityTerm(e, B);
+
+  Node skolem = getSkolem(n, inferInfo);
+  Node count = getMultiplicityTerm(e, skolem);
 
   Node notInB = d_nm->mkNode(kind::EQUAL, countB, d_zero);
   Node difference = d_nm->mkNode(kind::ITE, notInB, countA, d_zero);
@@ -246,8 +255,9 @@ InferInfo InferenceGenerator::duplicateRemoval(Node n, Node e)
   InferInfo inferInfo;
   inferInfo.d_id = Inference::BAG_DUPLICATE_REMOVAL;
 
-  Node countA = getMultiplicitySkolem(e, A, inferInfo);
-  Node count = getMultiplicitySkolem(e, n, inferInfo);
+  Node countA = getMultiplicityTerm(e, A);
+  Node skolem = getSkolem(n, inferInfo);
+  Node count = getMultiplicityTerm(e, skolem);
 
   Node gte = d_nm->mkNode(kind::GEQ, countA, d_one);
   Node ite = d_nm->mkNode(kind::ITE, gte, d_one, d_zero);
@@ -256,16 +266,10 @@ InferInfo InferenceGenerator::duplicateRemoval(Node n, Node e)
   return inferInfo;
 }
 
-Node InferenceGenerator::getMultiplicitySkolem(Node element,
-                                               Node bag,
-                                               InferInfo& inferInfo)
+Node InferenceGenerator::getMultiplicityTerm(Node element, Node bag)
 {
   Node count = d_nm->mkNode(kind::BAG_COUNT, element, bag);
-  Node skolem = d_state->registerBagElement(count);
-  eq::EqualityEngine* ee = d_state->getEqualityEngine();
-  ee->assertEquality(skolem.eqNode(count), true, d_nm->mkConst(true));
-  inferInfo.d_newSkolem.push_back(skolem);
-  return skolem;
+  return count;
 }
 
 }  // namespace bags
index b56997088cbc9118cd79387ba67075a1dbca0f92..9eee46e43579ca6dcd9b32a8c15faf421e3d5ab8 100644 (file)
@@ -38,33 +38,38 @@ class InferenceGenerator
   InferenceGenerator(SolverState* state);
 
   /**
-   * @param n is (bag x c) of type (Bag E)
+   * @param A is a bag of type (Bag E)
    * @param e is a node of type E
    * @return an inference that represents the following implication
    * (=>
    *   true
-   *   (= (bag.count e (bag x c)) (ite (= e x) c 0)))
+   *   (>= (bag.count e A) 0)
    */
-  InferInfo mkBag(Node n, Node e);
+  InferInfo nonNegativeCount(Node n, Node e);
 
   /**
-   * @param n is (= A B) where A, B are bags of type (Bag E)
-   * @param e is a node of Type E
+   * @param n is (bag x c) of type (Bag E)
+   * @param e is a node of type E
    * @return an inference that represents the following implication
    * (=>
-   *   (= A B)
-   *   (= (count e A) (count e B)))
+   *   true
+   *   (= (bag.count e skolem) c))
+   *   if e is exactly node x. Node skolem is a fresh variable equals (bag x c).
+   *   Otherwise the following inference is returned
+   * (=>
+   *   true
+   *   (= (bag.count e skolem) (ite (= e x) c 0)))
    */
-  InferInfo bagEquality(Node n, Node e);
+  InferInfo mkBag(Node n, Node e);
   /**
    * @param n is (not (= A B)) where A, B are bags of type (Bag E)
    * @return an inference that represents the following implication
    * (=>
    *   (not (= A B))
    *   (not (= (count e A) (count e B))))
-   *   where e is a fresh skolem of type E
+   *   where e is a fresh skolem of type E.
    */
-  InferInfo bagDisequality(Node n);
+  InferInfo bagDisequality(Node n, Node reason);
   /**
    * @param e is a node of Type E
    * @return an inference that represents the following implication
@@ -79,10 +84,9 @@ class InferenceGenerator
    * @return an inference that represents the following implication
    * (=>
    *   true
-   *   (= (count e k_{(union_disjoint A B)})
+   *   (= (count e skolem)
    *      (+ (count e A) (count e B))))
-   *  where k_{(union_disjoint A B)} is a unique purification skolem
-   *  for (union_disjoint A B)
+   *  where skolem is a fresh variable equals (union_disjoint A B)
    */
   InferInfo unionDisjoint(Node n, Node e);
   /**
@@ -91,11 +95,13 @@ class InferenceGenerator
    * @return an inference that represents the following implication
    * (=>
    *   true
-   *   (= (count e (union_max A B))
+   *   (=
+   *     (count e skolem)
    *     (ite
-   *     (> (count e A) (count e B))
-   *     (count e A)
-   *     (count e B)))))
+   *       (> (count e A) (count e B))
+   *       (count e A)
+   *       (count e B)))))
+   * where skolem is a fresh variable equals (union_max A B)
    */
   InferInfo unionMax(Node n, Node e);
   /**
@@ -104,11 +110,13 @@ class InferenceGenerator
    * @return an inference that represents the following implication
    * (=>
    *   true
-   *   (= (count e (intersection_min A B))
+   *   (=
+   *     (count e skolem)
    *     (ite(
-   *     (< (count e A) (count e B))
-   *     (count e A)
-   *     (count e B)))))
+   *       (< (count e A) (count e B))
+   *       (count e A)
+   *       (count e B)))))
+   * where skolem is a fresh variable equals (intersection_min A B)
    */
   InferInfo intersection(Node n, Node e);
   /**
@@ -117,11 +125,13 @@ class InferenceGenerator
    * @return an inference that represents the following implication
    * (=>
    *   true
-   *   (= (count e (difference_subtract A B))
+   *   (=
+   *     (count e skolem)
    *     (ite
-   *        (>= (count e A) (count e B))
-   *        (- (count e A) (count e B))
-   *        0))))
+   *       (>= (count e A) (count e B))
+   *       (- (count e A) (count e B))
+   *       0))))
+   * where skolem is a fresh variable equals (difference_subtract A B)
    */
   InferInfo differenceSubtract(Node n, Node e);
   /**
@@ -130,11 +140,13 @@ class InferenceGenerator
    * @return an inference that represents the following implication
    * (=>
    *   true
-   *   (= (count e (difference_remove A B))
+   *   (=
+   *     (count e skolem)
    *     (ite
-   *        (= (count e B) 0)
-   *        (count e A)
-   *        0))))
+   *       (= (count e B) 0)
+   *       (count e A)
+   *       0))))
+   * where skolem is a fresh variable equals (difference_remove A B)
    */
   InferInfo differenceRemove(Node n, Node e);
   /**
@@ -143,20 +155,24 @@ class InferenceGenerator
    * @return an inference that represents the following implication
    * (=>
    *   true
-   *   (= (count e (duplicate_removal A))
-   *     (ite (>= (count e A) 1) 1 0))))
+   *   (=
+   *    (count e skolem)
+   *    (ite (>= (count e A) 1) 1 0))))
+   * where skolem is a fresh variable equals (duplicate_removal A)
    */
   InferInfo duplicateRemoval(Node n, Node e);
 
   /**
    * @param element of type T
    * @param bag of type (bag T)
-   * @param inferInfo to store new skolem
-   * @return  a skolem for (bag.count element bag)
+   * @return  a count term (bag.count element bag)
    */
-  Node getMultiplicitySkolem(Node element, Node bag, InferInfo& inferInfo);
+  Node getMultiplicityTerm(Node element, Node bag);
 
  private:
+  /** generate skolem variable for node n and add it to inferInfo */
+  Node getSkolem(Node& n, InferInfo& inferInfo);
+
   NodeManager* d_nm;
   SkolemManager* d_sm;
   SolverState* d_state;
index 67025548ceb4c8b387ff1a24411237b6a464c62a..71a014582452ac2bd8cd631f96a31cf3c62cf20c 100644 (file)
@@ -45,7 +45,7 @@ class InferenceManager : public InferenceManagerBuffered
    * process the pending lemmas and then the pending phase requirements.
    * Notice that we process the pending lemmas even if there were facts.
    */
-  // TODO: refactor this before merge with theory of strings
+  // TODO issue #78: refactor this with theory of strings
   void doPending();
 
  private:
index 744f6de9fde04735afda8e8509d9bc71aa734426..9bcb6ae3c54ca29a0e0f8773665d358aceff68b0 100644 (file)
@@ -33,52 +33,89 @@ SolverState::SolverState(context::Context* c,
 {
   d_true = NodeManager::currentNM()->mkConst(true);
   d_false = NodeManager::currentNM()->mkConst(false);
+  d_nm = NodeManager::currentNM();
 }
 
-struct BagsCountAttributeId
+void SolverState::registerBag(TNode n)
 {
-};
-typedef expr::Attribute<BagsCountAttributeId, Node> BagsCountAttribute;
-
-void SolverState::registerClass(TNode n)
-{
-  TypeNode t = n.getType();
-  if (!t.isBag())
-  {
-    return;
-  }
+  Assert(n.getType().isBag());
   d_bags.insert(n);
 }
 
-Node SolverState::registerBagElement(TNode n)
+void SolverState::registerCountTerm(TNode n)
 {
   Assert(n.getKind() == BAG_COUNT);
-  Node element = n[0];
-  TypeNode elementType = element.getType();
-  Node bag = n[1];
-  d_elements[elementType].insert(element);
-  NodeManager* nm = NodeManager::currentNM();
-  BoundVarManager* bvm = nm->getBoundVarManager();
-  Node multiplicity = bvm->mkBoundVar<BagsCountAttribute>(n, nm->integerType());
-  Node equal = n.eqNode(multiplicity);
-  SkolemManager* sm = nm->getSkolemManager();
-  Node skolem = sm->mkSkolem(
-      multiplicity,
-      equal,
-      "bag_multiplicity",
-      "an extensional lemma for multiplicity of an element in a bag");
-  d_count[bag][element] = skolem;
-  Trace("bags::SolverState::registerBagElement")
-      << "New skolem: " << skolem << " for " << n << std::endl;
-
-  return skolem;
+  Node element = getRepresentative(n[0]);
+  Node bag = getRepresentative(n[1]);
+  d_bagElements[bag].insert(element);
+}
+
+const std::set<Node>& SolverState::getBags() { return d_bags; }
+
+const std::set<Node>& SolverState::getElements(Node B)
+{
+  Node bag = getRepresentative(B);
+  return d_bagElements[B];
+}
+
+void SolverState::reset()
+{
+  d_bagElements.clear();
+  d_bags.clear();
 }
 
-std::set<Node>& SolverState::getBags() { return d_bags; }
+void SolverState::initialize()
+{
+  reset();
+  collectBagsAndCountTerms();
+}
+
+void SolverState::collectBagsAndCountTerms()
+{
+  Trace("SolverState::collectBagsAndCountTerms")
+      << "SolverState::collectBagsAndCountTerms start" << endl;
+  eq::EqClassesIterator repIt = eq::EqClassesIterator(d_ee);
+  while (!repIt.isFinished())
+  {
+    Node eqc = (*repIt);
+    Trace("SolverState::collectBagsAndCountTerms")
+        << "[" << eqc << "]: " << endl;
+
+    if (eqc.getType().isBag())
+    {
+      registerBag(eqc);
+    }
 
-std::set<Node>& SolverState::getElements(TypeNode t) { return d_elements[t]; }
+    eq::EqClassIterator it = eq::EqClassIterator(eqc, d_ee);
+    while (!it.isFinished())
+    {
+      Node n = (*it);
+      Kind k = n.getKind();
+      if (k == MK_BAG)
+      {
+        // for terms (bag x c) we need to store x by registering the count term
+        // (bag.count x (bag x c))
+        Node count = d_nm->mkNode(BAG_COUNT, n[0], n);
+        registerCountTerm(count);
+        Trace("SolverState::collectBagsAndCountTerms")
+            << "registered " << count << endl;
+      }
+      if (k == BAG_COUNT)
+      {
+        // this takes care of all count terms in each equivalent class
+        registerCountTerm(n);
+        Trace("SolverState::collectBagsAndCountTerms")
+            << "registered " << n << endl;
+      }
+      ++it;
+    }
 
-std::map<Node, Node>& SolverState::getBagElements(Node B) { return d_count[B]; }
+    ++repIt;
+  }
+
+  Trace("SolverState::collectBagsAndCountTerms")
+      << "SolverState::collectBagsAndCountTerms end" << endl;
+}
 
 }  // namespace bags
 }  // namespace theory
index 8d70ee8f73d1f21fe97b2f6937c4f129ccfe7f59..175317529e4b178bf725d4067fb6a4e68bb4ab93 100644 (file)
@@ -31,24 +31,52 @@ class SolverState : public TheoryState
  public:
   SolverState(context::Context* c, context::UserContext* u, Valuation val);
 
-  void registerClass(TNode n);
+  /**
+   * This function adds the bag representative n to the set d_bags if it is not
+   * already there. This function is called during postCheck to collect bag
+   * terms in the equality engine. See the documentation of
+   * @link SolverState::collectBagsAndCountTerms
+   */
+  void registerBag(TNode n);
 
-  Node registerBagElement(TNode n);
-
-  std::set<Node>& getBags();
-
-  std::set<Node>& getElements(TypeNode t);
-
-  std::map<Node, Node>& getBagElements(Node B);
+  /**
+   * @param n has the form (bag.count e A)
+   * @pre bag A needs is already registered using registerBag(A)
+   * @return a unique skolem for (bag.count e A)
+   */
+  void registerCountTerm(TNode n);
+  /** get all bag terms that are representatives in the equality engine.
+   * This function is valid after the current solver is initialized during
+   * postCheck. See SolverState::initialize and BagSolver::postCheck
+   */
+  const std::set<Node>& getBags();
+  /**
+   * @pre B is a registered bag
+   * @return all elements associated with bag B so far
+   * Note that associated elements are not necessarily elements in B
+   * Example:
+   * (assert (= 0 (bag.count x B)))
+   * element x is associated with bag B, albeit x is definitely not in B.
+   */
+  const std::set<Node>& getElements(Node B);
+  /** initialize bag and count terms */
+  void initialize();
 
  private:
+  /** clear all bags data structures */
+  void reset();
+  /** collect bags' representatives and all count terms.
+   * This function is called during postCheck */
+  void collectBagsAndCountTerms();
   /** constants */
   Node d_true;
   Node d_false;
+  /** node manager for this solver state */
+  NodeManager* d_nm;
+  /** collection of bag representatives */
   std::set<Node> d_bags;
-  std::map<TypeNode, std::set<Node>> d_elements;
-  /** bag -> element -> multiplicity */
-  std::map<Node, std::map<Node, Node>> d_count;
+  /** bag -> associated elements */
+  std::map<Node, std::set<Node>> d_bagElements;
 }; /* class SolverState */
 
 }  // namespace bags
index 21a9d0e53207085ebf8e86a52088ab29c2258d10..153e9017dcfd134457c413684a5e98845de065e4 100644 (file)
@@ -78,22 +78,22 @@ void TheoryBags::finishInit()
 void TheoryBags::postCheck(Effort effort)
 {
   d_im.doPendingFacts();
-  // TODO: clean this before merge Assert(d_strat.isStrategyInit());
+  // TODO issue #78: add Assert(d_strat.isStrategyInit());
   if (!d_state.isInConflict() && !d_valuation.needCheck())
-  // TODO: clean this before merge && d_strat.hasStrategyEffort(e))
+  // TODO issue #78:  add && d_strat.hasStrategyEffort(e))
   {
     Trace("bags::TheoryBags::postCheck") << "effort: " << std::endl;
 
-    // TODO: clean this before merge ++(d_statistics.d_checkRuns);
+    // TODO issue #78: add ++(d_statistics.d_checkRuns);
     bool sentLemma = false;
     bool hadPending = false;
     Trace("bags-check") << "Full effort check..." << std::endl;
     do
     {
       d_im.reset();
-      // TODO: clean this before merge ++(d_statistics.d_strategyRuns);
+      // TODO issue #78: add ++(d_statistics.d_strategyRuns);
       Trace("bags-check") << "  * Run strategy..." << std::endl;
-      // TODO: clean this before merge runStrategy(e);
+      // TODO issue #78: add runStrategy(e);
 
       d_solver.postCheck();
 
@@ -153,14 +153,22 @@ bool TheoryBags::collectModelValues(TheoryModel* m,
       continue;
     }
     Node r = d_state.getRepresentative(n);
-    std::map<Node, Node> elements = d_state.getBagElements(r);
+    std::set<Node> solverElements = d_state.getElements(r);
+    std::set<Node> elements;
+    // only consider terms in termSet and ignore other elements in the solver
+    std::set_intersection(termSet.begin(),
+                          termSet.end(),
+                          solverElements.begin(),
+                          solverElements.end(),
+                          std::inserter(elements, elements.begin()));
     Trace("bags-model") << "Elements of bag " << n << " are: " << std::endl
                         << elements << std::endl;
     std::map<Node, Node> elementReps;
-    for (std::pair<Node, Node> pair : elements)
+    for (const Node& e : elements)
     {
-      Node key = d_state.getRepresentative(pair.first);
-      Node value = d_state.getRepresentative(pair.second);
+      Node key = d_state.getRepresentative(e);
+      Node countTerm = NodeManager::currentNM()->mkNode(BAG_COUNT, e, r);
+      Node value = d_state.getRepresentative(countTerm);
       elementReps[key] = value;
     }
     Node rep = NormalForm::constructBagFromElements(tn, elementReps);
@@ -211,38 +219,7 @@ void TheoryBags::presolve() {}
 
 /**************************** eq::NotifyClass *****************************/
 
-void TheoryBags::eqNotifyNewClass(TNode n)
-{
-  Kind k = n.getKind();
-  d_state.registerClass(n);
-  if (n.getKind() == MK_BAG)
-  {
-    // TODO: refactor this before merge
-    /*
-     * (bag x m) generates the lemma (and (= s (count x (bag x m))) (= s m))
-     * where s is a fresh skolem variable
-     */
-    NodeManager* nm = NodeManager::currentNM();
-    Node count = nm->mkNode(BAG_COUNT, n[0], n);
-    Node skolem = d_state.registerBagElement(count);
-    Node countSkolem = count.eqNode(skolem);
-    Node skolemMultiplicity = n[1].eqNode(skolem);
-    Node lemma = countSkolem.andNode(skolemMultiplicity);
-    TrustNode trustedLemma = TrustNode::mkTrustLemma(lemma, nullptr);
-    d_im.trustedLemma(trustedLemma);
-  }
-  if (k == BAG_COUNT)
-  {
-    /*
-     * (count x A) generates the lemma (= s (count x A))
-     * where s is a fresh skolem variable
-     */
-    Node skolem = d_state.registerBagElement(n);
-    Node lemma = n.eqNode(skolem);
-    TrustNode trustedLemma = TrustNode::mkTrustLemma(lemma, nullptr);
-    d_im.trustedLemma(trustedLemma);
-  }
-}
+void TheoryBags::eqNotifyNewClass(TNode n) {}
 
 void TheoryBags::eqNotifyMerge(TNode n1, TNode n2) {}
 
@@ -251,10 +228,8 @@ void TheoryBags::eqNotifyDisequal(TNode n1, TNode n2, TNode reason)
   TypeNode t1 = n1.getType();
   if (t1.isBag())
   {
-    InferInfo info = d_ig.bagDisequality(n1.eqNode(n2).notNode());
-    Node lemma = reason.impNode(info.d_conclusion);
-    TrustNode trustedLemma = TrustNode::mkTrustLemma(lemma, nullptr);
-    d_im.trustedLemma(trustedLemma);
+    InferInfo info = d_ig.bagDisequality(n1.eqNode(n2).notNode(), reason);
+    info.process(d_inferManager, true);
   }
 }
 
index cf4b0386d5d9ef78871a1acf2436477baafaa08f..810ed81282e3350ad1c2104527a43455e7e947d6 100644 (file)
@@ -1427,7 +1427,9 @@ set(regress_1_tests
   regress1/bug681.smt2
   regress1/bug694-Unapply1.scala-0.smt2
   regress1/bug800.smt2
+  regress1/bags/difference_remove1.smt2
   regress1/bags/disequality.smt2
+  regress1/bags/issue5759.smt2
   regress1/bags/subbag1.smt2
   regress1/bags/subbag2.smt2
   regress1/bags/union_disjoint.smt2
diff --git a/test/regress/regress1/bags/difference_remove1.smt2 b/test/regress/regress1/bags/difference_remove1.smt2
new file mode 100644 (file)
index 0000000..f5a87c1
--- /dev/null
@@ -0,0 +1,10 @@
+(set-logic ALL)
+(set-info :status sat)
+(declare-fun A () (Bag Int))
+(declare-fun B () (Bag Int))
+(declare-fun x () Int)
+(declare-fun y () Int)
+(assert (= A (union_max (bag x 1) (bag y 2))))
+(assert (= A (union_disjoint B (bag y 2))))
+(assert (= x y))
+(check-sat)
diff --git a/test/regress/regress1/bags/issue5759.smt2 b/test/regress/regress1/bags/issue5759.smt2
new file mode 100644 (file)
index 0000000..ba3752e
--- /dev/null
@@ -0,0 +1,10 @@
+(set-logic ALL)
+(set-info :status sat)
+(set-option :produce-models true)
+(declare-fun A () (Bag Int))
+(declare-fun B () (Bag Int))
+(declare-fun x () Int)
+(assert (not (= A (union_max (bag x 1) (bag 0 1)))))
+(assert (= A (union_disjoint B (bag 0 1))))
+(assert (= x 1))
+(check-sat)
\ No newline at end of file