Refactor bag solver (#7770)
authormudathirmahgoub <mudathirmahgoub@gmail.com>
Tue, 4 Jan 2022 15:39:33 +0000 (09:39 -0600)
committerGitHub <noreply@github.com>
Tue, 4 Jan 2022 15:39:33 +0000 (15:39 +0000)
17 files changed:
src/CMakeLists.txt
src/theory/bags/bag_solver.cpp
src/theory/bags/bag_solver.h
src/theory/bags/bags_rewriter.cpp
src/theory/bags/inference_generator.cpp
src/theory/bags/inference_generator.h
src/theory/bags/solver_state.cpp
src/theory/bags/solver_state.h
src/theory/bags/strategy.cpp [new file with mode: 0644]
src/theory/bags/strategy.h [new file with mode: 0644]
src/theory/bags/theory_bags.cpp
src/theory/bags/theory_bags.h
src/theory/inference_id.cpp
src/theory/inference_id.h
test/regress/CMakeLists.txt
test/regress/regress1/bags/fuzzy3b.smt2 [new file with mode: 0644]
test/unit/theory/theory_bags_rewriter_white.cpp

index fde0088e8a76a6d6f2fe1e6c886db119f8a17530..1ee7260475a8f7f7478d2b22e33b08f93db2d889 100644 (file)
@@ -555,6 +555,8 @@ libcvc5_add_sources(
   theory/bags/rewrites.h
   theory/bags/solver_state.cpp
   theory/bags/solver_state.h
+  theory/bags/strategy.cpp
+  theory/bags/strategy.h
   theory/bags/term_registry.cpp
   theory/bags/term_registry.h
   theory/bags/theory_bags.cpp
index 80ccd6707ebb65e991464a268b712368dd1bcf9f..55367bb8902158e278017867f15e2aaa187b7c28 100644 (file)
@@ -15,6 +15,7 @@
 
 #include "theory/bags/bag_solver.h"
 
+#include "expr/emptybag.h"
 #include "theory/bags/inference_generator.h"
 #include "theory/bags/inference_manager.h"
 #include "theory/bags/normal_form.h"
@@ -50,10 +51,8 @@ BagSolver::BagSolver(Env& env,
 
 BagSolver::~BagSolver() {}
 
-void BagSolver::postCheck()
+void BagSolver::checkBasicOperations()
 {
-  d_state.initialize();
-
   checkDisequalBagTerms();
 
   // At this point, all bag and count representatives should be in the solver
@@ -164,6 +163,39 @@ void BagSolver::checkDifferenceSubtract(const Node& n)
   }
 }
 
+bool BagSolver::checkBagMake()
+{
+  bool sentLemma = false;
+  for (const Node& bag : d_state.getBags())
+  {
+    TypeNode bagType = bag.getType();
+    NodeManager* nm = NodeManager::currentNM();
+    Node empty = nm->mkConst(EmptyBag(bagType));
+    if (d_state.areEqual(empty, bag) || d_state.areDisequal(empty, bag))
+    {
+      continue;
+    }
+
+    // look for BAG_MAKE terms in the equivalent class
+    eq::EqClassIterator it =
+        eq::EqClassIterator(bag, d_state.getEqualityEngine());
+    while (!it.isFinished())
+    {
+      Node n = (*it);
+      if (n.getKind() == BAG_MAKE)
+      {
+        Trace("bags-check") << "splitting on node " << std::endl;
+        InferInfo i = d_ig.bagMake(n);
+        sentLemma |= d_im.lemmaTheoryInference(&i);
+        // it is enough to split only once per equivalent class
+        break;
+      }
+      it++;
+    }
+  }
+  return sentLemma;
+}
+
 void BagSolver::checkBagMake(const Node& n)
 {
   Assert(n.getKind() == BAG_MAKE);
index 45a20e0554ca64490ec2d80ed9d3c95884a244f3..499b7998db510af2f241811266267b35ce74d446 100644 (file)
@@ -40,11 +40,29 @@ class BagSolver : protected EnvObj
   BagSolver(Env& env, SolverState& s, InferenceManager& im, TermRegistry& tr);
   ~BagSolver();
 
-  void postCheck();
+  /**
+   * apply inference rules for basic bag operators:
+   * BAG_MAKE, BAG_UNION_DISJOINT, BAG_UNION_MAX, BAG_INTER_MIN,
+   * BAG_DIFFERENCE_SUBTRACT, BAG_DIFFERENCE_REMOVE, BAG_DUPLICATE_REMOVAL
+   */
+  void checkBasicOperations();
+
+  /**
+   * apply inference rules for BAG_MAKE terms.
+   * For each term (bag x c) that is neither equal nor disequal to the empty
+   * bag, we do a split using the following lemma:
+   * (or
+   *   (and (<  c 1) (= (bag x c) (as bag.empty (Bag E))))
+   *   (and (>= c 1) (not (= (bag x c) (as bag.empty (Bag E))))
+   * where (Bag E) is the type of the bag term
+   * @return true if a new lemma was successfully sent.
+   */
+  bool checkBagMake();
 
  private:
   /** apply inference rules for empty bags */
   void checkEmpty(const Node& n);
+
   /**
    * apply inference rules for BAG_MAKE operator.
    * Example: Suppose n = (bag x c), and we have two count terms (bag.count x n)
index a5fb206aaafc9b179f5b477733ce98bbaff82024..f193bf73cdb6f576a0e5f7a4693e3255b766e904 100644 (file)
@@ -177,13 +177,12 @@ BagsRewriteResponse BagsRewriter::rewriteBagCount(const TNode& n) const
     // (bag.count x bag.empty) = 0
     return BagsRewriteResponse(d_zero, Rewrite::COUNT_EMPTY);
   }
-  if (n[1].getKind() == BAG_MAKE && n[0] == n[1][0])
+  if (n[1].getKind() == BAG_MAKE && n[0] == n[1][0] && n[1][1].isConst()
+      && n[1][1].getConst<Rational>() > Rational(0))
   {
-    // (bag.count x (bag x c)) = (ite (>= c 1) c 0)
+    // (bag.count x (bag x c)) = c, c > 0 is a constant
     Node c = n[1][1];
-    Node geq = d_nm->mkNode(GEQ, c, d_one);
-    Node ite = d_nm->mkNode(ITE, geq, c, d_zero);
-    return BagsRewriteResponse(ite, Rewrite::COUNT_BAG_MAKE);
+    return BagsRewriteResponse(c, Rewrite::COUNT_BAG_MAKE);
   }
   return BagsRewriteResponse(n, Rewrite::NONE);
 }
index 9be0c2b9375e02e9dedf7c68959dfd926724ae49..e2b2207771fba552e7e53ade5fde0e58414e1ca7 100644 (file)
@@ -17,6 +17,7 @@
 
 #include "expr/attribute.h"
 #include "expr/bound_var_manager.h"
+#include "expr/emptybag.h"
 #include "expr/skolem_manager.h"
 #include "theory/bags/inference_manager.h"
 #include "theory/bags/solver_state.h"
@@ -47,12 +48,32 @@ InferInfo InferenceGenerator::nonNegativeCount(Node n, Node e)
 
   InferInfo inferInfo(d_im, InferenceId::BAGS_NON_NEGATIVE_COUNT);
   Node count = d_nm->mkNode(BAG_COUNT, e, n);
-
   Node gte = d_nm->mkNode(GEQ, count, d_zero);
   inferInfo.d_conclusion = gte;
   return inferInfo;
 }
 
+InferInfo InferenceGenerator::bagMake(Node n)
+{
+  Assert(n.getKind() == BAG_MAKE);
+  /*
+   * (or
+   *   (and (<  c 1) (= (bag x c) (as bag.empty (Bag E))))
+   *   (and (>= c 1) (not (= (bag x c) (as bag.empty (Bag E))))
+   */
+  Node x = n[0];
+  Node c = n[1];
+  InferInfo inferInfo(d_im, InferenceId::BAGS_BAG_MAKE_SPLIT);
+  Node empty = d_nm->mkConst(EmptyBag(n.getType()));
+  Node equal = d_nm->mkNode(EQUAL, n, empty);
+  Node geq = d_nm->mkNode(GEQ, c, d_one);
+  Node isEmpty = geq.notNode().andNode(equal);
+  Node isNotEmpty = geq.andNode(equal.notNode());
+  Node orNode = isEmpty.orNode(isNotEmpty);
+  inferInfo.d_conclusion = orNode;
+  return inferInfo;
+}
+
 InferInfo InferenceGenerator::bagMake(Node n, Node e)
 {
   Assert(n.getKind() == BAG_MAKE);
index 8ed3ead382c10d7e103e07279975cd61a7475e53..29da0162959f359792d7e85a2dd7b6672b2de18d 100644 (file)
@@ -47,6 +47,15 @@ class InferenceGenerator
    */
   InferInfo nonNegativeCount(Node n, Node e);
 
+  /**
+   * @param n is (bag x c) of type (Bag E)
+   * @return an inference that represents the following lemma:
+   * (or
+   *   (and (<  c 1) (= (bag x c) (as bag.empty (Bag E))))
+   *   (and (>= c 1) (not (= (bag x c) (as bag.empty (Bag E))))
+   */
+  InferInfo bagMake(Node n);
+
   /**
    * @param n is (bag x c) of type (Bag E)
    * @param e is a node of type E
index ad817062fb512a6999d3c8a354e20f4e54ed3175..52cbb86717d5020977f81b4115102b536076d8d3 100644 (file)
@@ -40,19 +40,35 @@ void SolverState::registerBag(TNode n)
   d_bags.insert(n);
 }
 
-void SolverState::registerCountTerm(TNode n)
+Node SolverState::registerCountTerm(TNode n)
 {
   Assert(n.getKind() == BAG_COUNT);
-  Node element = n[0];
+  Node element = getRepresentative(n[0]);
   Node bag = getRepresentative(n[1]);
-  d_bagElements[bag].insert(element);
+  Node count = d_nm->mkNode(BAG_COUNT, element, bag);
+  Node skolem = d_nm->getSkolemManager()->mkPurifySkolem(count, "bag.count");
+  d_bagElements[bag].push_back(std::make_pair(element, skolem));
+  return count.eqNode(skolem);
 }
 
 const std::set<Node>& SolverState::getBags() { return d_bags; }
 
-const std::set<Node>& SolverState::getElements(Node B)
+std::set<Node> SolverState::getElements(Node B)
 {
   Node bag = getRepresentative(B);
+  std::set<Node> elements;
+  std::vector<std::pair<Node, Node>> pairs = d_bagElements[bag];
+  for (std::pair<Node, Node> pair : pairs)
+  {
+    elements.insert(pair.first);
+  }
+  return elements;
+}
+
+const std::vector<std::pair<Node, Node>>& SolverState::getElementCountPairs(
+    Node n)
+{
+  Node bag = getRepresentative(n);
   return d_bagElements[bag];
 }
 
@@ -65,15 +81,17 @@ void SolverState::reset()
   d_deq.clear();
 }
 
-void SolverState::initialize()
+std::vector<Node> SolverState::initialize()
 {
   reset();
-  collectBagsAndCountTerms();
   collectDisequalBagTerms();
+  return collectBagsAndCountTerms();
 }
 
-void SolverState::collectBagsAndCountTerms()
+std::vector<Node> SolverState::collectBagsAndCountTerms()
 {
+  std::vector<Node> lemmas;
+
   eq::EqClassesIterator repIt = eq::EqClassesIterator(d_ee);
   while (!repIt.isFinished())
   {
@@ -96,14 +114,16 @@ void SolverState::collectBagsAndCountTerms()
         // for terms (bag x c) we need to store x by registering the count term
         // (bag.count x (bag x c))
         Node count = d_nm->mkNode(BAG_COUNT, n[0], n);
-        registerCountTerm(count);
+        Node lemma = registerCountTerm(count);
+        lemmas.push_back(lemma);
         Trace("SolverState::collectBagsAndCountTerms")
             << "registered " << count << endl;
       }
       if (k == BAG_COUNT)
       {
         // this takes care of all count terms in each equivalent class
-        registerCountTerm(n);
+        Node lemma = registerCountTerm(n);
+        lemmas.push_back(lemma);
         Trace("SolverState::collectBagsAndCountTerms")
             << "registered " << n << endl;
       }
@@ -114,7 +134,7 @@ void SolverState::collectBagsAndCountTerms()
   }
 
   Trace("bags-eqc") << "bag representatives: " << d_bags << endl;
-  Trace("bags-eqc") << "bag elements: " << d_bagElements << endl;
+  return lemmas;
 }
 
 void SolverState::collectDisequalBagTerms()
index d6f628537c4facc0507a534ab0ae1b2992c02b8a..64c4b16741d0649c786c44f601854faee55a7e12 100644 (file)
@@ -42,9 +42,10 @@ class SolverState : public TheoryState
   /**
    * @param n has the form (bag.count e A)
    * @pre bag A needs is already registered using registerBag(A)
-   * @return a unique skolem for (bag.count e A)
+   * @return a lemma (= skolem (bag.count eRep ARep)) where
+   * eRep, ARep are representatives of e, A respectively
    */
-  void registerCountTerm(TNode n);
+  Node registerCountTerm(TNode n);
   /** get all bag terms that are representatives in the equality engine.
    * This function is valid after the current solver is initialized during
    * postCheck. See SolverState::initialize and BagSolver::postCheck
@@ -58,11 +59,18 @@ class SolverState : public TheoryState
    * (assert (= 0 (bag.count x B)))
    * element x is associated with bag B, albeit x is definitely not in B.
    */
-  const std::set<Node>& getElements(Node B);
-  /** initialize bag and count terms */
-  void initialize();
+  std::set<Node> getElements(Node B);
+  /**
+   * initialize bag and count terms
+   * @return a list of skolem lemmas to be asserted
+   * */
+  std::vector<Node> initialize();
   /** return disequal bag terms */
   const std::set<Node>& getDisequalBagTerms();
+  /**
+   * return a list of bag elements and their skolem counts
+   */
+  const std::vector<std::pair<Node, Node>>& getElementCountPairs(Node n);
 
  private:
   /** clear all bags data structures */
@@ -70,8 +78,9 @@ class SolverState : public TheoryState
   /**
    * collect bags' representatives and all count terms.
    * This function is called during postCheck
+   * @return a list of skolem lemmas to be asserted
    */
-  void collectBagsAndCountTerms();
+  std::vector<Node> collectBagsAndCountTerms();
   /**
    * collect disequal bag terms. This function is called during postCheck.
    */
@@ -83,8 +92,12 @@ class SolverState : public TheoryState
   NodeManager* d_nm;
   /** collection of bag representatives */
   std::set<Node> d_bags;
-  /** bag -> associated elements */
-  std::map<Node, std::set<Node>> d_bagElements;
+  /**
+   * This cache maps bag representatives to pairs of elements and multiplicity
+   * skolems which are used for model building.
+   * This map is cleared and initialized at the start of each full effort check.
+   */
+  std::map<Node, std::vector<std::pair<Node, Node>>> d_bagElements;
   /** Disequal bag terms */
   std::set<Node> d_deq;
 }; /* class SolverState */
diff --git a/src/theory/bags/strategy.cpp b/src/theory/bags/strategy.cpp
new file mode 100644 (file)
index 0000000..541be9a
--- /dev/null
@@ -0,0 +1,105 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Andrew Reynolds, Mudathir Mohamed
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * Implementation of the strategy of the theory of bags.
+ */
+
+#include "theory/bags/strategy.h"
+
+namespace cvc5 {
+namespace theory {
+namespace bags {
+
+std::ostream& operator<<(std::ostream& out, InferStep s)
+{
+  switch (s)
+  {
+    case BREAK: out << "break"; break;
+    case CHECK_INIT: out << "check_init"; break;
+    case CHECK_BAG_MAKE: out << "check_bag_make"; break;
+    case CHECK_BASIC_OPERATIONS: out << "CHECK_BASIC_OPERATIONS"; break;
+    default: out << "?"; break;
+  }
+  return out;
+}
+
+Strategy::Strategy() : d_strategy_init(false) {}
+
+Strategy::~Strategy() {}
+
+bool Strategy::isStrategyInit() const { return d_strategy_init; }
+
+bool Strategy::hasStrategyEffort(Theory::Effort e) const
+{
+  return d_strat_steps.find(e) != d_strat_steps.end();
+}
+
+std::vector<std::pair<InferStep, size_t> >::iterator Strategy::stepBegin(
+    Theory::Effort e)
+{
+  std::map<Theory::Effort, std::pair<size_t, size_t> >::const_iterator it =
+      d_strat_steps.find(e);
+  Assert(it != d_strat_steps.end());
+  return d_infer_steps.begin() + it->second.first;
+}
+
+std::vector<std::pair<InferStep, size_t> >::iterator Strategy::stepEnd(
+    Theory::Effort e)
+{
+  std::map<Theory::Effort, std::pair<size_t, size_t> >::const_iterator it =
+      d_strat_steps.find(e);
+  Assert(it != d_strat_steps.end());
+  return d_infer_steps.begin() + it->second.second;
+}
+
+void Strategy::addStrategyStep(InferStep s, int effort, bool addBreak)
+{
+  // must run check init first
+  Assert((s == CHECK_INIT) == d_infer_steps.empty());
+  d_infer_steps.push_back(std::pair<InferStep, int>(s, effort));
+  if (addBreak)
+  {
+    d_infer_steps.push_back(std::pair<InferStep, int>(BREAK, 0));
+  }
+}
+
+void Strategy::initializeStrategy()
+{
+  // initialize the strategy if not already done so
+  if (!d_strategy_init)
+  {
+    std::map<Theory::Effort, unsigned> step_begin;
+    std::map<Theory::Effort, unsigned> step_end;
+    d_strategy_init = true;
+    // beginning indices
+    step_begin[Theory::EFFORT_FULL] = 0;
+    // add the inference steps
+    addStrategyStep(CHECK_INIT);
+    addStrategyStep(CHECK_BAG_MAKE);
+    addStrategyStep(CHECK_BASIC_OPERATIONS);
+    step_end[Theory::EFFORT_FULL] = d_infer_steps.size() - 1;
+
+    // set the beginning/ending ranges
+    for (const std::pair<const Theory::Effort, unsigned>& it_begin : step_begin)
+    {
+      Theory::Effort e = it_begin.first;
+      std::map<Theory::Effort, unsigned>::iterator it_end = step_end.find(e);
+      Assert(it_end != step_end.end());
+      d_strat_steps[e] =
+          std::pair<unsigned, unsigned>(it_begin.second, it_end->second);
+    }
+  }
+}
+
+}  // namespace bags
+}  // namespace theory
+}  // namespace cvc5
diff --git a/src/theory/bags/strategy.h b/src/theory/bags/strategy.h
new file mode 100644 (file)
index 0000000..7dc5d69
--- /dev/null
@@ -0,0 +1,97 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Andrew Reynolds, Mudathir Mohamed
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * Strategy of the theory of bags.
+ */
+
+#include "cvc5_private.h"
+
+#ifndef CVC5__THEORY__BAGS__STRATEGY_H
+#define CVC5__THEORY__BAGS__STRATEGY_H
+
+#include <map>
+#include <vector>
+
+#include "theory/theory.h"
+
+namespace cvc5 {
+namespace theory {
+namespace bags {
+
+/** inference steps
+ *
+ * Corresponds to a step in the overall strategy of the bags solver. For
+ * details on the individual steps, see documentation on the inference schemas
+ * within Strategy.
+ */
+enum InferStep
+{
+  // indicates that the strategy should break if lemmas or facts are added
+  BREAK,
+  // check initial
+  CHECK_INIT,
+  // check bag operator
+  CHECK_BAG_MAKE,
+  // check basic operations
+  CHECK_BASIC_OPERATIONS
+};
+std::ostream& operator<<(std::ostream& out, InferStep i);
+
+/**
+ * The strategy of theory of bags.
+ *
+ * This stores a sequence of the above enum that indicates the calls to
+ * runInferStep to make on the theory of bags, given by parent.
+ */
+class Strategy
+{
+ public:
+  Strategy();
+  ~Strategy();
+  /** is this strategy initialized? */
+  bool isStrategyInit() const;
+  /** do we have a strategy for effort e? */
+  bool hasStrategyEffort(Theory::Effort e) const;
+  /** begin and end iterators for effort e */
+  std::vector<std::pair<InferStep, size_t> >::iterator stepBegin(
+      Theory::Effort e);
+  std::vector<std::pair<InferStep, size_t> >::iterator stepEnd(
+      Theory::Effort e);
+  /** initialize the strategy
+   *
+   * This initializes the above information based on the options. This makes
+   * a series of calls to addStrategyStep above.
+   */
+  void initializeStrategy();
+
+ private:
+  /** add strategy step
+   *
+   * This adds (s,effort) as a strategy step to the vectors d_infer_steps and
+   * d_infer_step_effort. This indicates that a call to runInferStep should
+   * be run as the next step in the strategy. If addBreak is true, we add
+   * a BREAK to the strategy following this step.
+   */
+  void addStrategyStep(InferStep s, int effort = 0, bool addBreak = true);
+  /** is strategy initialized */
+  bool d_strategy_init;
+  /** the strategy */
+  std::vector<std::pair<InferStep, size_t> > d_infer_steps;
+  /** the range (begin, end) of steps to run at given efforts */
+  std::map<Theory::Effort, std::pair<size_t, size_t> > d_strat_steps;
+}; /* class Strategy */
+
+}  // namespace bags
+}  // namespace theory
+}  // namespace cvc5
+
+#endif /* CVC5__THEORY__BAGS__STRATEGY_H */
index 39598975b862d05606c67834182d4c3256d62fea..cfd6f117862122ddf5df14f80e9a6dfd7ac3f0ab 100644 (file)
@@ -159,11 +159,11 @@ TrustNode TheoryBags::expandChooseOperator(const Node& node,
 void TheoryBags::postCheck(Effort effort)
 {
   d_im.doPendingFacts();
-  // TODO issue #78: add Assert(d_strat.isStrategyInit());
-  if (!d_state.isInConflict() && !d_valuation.needCheck())
-  // TODO issue #78:  add && d_strat.hasStrategyEffort(e))
+  Assert(d_strat.isStrategyInit());
+  if (!d_state.isInConflict() && !d_valuation.needCheck()
+      && d_strat.hasStrategyEffort(effort))
   {
-    Trace("bags::TheoryBags::postCheck") << "effort: " << std::endl;
+    Trace("bags::TheoryBags::postCheck") << "effort: " << effort << std::endl;
 
     // TODO issue #78: add ++(d_statistics.d_checkRuns);
     bool sentLemma = false;
@@ -174,9 +174,12 @@ void TheoryBags::postCheck(Effort effort)
       d_im.reset();
       // TODO issue #78: add ++(d_statistics.d_strategyRuns);
       Trace("bags-check") << "  * Run strategy..." << std::endl;
-      // TODO issue #78: add runStrategy(e);
-
-      d_solver.postCheck();
+      std::vector<Node> lemmas = d_state.initialize();
+      for (Node lemma : lemmas)
+      {
+        d_im.lemma(lemma, InferenceId::BAGS_COUNT_SKOLEM);
+      }
+      runStrategy(effort);
 
       // remember if we had pending facts or lemmas
       hadPending = d_im.hasPending();
@@ -192,7 +195,7 @@ void TheoryBags::postCheck(Effort effort)
       sentLemma = d_im.hasSentLemma();
       if (Trace.isOn("bags-check"))
       {
-        // TODO: clean this Trace("bags-check") << "  ...finish run strategy: ";
+        Trace("bags-check") << "  ...finish run strategy: ";
         Trace("bags-check") << (hadPending ? "hadPending " : "");
         Trace("bags-check") << (sentLemma ? "sentLemma " : "");
         Trace("bags-check") << (d_state.isInConflict() ? "conflict " : "");
@@ -211,6 +214,66 @@ void TheoryBags::postCheck(Effort effort)
   Assert(!d_im.hasPendingLemma());
 }
 
+void TheoryBags::runStrategy(Theory::Effort e)
+{
+  std::vector<std::pair<InferStep, size_t>>::iterator it = d_strat.stepBegin(e);
+  std::vector<std::pair<InferStep, size_t>>::iterator stepEnd =
+      d_strat.stepEnd(e);
+
+  Trace("bags-process") << "----check, next round---" << std::endl;
+  while (it != stepEnd)
+  {
+    InferStep curr = it->first;
+    if (curr == BREAK)
+    {
+      if (d_state.isInConflict() || d_im.hasPending())
+      {
+        break;
+      }
+    }
+    else
+    {
+      if (runInferStep(curr, it->second) || d_state.isInConflict())
+      {
+        break;
+      }
+    }
+    ++it;
+  }
+  Trace("bags-process") << "----finished round---" << std::endl;
+}
+
+/** run the given inference step */
+bool TheoryBags::runInferStep(InferStep s, int effort)
+{
+  Trace("bags-process") << "Run " << s;
+  if (effort > 0)
+  {
+    Trace("bags-process") << ", effort = " << effort;
+  }
+  Trace("bags-process") << "..." << std::endl;
+  switch (s)
+  {
+    case CHECK_INIT: break;
+    case CHECK_BAG_MAKE:
+    {
+      if (d_solver.checkBagMake())
+      {
+        return true;
+      }
+      break;
+    }
+    case CHECK_BASIC_OPERATIONS: d_solver.checkBasicOperations(); break;
+    default: Unreachable(); break;
+  }
+  Trace("bags-process") << "Done " << s
+                        << ", addedFact = " << d_im.hasPendingFact()
+                        << ", addedLemma = " << d_im.hasPendingLemma()
+                        << ", conflict = " << d_state.isInConflict()
+                        << std::endl;
+  return false;
+}
+
 void TheoryBags::notifyFact(TNode atom,
                             bool polarity,
                             TNode fact,
@@ -245,22 +308,24 @@ bool TheoryBags::collectModelValues(TheoryModel* m,
 
     processedBags.insert(r);
 
-    std::set<Node> solverElements = d_state.getElements(r);
-    std::set<Node> elements;
-    // only consider terms in termSet and ignore other elements in the solver
-    std::set_intersection(termSet.begin(),
-                          termSet.end(),
-                          solverElements.begin(),
-                          solverElements.end(),
-                          std::inserter(elements, elements.begin()));
-    Trace("bags-model") << "Elements of bag " << n << " are: " << std::endl
-                        << elements << std::endl;
+    const std::vector<std::pair<Node, Node>>& solverElements =
+        d_state.getElementCountPairs(r);
+    std::vector<std::pair<Node, Node>> elements;
+    for (std::pair<Node, Node> pair : solverElements)
+    {
+      if (termSet.find(pair.first) == termSet.end())
+      {
+        continue;
+      }
+      elements.push_back(pair);
+    }
+
     std::map<Node, Node> elementReps;
-    for (const Node& e : elements)
+    for (std::pair<Node, Node> pair : elements)
     {
-      Node key = d_state.getRepresentative(e);
-      Node countTerm = NodeManager::currentNM()->mkNode(BAG_COUNT, e, r);
-      Node value = m->getRepresentative(countTerm);
+      Node key = d_state.getRepresentative(pair.first);
+      Node countSkolem = pair.second;
+      Node value = m->getRepresentative(countSkolem);
       elementReps[key] = value;
     }
     Node rep = NormalForm::constructBagFromElements(tn, elementReps);
@@ -299,7 +364,12 @@ void TheoryBags::preRegisterTerm(TNode n)
   }
 }
 
-void TheoryBags::presolve() {}
+void TheoryBags::presolve()
+{
+  Debug("bags-presolve") << "Started presolve" << std::endl;
+  d_strat.initializeStrategy();
+  Debug("bags-presolve") << "Finished presolve" << std::endl;
+}
 
 /**************************** eq::NotifyClass *****************************/
 
index 8d15947efdc51e4eb9a20274e6924876fcfcb2b7..cc76c5453d6a3905518656ae3934cf3dc56d58c0 100644 (file)
@@ -25,6 +25,7 @@
 #include "theory/bags/inference_generator.h"
 #include "theory/bags/inference_manager.h"
 #include "theory/bags/solver_state.h"
+#include "theory/bags/strategy.h"
 #include "theory/bags/term_registry.h"
 #include "theory/theory.h"
 #include "theory/theory_eq_notify.h"
@@ -72,6 +73,11 @@ class TheoryBags : public Theory
   void preRegisterTerm(TNode n) override;
   void presolve() override;
 
+  /** run strategy for effort e */
+  void runStrategy(Theory::Effort e);
+  /** run the given inference step */
+  bool runInferStep(InferStep s, int effort);
+
  private:
   /** Functions to handle callbacks from equality engine */
   class NotifyClass : public TheoryEqNotifyClass
@@ -114,6 +120,9 @@ class TheoryBags : public Theory
   /** bag reduction */
   BagReduction d_bagReduction;
 
+  /** The representation of the strategy */
+  Strategy d_strat;
+
   void eqNotifyNewClass(TNode n);
   void eqNotifyMerge(TNode n1, TNode n2);
   void eqNotifyDisequal(TNode t1, TNode t2, TNode reason);
index af18d3eefe00d46199df5b78d7b8278845ac034d..a2dcdec8c85c522010df46ab0ebc129af5fc15c8 100644 (file)
@@ -107,6 +107,8 @@ const char* toString(InferenceId i)
 
     case InferenceId::BAGS_NON_NEGATIVE_COUNT: return "BAGS_NON_NEGATIVE_COUNT";
     case InferenceId::BAGS_BAG_MAKE: return "BAGS_BAG_MAKE";
+    case InferenceId::BAGS_BAG_MAKE_SPLIT: return "BAGS_BAG_MAKE_SPLIT";
+    case InferenceId::BAGS_COUNT_SKOLEM: return "BAGS_COUNT_SKOLEM";
     case InferenceId::BAGS_EQUALITY: return "BAGS_EQUALITY";
     case InferenceId::BAGS_DISEQUALITY: return "BAGS_DISEQUALITY";
     case InferenceId::BAGS_EMPTY: return "BAGS_EMPTY";
index 8eaeeab754314a52c0e40b6b4f0257f3c6a39fcd..8d90785034068d91a02af400cc5951b191bae05a 100644 (file)
@@ -170,6 +170,8 @@ enum class InferenceId
   // ---------------------------------- bags theory
   BAGS_NON_NEGATIVE_COUNT,
   BAGS_BAG_MAKE,
+  BAGS_BAG_MAKE_SPLIT,
+  BAGS_COUNT_SKOLEM,
   BAGS_EQUALITY,
   BAGS_DISEQUALITY,
   BAGS_EMPTY,
index 6bca782f13b256b9af3dce6367cb18309a98fd27..3c0e79596d20716722c85e04d368eac02aebc3a5 100644 (file)
@@ -1627,6 +1627,7 @@ set(regress_1_tests
   regress1/bags/fuzzy1.smt2
   regress1/bags/fuzzy2.smt2
   regress1/bags/fuzzy3.smt2
+  regress1/bags/fuzzy3b.smt2
   regress1/bags/fuzzy4.smt2
   regress1/bags/fuzzy5.smt2
   regress1/bags/fuzzy6.smt2
diff --git a/test/regress/regress1/bags/fuzzy3b.smt2 b/test/regress/regress1/bags/fuzzy3b.smt2
new file mode 100644 (file)
index 0000000..9264557
--- /dev/null
@@ -0,0 +1,7 @@
+(set-logic ALL)
+(set-info :status sat)
+(set-option :produce-models true)
+(declare-fun A () (Bag (Tuple Int Int)))
+(declare-fun d () (Tuple Int Int))
+(assert (= A (bag.difference_remove A (bag d 1))))
+(check-sat)
index ff98c308a70153050f156292a2c69e5190a5fe13..af24abc39870b279754b0868782062b845e86f97 100644 (file)
@@ -189,7 +189,7 @@ TEST_F(TestTheoryWhiteBagsRewriter, bag_count)
   ASSERT_TRUE(response1.d_status == REWRITE_AGAIN_FULL
               && response1.d_node == zero);
 
-  // (bag.count x (bag x c) = (ite (>= c 1) c 0)
+  // (bag.count x (bag x c) = c, c > 0 is a constant
   Node bag = d_nodeManager->mkBag(d_nodeManager->stringType(), skolem, three);
   Node n2 = d_nodeManager->mkNode(BAG_COUNT, skolem, bag);
   RewriteResponse response2 = d_rewriter->postRewrite(n2);
@@ -197,7 +197,7 @@ TEST_F(TestTheoryWhiteBagsRewriter, bag_count)
   Node geq = d_nodeManager->mkNode(GEQ, three, one);
   Node ite = d_nodeManager->mkNode(ITE, geq, three, zero);
   ASSERT_TRUE(response2.d_status == REWRITE_AGAIN_FULL
-              && response2.d_node == ite);
+              && response2.d_node == three);
 }
 
 TEST_F(TestTheoryWhiteBagsRewriter, duplicate_removal)