Add bag inferences for operators: intersection, duplicate_removal, and empty bags...
authormudathirmahgoub <mudathirmahgoub@gmail.com>
Fri, 29 Jan 2021 21:44:28 +0000 (15:44 -0600)
committerGitHub <noreply@github.com>
Fri, 29 Jan 2021 21:44:28 +0000 (15:44 -0600)
This PR adds inferences for operators: intersection, duplicate_removal, and empty bags during post check.
It also fixes a bug in SolverState::getElements

14 files changed:
src/theory/bags/bag_solver.cpp
src/theory/bags/bag_solver.h
src/theory/bags/infer_info.cpp
src/theory/bags/inference_generator.cpp
src/theory/bags/inference_generator.h
src/theory/bags/solver_state.cpp
src/theory/bags/solver_state.h
src/theory/bags/theory_bags.cpp
test/regress/CMakeLists.txt
test/regress/regress1/bags/duplicate_removal1.smt2 [new file with mode: 0644]
test/regress/regress1/bags/duplicate_removal2.smt2 [new file with mode: 0644]
test/regress/regress1/bags/emptybag1.smt2 [new file with mode: 0644]
test/regress/regress1/bags/intersection_min1.smt2 [new file with mode: 0644]
test/regress/regress1/bags/intersection_min2.smt2 [new file with mode: 0644]

index 495f73723d81b6325a8edc155e5f8038df9e6e5b..bdd4a9b301be9ed6f688347f58a366ea6202b1c6 100644 (file)
@@ -27,7 +27,7 @@ namespace theory {
 namespace bags {
 
 BagSolver::BagSolver(SolverState& s, InferenceManager& im, TermRegistry& tr)
-    : d_state(s), d_im(im), d_termReg(tr)
+    : d_state(s), d_ig(&d_state), d_im(im), d_termReg(tr)
 {
   d_zero = NodeManager::currentNM()->mkConst(Rational(0));
   d_one = NodeManager::currentNM()->mkConst(Rational(1));
@@ -41,6 +41,8 @@ void BagSolver::postCheck()
 {
   d_state.initialize();
 
+  checkDisequalBagTerms();
+
   // At this point, all bag and count representatives should be in the solver
   // state.
   for (const Node& bag : d_state.getBags())
@@ -54,11 +56,14 @@ void BagSolver::postCheck()
       Kind k = n.getKind();
       switch (k)
       {
+        case kind::EMPTYBAG: checkEmpty(n); break;
         case kind::MK_BAG: checkMkBag(n); break;
         case kind::UNION_DISJOINT: checkUnionDisjoint(n); break;
         case kind::UNION_MAX: checkUnionMax(n); break;
+        case kind::INTERSECTION_MIN: checkIntersectionMin(n); break;
         case kind::DIFFERENCE_SUBTRACT: checkDifferenceSubtract(n); break;
         case kind::DIFFERENCE_REMOVE: checkDifferenceRemove(n); break;
+        case kind::DUPLICATE_REMOVAL: checkDuplicateRemoval(n); break;
         default: break;
       }
       it++;
@@ -91,16 +96,24 @@ set<Node> BagSolver::getElementsForBinaryOperator(const Node& n)
   return elements;
 }
 
+void BagSolver::checkEmpty(const Node& n)
+{
+  Assert(n.getKind() == EMPTYBAG);
+  for (const Node& e : d_state.getElements(n))
+  {
+    InferInfo i = d_ig.empty(n, e);
+    i.process(&d_im, true);
+  }
+}
+
 void BagSolver::checkUnionDisjoint(const Node& n)
 {
   Assert(n.getKind() == UNION_DISJOINT);
   std::set<Node> elements = getElementsForBinaryOperator(n);
   for (const Node& e : elements)
   {
-    InferenceGenerator ig(&d_state);
-    InferInfo i = ig.unionDisjoint(n, e);
+    InferInfo i = d_ig.unionDisjoint(n, e);
     i.process(&d_im, true);
-    Trace("bags::BagSolver::postCheck") << i << endl;
   }
 }
 
@@ -110,10 +123,19 @@ void BagSolver::checkUnionMax(const Node& n)
   std::set<Node> elements = getElementsForBinaryOperator(n);
   for (const Node& e : elements)
   {
-    InferenceGenerator ig(&d_state);
-    InferInfo i = ig.unionMax(n, e);
+    InferInfo i = d_ig.unionMax(n, e);
+    i.process(&d_im, true);
+  }
+}
+
+void BagSolver::checkIntersectionMin(const Node& n)
+{
+  Assert(n.getKind() == INTERSECTION_MIN);
+  std::set<Node> elements = getElementsForBinaryOperator(n);
+  for (const Node& e : elements)
+  {
+    InferInfo i = d_ig.intersection(n, e);
     i.process(&d_im, true);
-    Trace("bags::BagSolver::postCheck") << i << endl;
   }
 }
 
@@ -123,10 +145,8 @@ void BagSolver::checkDifferenceSubtract(const Node& n)
   std::set<Node> elements = getElementsForBinaryOperator(n);
   for (const Node& e : elements)
   {
-    InferenceGenerator ig(&d_state);
-    InferInfo i = ig.differenceSubtract(n, e);
+    InferInfo i = d_ig.differenceSubtract(n, e);
     i.process(&d_im, true);
-    Trace("bags::BagSolver::postCheck") << i << endl;
   }
 }
 
@@ -138,18 +158,14 @@ void BagSolver::checkMkBag(const Node& n)
       << " are: " << d_state.getElements(n) << std::endl;
   for (const Node& e : d_state.getElements(n))
   {
-    InferenceGenerator ig(&d_state);
-    InferInfo i = ig.mkBag(n, e);
+    InferInfo i = d_ig.mkBag(n, e);
     i.process(&d_im, true);
-    Trace("bags::BagSolver::postCheck") << i << endl;
   }
 }
 void BagSolver::checkNonNegativeCountTerms(const Node& bag, const Node& element)
 {
-  InferenceGenerator ig(&d_state);
-  InferInfo i = ig.nonNegativeCount(bag, element);
+  InferInfo i = d_ig.nonNegativeCount(bag, element);
   i.process(&d_im, true);
-  Trace("bags::BagSolver::postCheck") << i << endl;
 }
 
 void BagSolver::checkDifferenceRemove(const Node& n)
@@ -158,10 +174,34 @@ void BagSolver::checkDifferenceRemove(const Node& n)
   std::set<Node> elements = getElementsForBinaryOperator(n);
   for (const Node& e : elements)
   {
-    InferenceGenerator ig(&d_state);
-    InferInfo i = ig.differenceRemove(n, e);
+    InferInfo i = d_ig.differenceRemove(n, e);
     i.process(&d_im, true);
-    Trace("bags::BagSolver::postCheck") << i << endl;
+  }
+}
+
+void BagSolver::checkDuplicateRemoval(Node n)
+{
+  Assert(n.getKind() == DUPLICATE_REMOVAL);
+  set<Node> elements;
+  const set<Node>& downwards = d_state.getElements(n);
+  const set<Node>& upwards = d_state.getElements(n[0]);
+
+  elements.insert(downwards.begin(), downwards.end());
+  elements.insert(upwards.begin(), upwards.end());
+
+  for (const Node& e : elements)
+  {
+    InferInfo i = d_ig.duplicateRemoval(n, e);
+    i.process(&d_im, true);
+  }
+}
+
+void BagSolver::checkDisequalBagTerms()
+{
+  for (const Node& n : d_state.getDisequalBagTerms())
+  {
+    InferInfo info = d_ig.bagDisequality(n);
+    info.process(&d_im, true);
   }
 }
 
index b4b18c00cab16f27edc2e89ab1b3a6d5641de427..b19e1f11e351786f19c636b1f641d9c43d2ad003 100644 (file)
@@ -20,6 +20,7 @@
 #include "context/cdhashset.h"
 #include "context/cdlist.h"
 #include "theory/bags/infer_info.h"
+#include "theory/bags/inference_generator.h"
 #include "theory/bags/inference_manager.h"
 #include "theory/bags/normal_form.h"
 #include "theory/bags/solver_state.h"
@@ -41,6 +42,8 @@ class BagSolver
   void postCheck();
 
  private:
+  /** apply inference rules for empty bags */
+  void checkEmpty(const Node& n);
   /**
    * apply inference rules for MK_BAG operator.
    * Example: Suppose n = (bag x c), and we have two count terms (bag.count x n)
@@ -60,15 +63,23 @@ class BagSolver
   void checkUnionDisjoint(const Node& n);
   /** apply inference rules for union max */
   void checkUnionMax(const Node& n);
+  /** apply inference rules for intersection_min operator */
+  void checkIntersectionMin(const Node& n);
   /** apply inference rules for difference subtract */
   void checkDifferenceSubtract(const Node& n);
   /** apply inference rules for difference remove */
   void checkDifferenceRemove(const Node& n);
+  /** apply inference rules for duplicate removal operator */
+  void checkDuplicateRemoval(Node n);
   /** apply non negative constraints for multiplicities */
   void checkNonNegativeCountTerms(const Node& bag, const Node& element);
+  /** apply inference rules for disequal bag terms */
+  void checkDisequalBagTerms();
 
   /** The solver state object */
   SolverState& d_state;
+  /** The inference generator object*/
+  InferenceGenerator d_ig;
   /** Reference to the inference manager for the theory of bags */
   InferenceManager& d_im;
   /** Reference to the term registry of theory of bags */
index 5b3274617307b2c6da8f9077b9a37ab66cd764a5..9bf546af13845910d61f77f8dadc0e0015d08192 100644 (file)
@@ -76,6 +76,9 @@ bool InferInfo::process(TheoryInferenceManager* im, bool asLemma)
     TrustNode trustedLemma = TrustNode::mkTrustLemma(n, nullptr);
     im->trustedLemma(trustedLemma);
   }
+
+  Trace("bags::InferInfo::process") << (*this) << std::endl;
+
   return true;
 }
 
index 7ef126911cf7842a247527c52ef24a9f5055a6c1..708c25f34f5762860b3ecf68f3ad97e6a045fa78 100644 (file)
@@ -80,17 +80,15 @@ struct BagsDeqAttributeId
 };
 typedef expr::Attribute<BagsDeqAttributeId, Node> BagsDeqAttribute;
 
-InferInfo InferenceGenerator::bagDisequality(Node n, Node reason)
+InferInfo InferenceGenerator::bagDisequality(Node n)
 {
-  Assert(n.getKind() == kind::NOT && n[0].getKind() == kind::EQUAL);
-  Assert(n[0][0].getType().isBag());
+  Assert(n.getKind() == kind::EQUAL && n[0].getType().isBag());
 
-  Node A = n[0][0];
-  Node B = n[0][1];
+  Node A = n[0];
+  Node B = n[1];
 
   InferInfo inferInfo;
   inferInfo.d_id = Inference::BAG_DISEQUALITY;
-  inferInfo.d_premises.push_back(reason);
 
   TypeNode elementType = A.getType().getBagElementType();
   BoundVarManager* bvm = d_nm->getBoundVarManager();
@@ -106,7 +104,7 @@ InferInfo InferenceGenerator::bagDisequality(Node n, Node reason)
 
   Node disEqual = countA.eqNode(countB).notNode();
 
-  inferInfo.d_premises.push_back(n);
+  inferInfo.d_premises.push_back(n.notNode());
   inferInfo.d_conclusion = disEqual;
   return inferInfo;
 }
@@ -118,13 +116,15 @@ Node InferenceGenerator::getSkolem(Node& n, InferInfo& inferInfo)
   return skolem;
 }
 
-InferInfo InferenceGenerator::bagEmpty(Node e)
+InferInfo InferenceGenerator::empty(Node n, Node e)
 {
-  EmptyBag emptyBag = EmptyBag(d_nm->mkBagType(e.getType()));
-  Node empty = d_nm->mkConst(emptyBag);
+  Assert(n.getKind() == kind::EMPTYBAG);
+  Assert(e.getType() == n.getType().getBagElementType());
+
   InferInfo inferInfo;
+  Node skolem = getSkolem(n, inferInfo);
   inferInfo.d_id = Inference::BAG_EMPTY;
-  Node count = getMultiplicityTerm(e, empty);
+  Node count = getMultiplicityTerm(e, skolem);
 
   Node equal = count.eqNode(d_zero);
   inferInfo.d_conclusion = equal;
index 9eee46e43579ca6dcd9b32a8c15faf421e3d5ab8..4a852530ac2de20355507c23cf3b92ab4df84563 100644 (file)
@@ -62,22 +62,25 @@ class InferenceGenerator
    */
   InferInfo mkBag(Node n, Node e);
   /**
-   * @param n is (not (= A B)) where A, B are bags of type (Bag E)
+   * @param n is (= A B) where A, B are bags of type (Bag E), and
+   * (not (= A B)) is an assertion in the equality engine
    * @return an inference that represents the following implication
    * (=>
    *   (not (= A B))
    *   (not (= (count e A) (count e B))))
    *   where e is a fresh skolem of type E.
    */
-  InferInfo bagDisequality(Node n, Node reason);
+  InferInfo bagDisequality(Node n);
   /**
+   * @param n is (as emptybag (Bag E))
    * @param e is a node of Type E
    * @return an inference that represents the following implication
    * (=>
    *   true
-   *   (= 0 (count e (as emptybag (Bag E)))))
+   *   (= 0 (count e skolem)))
+   *   where skolem = (as emptybag (Bag String))
    */
-  InferInfo bagEmpty(Node e);
+  InferInfo empty(Node n, Node e);
   /**
    * @param n is (union_disjoint A B) where A, B are bags of type (Bag E)
    * @param e is a node of Type E
index 9bcb6ae3c54ca29a0e0f8773665d358aceff68b0..adca85068e38a4288b9afc334a38dd53273aa3be 100644 (file)
@@ -55,31 +55,32 @@ const std::set<Node>& SolverState::getBags() { return d_bags; }
 const std::set<Node>& SolverState::getElements(Node B)
 {
   Node bag = getRepresentative(B);
-  return d_bagElements[B];
+  return d_bagElements[bag];
 }
 
+const std::set<Node>& SolverState::getDisequalBagTerms() { return d_deq; }
+
 void SolverState::reset()
 {
   d_bagElements.clear();
   d_bags.clear();
+  d_deq.clear();
 }
 
 void SolverState::initialize()
 {
   reset();
   collectBagsAndCountTerms();
+  collectDisequalBagTerms();
 }
 
 void SolverState::collectBagsAndCountTerms()
 {
-  Trace("SolverState::collectBagsAndCountTerms")
-      << "SolverState::collectBagsAndCountTerms start" << endl;
   eq::EqClassesIterator repIt = eq::EqClassesIterator(d_ee);
   while (!repIt.isFinished())
   {
     Node eqc = (*repIt);
-    Trace("SolverState::collectBagsAndCountTerms")
-        << "[" << eqc << "]: " << endl;
+    Trace("bags-eqc") << "Eqc [ " << eqc << " ] = { ";
 
     if (eqc.getType().isBag())
     {
@@ -90,6 +91,7 @@ void SolverState::collectBagsAndCountTerms()
     while (!it.isFinished())
     {
       Node n = (*it);
+      Trace("bags-eqc") << (*it) << " ";
       Kind k = n.getKind();
       if (k == MK_BAG)
       {
@@ -109,12 +111,27 @@ void SolverState::collectBagsAndCountTerms()
       }
       ++it;
     }
-
+    Trace("bags-eqc") << " } " << std::endl;
     ++repIt;
   }
 
-  Trace("SolverState::collectBagsAndCountTerms")
-      << "SolverState::collectBagsAndCountTerms end" << endl;
+  Trace("bags-eqc") << "bag representatives: " << d_bags << endl;
+  Trace("bags-eqc") << "bag elements: " << d_bagElements << endl;
+}
+
+void SolverState::collectDisequalBagTerms()
+{
+  eq::EqClassIterator it = eq::EqClassIterator(d_false, d_ee);
+  while (!it.isFinished())
+  {
+    Node n = (*it);
+    if (n.getKind() == EQUAL && n[0].getType().isBag())
+    {
+      Trace("bags-eqc") << "Disequal terms: " << n << std::endl;
+      d_deq.insert(n);
+    }
+    ++it;
+  }
 }
 
 }  // namespace bags
index 175317529e4b178bf725d4067fb6a4e68bb4ab93..7670e5dec9363bc8528660bfc3cc8681965da3d5 100644 (file)
@@ -61,13 +61,21 @@ class SolverState : public TheoryState
   const std::set<Node>& getElements(Node B);
   /** initialize bag and count terms */
   void initialize();
+  /** return disequal bag terms */
+  const std::set<Node>& getDisequalBagTerms();
 
  private:
   /** clear all bags data structures */
   void reset();
-  /** collect bags' representatives and all count terms.
-   * This function is called during postCheck */
+  /**
+   * collect bags' representatives and all count terms.
+   * This function is called during postCheck
+   */
   void collectBagsAndCountTerms();
+  /**
+   * collect disequal bag terms. This function is called during postCheck.
+   */
+  void collectDisequalBagTerms();
   /** constants */
   Node d_true;
   Node d_false;
@@ -77,6 +85,8 @@ class SolverState : public TheoryState
   std::set<Node> d_bags;
   /** bag -> associated elements */
   std::map<Node, std::set<Node>> d_bagElements;
+  /** Disequal bag terms */
+  std::set<Node> d_deq;
 }; /* class SolverState */
 
 }  // namespace bags
index 153e9017dcfd134457c413684a5e98845de065e4..15e8e00e7a5532010bc0d1b17ca59aaa86f1bf04 100644 (file)
@@ -199,7 +199,6 @@ void TheoryBags::preRegisterTerm(TNode n)
     case BAG_FROM_SET:
     case BAG_TO_SET:
     case BAG_IS_SINGLETON:
-    case DUPLICATE_REMOVAL:
     {
       std::stringstream ss;
       ss << "Term of kind " << n.getKind() << " is not supported yet";
@@ -223,15 +222,7 @@ void TheoryBags::eqNotifyNewClass(TNode n) {}
 
 void TheoryBags::eqNotifyMerge(TNode n1, TNode n2) {}
 
-void TheoryBags::eqNotifyDisequal(TNode n1, TNode n2, TNode reason)
-{
-  TypeNode t1 = n1.getType();
-  if (t1.isBag())
-  {
-    InferInfo info = d_ig.bagDisequality(n1.eqNode(n2).notNode(), reason);
-    info.process(d_inferManager, true);
-  }
-}
+void TheoryBags::eqNotifyDisequal(TNode n1, TNode n2, TNode reason) {}
 
 void TheoryBags::NotifyClass::eqNotifyNewClass(TNode n)
 {
index 94be987f7f766f8213af3e512ead6e72cf05de31..128f9c567241afb25e974569c4a0e8c2c9385cbf 100644 (file)
@@ -1430,6 +1430,11 @@ set(regress_1_tests
   regress1/bug800.smt2
   regress1/bags/difference_remove1.smt2
   regress1/bags/disequality.smt2
+  regress1/bags/duplicate_removal1.smt2
+  regress1/bags/duplicate_removal2.smt2
+  regress1/bags/emptybag1.smt2
+  regress1/bags/intersection_min1.smt2
+  regress1/bags/intersection_min2.smt2
   regress1/bags/issue5759.smt2
   regress1/bags/subbag1.smt2
   regress1/bags/subbag2.smt2
diff --git a/test/regress/regress1/bags/duplicate_removal1.smt2 b/test/regress/regress1/bags/duplicate_removal1.smt2
new file mode 100644 (file)
index 0000000..2b662c6
--- /dev/null
@@ -0,0 +1,8 @@
+(set-logic ALL)
+(set-info :status sat)
+(set-option :produce-models true)
+(declare-fun A () (Bag String))
+(declare-fun B () (Bag String))
+(assert (= B (duplicate_removal A)))
+(assert (distinct (as emptybag (Bag String)) A B))
+(check-sat)
diff --git a/test/regress/regress1/bags/duplicate_removal2.smt2 b/test/regress/regress1/bags/duplicate_removal2.smt2
new file mode 100644 (file)
index 0000000..7dacaae
--- /dev/null
@@ -0,0 +1,8 @@
+(set-logic ALL)
+(set-info :status unsat)
+(declare-fun A () (Bag String))
+(declare-fun B () (Bag String))
+(assert (= B (duplicate_removal A)))
+(assert (distinct (as emptybag (Bag String)) A B))
+(assert (= B (union_max A B)))
+(check-sat)
\ No newline at end of file
diff --git a/test/regress/regress1/bags/emptybag1.smt2 b/test/regress/regress1/bags/emptybag1.smt2
new file mode 100644 (file)
index 0000000..f7f9259
--- /dev/null
@@ -0,0 +1,10 @@
+(set-logic ALL)
+(set-info :status unsat)
+(declare-fun A () (Bag String))
+(declare-fun x () String)
+(declare-fun y () Int)
+(assert (= x "x"))
+(assert (= A (as emptybag (Bag String))))
+(assert (= (bag.count x A) y))
+(assert(> y 1))
+(check-sat)
diff --git a/test/regress/regress1/bags/intersection_min1.smt2 b/test/regress/regress1/bags/intersection_min1.smt2
new file mode 100644 (file)
index 0000000..f5a515b
--- /dev/null
@@ -0,0 +1,10 @@
+(set-logic ALL)
+(set-info :status sat)
+(set-option :produce-models true)
+(declare-fun A () (Bag String))
+(declare-fun B () (Bag String))
+(declare-fun C () (Bag String))
+(assert (= C (intersection_min A B)))
+(assert (distinct (as emptybag (Bag String)) C))
+(assert (distinct A B C))
+(check-sat)
\ No newline at end of file
diff --git a/test/regress/regress1/bags/intersection_min2.smt2 b/test/regress/regress1/bags/intersection_min2.smt2
new file mode 100644 (file)
index 0000000..66afa2f
--- /dev/null
@@ -0,0 +1,9 @@
+(set-logic ALL)
+(set-info :status unsat)
+(declare-fun A () (Bag String))
+(declare-fun B () (Bag String))
+(declare-fun C () (Bag String))
+(assert (= C (intersection_min A B)))
+(assert (= C (union_disjoint A B)))
+(assert (distinct (as emptybag (Bag String)) C))
+(check-sat)