Add bags inference generator (#5731)
authormudathirmahgoub <mudathirmahgoub@gmail.com>
Fri, 8 Jan 2021 16:07:50 +0000 (10:07 -0600)
committerGitHub <noreply@github.com>
Fri, 8 Jan 2021 16:07:50 +0000 (10:07 -0600)
This PR adds inference generator for basic bag rules.

29 files changed:
src/CMakeLists.txt
src/api/cvc4cpp.cpp
src/expr/type_node.cpp
src/parser/smt2/smt2.cpp
src/printer/smt2/smt2_printer.cpp
src/theory/bags/bag_solver.cpp [new file with mode: 0644]
src/theory/bags/bag_solver.h [new file with mode: 0644]
src/theory/bags/bags_rewriter.cpp
src/theory/bags/bags_rewriter.h
src/theory/bags/infer_info.cpp [new file with mode: 0644]
src/theory/bags/infer_info.h [new file with mode: 0644]
src/theory/bags/inference_generator.cpp [new file with mode: 0644]
src/theory/bags/inference_generator.h [new file with mode: 0644]
src/theory/bags/inference_manager.cpp
src/theory/bags/inference_manager.h
src/theory/bags/normal_form.cpp
src/theory/bags/normal_form.h
src/theory/bags/solver_state.cpp
src/theory/bags/solver_state.h
src/theory/bags/theory_bags.cpp
src/theory/bags/theory_bags.h
src/theory/strings/base_solver.h
test/regress/CMakeLists.txt
test/regress/regress1/bags/disequality.smt2 [new file with mode: 0644]
test/regress/regress1/bags/subbag1.smt2 [new file with mode: 0644]
test/regress/regress1/bags/subbag2.smt2 [new file with mode: 0644]
test/regress/regress1/bags/union_disjoint.smt2 [new file with mode: 0644]
test/regress/regress1/bags/union_max1.smt2 [new file with mode: 0644]
test/regress/regress1/bags/union_max2.smt2 [new file with mode: 0644]

index f01c948db40ab035c1dc25b0fc5e5d04449ee1af..7e294443c2fbd7150c33c3fd3862d0cdb8e25804 100644 (file)
@@ -456,8 +456,14 @@ libcvc4_add_sources(
   theory/atom_requests.h
   theory/bags/bags_rewriter.cpp
   theory/bags/bags_rewriter.h
+  theory/bags/bag_solver.cpp
+  theory/bags/bag_solver.h
   theory/bags/bags_statistics.cpp
   theory/bags/bags_statistics.h
+  theory/bags/infer_info.cpp
+  theory/bags/infer_info.h
+  theory/bags/inference_generator.cpp
+  theory/bags/inference_generator.h
   theory/bags/inference_manager.cpp
   theory/bags/inference_manager.h
   theory/bags/make_bag_op.cpp
index b11a01628626bf61337b6621624256670cf4efa6..49974d30d54a504338fed3286a65ebcf39ed155a 100644 (file)
@@ -3418,6 +3418,20 @@ Term Solver::mkTermHelper(Kind kind, const std::vector<Term>& children) const
       Node singleton = getNodeManager()->mkSingleton(type, *children[0].d_node);
       res = Term(this, singleton).getExpr();
     }
+    else if (kind == api::MK_BAG)
+    {
+      // the type of the term is the same as the type of the internal node
+      // see Term::getSort()
+      TypeNode type = children[0].d_node->getType();
+      // Internally NodeManager::mkBag needs a type argument
+      // to construct a bag, since there is no difference between
+      // integers and reals (both are Rationals).
+      // At the API, mkReal and mkInteger are different and therefore the
+      // element type can be used safely here.
+      Node bag = getNodeManager()->mkBag(
+          type, *children[0].d_node, *children[1].d_node);
+      res = Term(this, bag).getExpr();
+    }
     else
     {
       res = d_exprMgr->mkExpr(k, echildren);
index 49fbe73eff4688016f5e13ad6280bba0baa426e9..3c038546a4ce964b8c3f677fd6f490ca3399b45a 100644 (file)
@@ -151,12 +151,12 @@ bool TypeNode::isFiniteInternal(bool usortFinite)
       TypeNode tnc = getArrayConstituentType();
       if (!tnc.isFiniteInternal(usortFinite))
       {
-        // arrays with consistuent type that is infinite are infinite
+        // arrays with constituent type that is infinite are infinite
         ret = false;
       }
       else if (getArrayIndexType().isFiniteInternal(usortFinite))
       {
-        // arrays with both finite consistuent and index types are finite
+        // arrays with both finite constituent and index types are finite
         ret = true;
       }
       else
@@ -170,6 +170,11 @@ bool TypeNode::isFiniteInternal(bool usortFinite)
     {
       ret = getSetElementType().isFiniteInternal(usortFinite);
     }
+    else if (isBag())
+    {
+      // there are infinite bags for all element types
+      ret = false;
+    }
     else if (isFunction())
     {
       ret = true;
index f069a486f5755ec80dc375b5c2a0c0f6bdaaddb5..36703fd6dbb5c6e0894e863a320c9417ece87b95 100644 (file)
@@ -642,7 +642,7 @@ Command* Smt2::setLogic(std::string name, bool fromCommand)
     addOperator(api::INTERSECTION_MIN, "intersection_min");
     addOperator(api::DIFFERENCE_SUBTRACT, "difference_subtract");
     addOperator(api::DIFFERENCE_REMOVE, "difference_remove");
-    addOperator(api::SUBBAG, "bag.is_included");
+    addOperator(api::SUBBAG, "subbag");
     addOperator(api::BAG_COUNT, "bag.count");
     addOperator(api::DUPLICATE_REMOVAL, "duplicate_removal");
     addOperator(api::MK_BAG, "bag");
index 50bb79a9a4526c8d9f807cc430e61efbd21e3d9e..ccf9c41643f9c19a3a431b3e65768415f221472a 100644 (file)
@@ -813,7 +813,30 @@ void Smt2Printer::toStream(std::ostream& out,
   case kind::UNIVERSE_SET:out << "(as univset " << n.getType() << ")";break;
 
   // bags
-  case kind::BAG_TYPE:  out << smtKindString(k, d_variant) << " "; break;
+  case kind::BAG_TYPE:
+  case kind::UNION_MAX:
+  case kind::UNION_DISJOINT:
+  case kind::INTERSECTION_MIN:
+  case kind::DIFFERENCE_SUBTRACT:
+  case kind::DIFFERENCE_REMOVE:
+  case kind::SUBBAG:
+  case kind::BAG_COUNT:
+  case kind::DUPLICATE_REMOVAL:
+  case kind::BAG_CARD:
+  case kind::BAG_CHOOSE:
+  case kind::BAG_IS_SINGLETON:
+  case kind::BAG_FROM_SET:
+  case kind::BAG_TO_SET: out << smtKindString(k, d_variant) << " "; break;
+  case kind::MK_BAG:
+  {
+    // print (bag (mkBag_op Real) 1 3) as (bag 1.0 3)
+    out << smtKindString(k, d_variant) << " ";
+    TypeNode elemType = n.getType().getBagElementType();
+    toStreamCastToType(
+        out, n[0], toDepth < 0 ? toDepth : toDepth - 1, elemType);
+    out << " " << n[1] << ")";
+    return;
+  }
 
     // fp theory
   case kind::FLOATINGPOINT_FP:
@@ -1170,6 +1193,20 @@ static string smtKindString(Kind k, Variant v)
 
   // bag theory
   case kind::BAG_TYPE: return "Bag";
+  case kind::UNION_MAX: return "union_max";
+  case kind::UNION_DISJOINT: return "union_disjoint";
+  case kind::INTERSECTION_MIN: return "intersection_min";
+  case kind::DIFFERENCE_SUBTRACT: return "difference_subtract";
+  case kind::DIFFERENCE_REMOVE: return "difference_remove";
+  case kind::SUBBAG: return "subbag";
+  case kind::BAG_COUNT: return "bag.count";
+  case kind::DUPLICATE_REMOVAL: return "duplicate_removal";
+  case kind::MK_BAG: return "bag";
+  case kind::BAG_CARD: return "bag.card";
+  case kind::BAG_CHOOSE: return "bag.choose";
+  case kind::BAG_IS_SINGLETON: return "bag.is_singleton";
+  case kind::BAG_FROM_SET: return "bag.from_set";
+  case kind::BAG_TO_SET: return "bag.to_set";
 
     // fp theory
   case kind::FLOATINGPOINT_FP: return "fp";
diff --git a/src/theory/bags/bag_solver.cpp b/src/theory/bags/bag_solver.cpp
new file mode 100644 (file)
index 0000000..5621a7c
--- /dev/null
@@ -0,0 +1,109 @@
+/*********************                                                        */
+/*! \file bag_solver.cpp
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief solver for the theory of bags.
+ **
+ ** solver for the theory of bags.
+ **/
+
+#include "theory/bags/bag_solver.h"
+
+#include "theory/bags/inference_generator.h"
+
+using namespace std;
+using namespace CVC4::context;
+using namespace CVC4::kind;
+
+namespace CVC4 {
+namespace theory {
+namespace bags {
+
+BagSolver::BagSolver(SolverState& s, InferenceManager& im, TermRegistry& tr)
+    : d_state(s), d_im(im), d_termReg(tr)
+{
+  d_zero = NodeManager::currentNM()->mkConst(Rational(0));
+  d_one = NodeManager::currentNM()->mkConst(Rational(1));
+  d_true = NodeManager::currentNM()->mkConst(true);
+  d_false = NodeManager::currentNM()->mkConst(false);
+}
+
+BagSolver::~BagSolver() {}
+
+void BagSolver::postCheck()
+{
+  for (const Node& n : d_state.getBags())
+  {
+    Kind k = n.getKind();
+    switch (k)
+    {
+      case kind::MK_BAG: checkMkBag(n); break;
+      case kind::UNION_DISJOINT: checkUnionDisjoint(n); break;
+      case kind::UNION_MAX: checkUnionMax(n); break;
+      case kind::DIFFERENCE_SUBTRACT: checkDifferenceSubtract(n); break;
+      default: break;
+    }
+  }
+}
+
+void BagSolver::checkUnionDisjoint(const Node& n)
+{
+  Assert(n.getKind() == UNION_DISJOINT);
+  TypeNode elementType = n.getType().getBagElementType();
+  for (const Node& e : d_state.getElements(elementType))
+  {
+    InferenceGenerator ig(&d_state);
+    InferInfo i = ig.unionDisjoint(n, e);
+    i.process(&d_im, true);
+    Trace("bags::BagSolver::postCheck") << i << endl;
+  }
+}
+
+void BagSolver::checkUnionMax(const Node& n)
+{
+  Assert(n.getKind() == UNION_MAX);
+  TypeNode elementType = n.getType().getBagElementType();
+  for (const Node& e : d_state.getElements(elementType))
+  {
+    InferenceGenerator ig(&d_state);
+    InferInfo i = ig.unionMax(n, e);
+    i.process(&d_im, true);
+    Trace("bags::BagSolver::postCheck") << i << endl;
+  }
+}
+
+void BagSolver::checkDifferenceSubtract(const Node& n)
+{
+  Assert(n.getKind() == DIFFERENCE_SUBTRACT);
+  TypeNode elementType = n.getType().getBagElementType();
+  for (const Node& e : d_state.getElements(elementType))
+  {
+    InferenceGenerator ig(&d_state);
+    InferInfo i = ig.differenceSubtract(n, e);
+    i.process(&d_im, true);
+    Trace("bags::BagSolver::postCheck") << i << endl;
+  }
+}
+void BagSolver::checkMkBag(const Node& n)
+{
+  Assert(n.getKind() == MK_BAG);
+  TypeNode elementType = n.getType().getBagElementType();
+  for (const Node& e : d_state.getElements(elementType))
+  {
+    InferenceGenerator ig(&d_state);
+    InferInfo i = ig.mkBag(n, e);
+    i.process(&d_im, true);
+    Trace("bags::BagSolver::postCheck") << i << endl;
+  }
+}
+
+}  // namespace bags
+}  // namespace theory
+}  // namespace CVC4
diff --git a/src/theory/bags/bag_solver.h b/src/theory/bags/bag_solver.h
new file mode 100644 (file)
index 0000000..48583d1
--- /dev/null
@@ -0,0 +1,70 @@
+/*********************                                                        */
+/*! \file bag_solver.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief solver for the theory of bags.
+ **/
+
+#include "cvc4_private.h"
+
+#ifndef CVC4__THEORY__BAG__SOLVER_H
+#define CVC4__THEORY__BAG__SOLVER_H
+
+#include "context/cdhashset.h"
+#include "context/cdlist.h"
+#include "theory/bags/infer_info.h"
+#include "theory/bags/inference_manager.h"
+#include "theory/bags/normal_form.h"
+#include "theory/bags/solver_state.h"
+#include "theory/bags/term_registry.h"
+
+namespace CVC4 {
+namespace theory {
+namespace bags {
+
+/** The solver for the theory of bags
+ *
+ */
+class BagSolver
+{
+ public:
+  BagSolver(SolverState& s, InferenceManager& im, TermRegistry& tr);
+  ~BagSolver();
+
+  void postCheck();
+
+ private:
+  /** apply inference rules for MK_BAG operator */
+  void checkMkBag(const Node& n);
+  /** apply inference rules for union disjoint */
+  void checkUnionDisjoint(const Node& n);
+  /** apply inference rules for union max */
+  void checkUnionMax(const Node& n);
+  /** apply inference rules for difference subtract */
+  void checkDifferenceSubtract(const Node& n);
+
+  /** The solver state object */
+  SolverState& d_state;
+  /** Reference to the inference manager for the theory of bags */
+  InferenceManager& d_im;
+  /** Reference to the term registry of theory of bags */
+  TermRegistry& d_termReg;
+  /** Commonly used constants */
+  Node d_true;
+  Node d_false;
+  Node d_zero;
+  Node d_one;
+}; /* class BagSolver */
+
+}  // namespace bags
+}  // namespace theory
+}  // namespace CVC4
+
+#endif /* CVC4__THEORY__BAG__SOLVER_H */
index aee57c74d7c6192dea471d20853ade665f0510f6..66886bfbff1208435a83c4bb14acf30916fc6846 100644 (file)
@@ -41,6 +41,8 @@ BagsRewriter::BagsRewriter(HistogramStat<Rewrite>* statistics)
     : d_statistics(statistics)
 {
   d_nm = NodeManager::currentNM();
+  d_zero = d_nm->mkConst(Rational(0));
+  d_one = d_nm->mkConst(Rational(1));
 }
 
 RewriteResponse BagsRewriter::postRewrite(TNode n)
@@ -51,7 +53,7 @@ RewriteResponse BagsRewriter::postRewrite(TNode n)
     // no need to rewrite n if it is already in a normal form
     response = BagsRewriteResponse(n, Rewrite::NONE);
   }
-  else if(n.getKind() == EQUAL)
+  else if (n.getKind() == EQUAL)
   {
     response = postRewriteEqual(n);
   }
@@ -162,12 +164,11 @@ BagsRewriteResponse BagsRewriter::rewriteBagCount(const TNode& n) const
   if (n[1].isConst() && n[1].getKind() == EMPTYBAG)
   {
     // (bag.count x emptybag) = 0
-    return BagsRewriteResponse(d_nm->mkConst(Rational(0)),
-                               Rewrite::COUNT_EMPTY);
+    return BagsRewriteResponse(d_zero, Rewrite::COUNT_EMPTY);
   }
   if (n[1].getKind() == MK_BAG && n[0] == n[1][0])
   {
-    // (bag.count x (mkBag x c) = c where c > 0 is a constant
+    // (bag.count x (mkBag x c) = c
     return BagsRewriteResponse(n[1][1], Rewrite::COUNT_MK_BAG);
   }
   return BagsRewriteResponse(n, Rewrite::NONE);
@@ -181,8 +182,7 @@ BagsRewriteResponse BagsRewriter::rewriteDuplicateRemoval(const TNode& n) const
   {
     // (duplicate_removal (mkBag x n)) = (mkBag x 1)
     //  where n is a positive constant
-    Node one = NodeManager::currentNM()->mkConst(Rational(1));
-    Node bag = d_nm->mkBag(n[0][0].getType(), n[0][0], one);
+    Node bag = d_nm->mkBag(n[0][0].getType(), n[0][0], d_one);
     return BagsRewriteResponse(bag, Rewrite::DUPLICATE_REMOVAL_MK_BAG);
   }
   return BagsRewriteResponse(n, Rewrite::NONE);
@@ -444,8 +444,7 @@ BagsRewriteResponse BagsRewriter::rewriteIsSingleton(const TNode& n) const
   if (n[0].getKind() == MK_BAG)
   {
     // (bag.is_singleton (mkBag x c)) = (c == 1)
-    Node one = d_nm->mkConst(Rational(1));
-    Node equal = n[0][1].eqNode(one);
+    Node equal = n[0][1].eqNode(d_one);
     return BagsRewriteResponse(equal, Rewrite::IS_SINGLETON_MK_BAG);
   }
   return BagsRewriteResponse(n, Rewrite::NONE);
@@ -457,9 +456,8 @@ BagsRewriteResponse BagsRewriter::rewriteFromSet(const TNode& n) const
   if (n[0].getKind() == SINGLETON)
   {
     // (bag.from_set (singleton (singleton_op Int) x)) = (mkBag x 1)
-    Node one = d_nm->mkConst(Rational(1));
     TypeNode type = n[0].getType().getSetElementType();
-    Node bag = d_nm->mkBag(type, n[0][0], one);
+    Node bag = d_nm->mkBag(type, n[0][0], d_one);
     return BagsRewriteResponse(bag, Rewrite::FROM_SINGLETON);
   }
   return BagsRewriteResponse(n, Rewrite::NONE);
@@ -484,20 +482,20 @@ BagsRewriteResponse BagsRewriter::postRewriteEqual(const TNode& n) const
   Assert(n.getKind() == kind::EQUAL);
   if (n[0] == n[1])
   {
-    Node ret = NodeManager::currentNM()->mkConst(true);
+    Node ret = d_nm->mkConst(true);
     return BagsRewriteResponse(ret, Rewrite::EQ_REFL);
   }
 
   if (n[0].isConst() && n[1].isConst())
   {
-    Node ret = NodeManager::currentNM()->mkConst(false);
+    Node ret = d_nm->mkConst(false);
     return BagsRewriteResponse(ret, Rewrite::EQ_CONST_FALSE);
   }
 
   // standard ordering
   if (n[0] > n[1])
   {
-    Node ret = NodeManager::currentNM()->mkNode(kind::EQUAL, n[1], n[0]);
+    Node ret = d_nm->mkNode(kind::EQUAL, n[1], n[0]);
     return BagsRewriteResponse(ret, Rewrite::EQ_SYM);
   }
   return BagsRewriteResponse(n, Rewrite::NONE);
index 8425e3b1fa1729c591388ea1f587ce341c1c03cc..fb76fb1c22a99557fdaac9beac69decd5941df71 100644 (file)
@@ -80,7 +80,7 @@ class BagsRewriter : public TheoryRewriter
   /**
    * rewrites for n include:
    * - (bag.count x emptybag) = 0
-   * - (bag.count x (mkBag x c) = c where c > 0 is a constant
+   * - (bag.count x (bag x c) = c
    * - otherwise = n
    */
   BagsRewriteResponse rewriteBagCount(const TNode& n) const;
@@ -213,6 +213,8 @@ class BagsRewriter : public TheoryRewriter
  private:
   /** Reference to the rewriter statistics. */
   NodeManager* d_nm;
+  Node d_zero;
+  Node d_one;
   /** Reference to the rewriter statistics. */
   HistogramStat<Rewrite>* d_statistics;
 }; /* class TheoryBagsRewriter */
diff --git a/src/theory/bags/infer_info.cpp b/src/theory/bags/infer_info.cpp
new file mode 100644 (file)
index 0000000..1244a43
--- /dev/null
@@ -0,0 +1,111 @@
+/*********************                                                        */
+/*! \file infer_info.cpp
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Implementation of inference information utility.
+ **/
+
+#include "theory/bags/infer_info.h"
+
+#include "theory/bags/inference_manager.h"
+
+namespace CVC4 {
+namespace theory {
+namespace bags {
+
+const char* toString(Inference i)
+{
+  switch (i)
+  {
+    case Inference::NONE: return "NONE";
+    case Inference::BAG_MK_BAG: return "BAG_MK_BAG";
+    case Inference::BAG_EQUALITY: return "BAG_EQUALITY";
+    case Inference::BAG_DISEQUALITY: return "BAG_DISEQUALITY";
+    case Inference::BAG_EMPTY: return "BAG_EMPTY";
+    case Inference::BAG_UNION_DISJOINT: return "BAG_UNION_DISJOINT";
+    case Inference::BAG_UNION_MAX: return "BAG_UNION_MAX";
+    case Inference::BAG_INTERSECTION_MIN: return "BAG_INTERSECTION_MIN";
+    case Inference::BAG_DIFFERENCE_SUBTRACT: return "BAG_DIFFERENCE_SUBTRACT";
+    case Inference::BAG_DIFFERENCE_REMOVE: return "BAG_DIFFERENCE_REMOVE";
+    case Inference::BAG_DUPLICATE_REMOVAL: return "BAG_DUPLICATE_REMOVAL";
+    default: return "?";
+  }
+}
+
+std::ostream& operator<<(std::ostream& out, Inference i)
+{
+  out << toString(i);
+  return out;
+}
+
+InferInfo::InferInfo() : d_id(Inference::NONE) {}
+
+bool InferInfo::process(TheoryInferenceManager* im, bool asLemma)
+{
+  Node lemma = d_conclusion;
+  if (d_premises.size() >= 2)
+  {
+    Node andNode = NodeManager::currentNM()->mkNode(kind::AND, d_premises);
+    lemma = andNode.impNode(lemma);
+  }
+  else if (d_premises.size() == 1)
+  {
+    lemma = d_premises[0].impNode(lemma);
+  }
+  if (asLemma)
+  {
+    TrustNode trustedLemma = TrustNode::mkTrustLemma(lemma, nullptr);
+    return im->trustedLemma(trustedLemma);
+  }
+  Unimplemented();
+}
+
+bool InferInfo::isTrivial() const
+{
+  Assert(!d_conclusion.isNull());
+  return d_conclusion.isConst() && d_conclusion.getConst<bool>();
+}
+
+bool InferInfo::isConflict() const
+{
+  Assert(!d_conclusion.isNull());
+  return d_conclusion.isConst() && !d_conclusion.getConst<bool>();
+}
+
+bool InferInfo::isFact() const
+{
+  Assert(!d_conclusion.isNull());
+  TNode atom =
+      d_conclusion.getKind() == kind::NOT ? d_conclusion[0] : d_conclusion;
+  return !atom.isConst() && atom.getKind() != kind::OR;
+}
+
+Node InferInfo::getPremises() const
+{
+  // d_noExplain is a subset of d_ant
+  NodeManager* nm = NodeManager::currentNM();
+  return nm->mkAnd(d_premises);
+}
+
+std::ostream& operator<<(std::ostream& out, const InferInfo& ii)
+{
+  out << "(infer " << ii.d_id << " " << ii.d_conclusion << std::endl;
+  if (!ii.d_premises.empty())
+  {
+    out << " :premise (" << ii.d_premises << ")" << std::endl;
+  }
+
+  out << ")";
+  return out;
+}
+
+}  // namespace bags
+}  // namespace theory
+}  // namespace CVC4
diff --git a/src/theory/bags/infer_info.h b/src/theory/bags/infer_info.h
new file mode 100644 (file)
index 0000000..3edbef7
--- /dev/null
@@ -0,0 +1,128 @@
+/*********************                                                        */
+/*! \file infer_info.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Inference information utility
+ **/
+
+#include "cvc4_private.h"
+
+#ifndef CVC4__THEORY__BAGS__INFER_INFO_H
+#define CVC4__THEORY__BAGS__INFER_INFO_H
+
+#include <map>
+#include <vector>
+
+#include "expr/node.h"
+#include "theory/theory_inference.h"
+
+namespace CVC4 {
+namespace theory {
+namespace bags {
+
+/**
+ * Types of inferences used in the procedure
+ */
+enum class Inference : uint32_t
+{
+  NONE,
+  BAG_MK_BAG,
+  BAG_EQUALITY,
+  BAG_DISEQUALITY,
+  BAG_EMPTY,
+  BAG_UNION_DISJOINT,
+  BAG_UNION_MAX,
+  BAG_INTERSECTION_MIN,
+  BAG_DIFFERENCE_SUBTRACT,
+  BAG_DIFFERENCE_REMOVE,
+  BAG_DUPLICATE_REMOVAL
+};
+
+/**
+ * Converts an inference to a string. Note: This function is also used in
+ * `safe_print()`. Changing this functions name or signature will result in
+ * `safe_print()` printing "<unsupported>" instead of the proper strings for
+ * the enum values.
+ *
+ * @param i The inference
+ * @return The name of the inference
+ */
+const char* toString(Inference i);
+
+/**
+ * Writes an inference name to a stream.
+ *
+ * @param out The stream to write to
+ * @param i The inference to write to the stream
+ * @return The stream
+ */
+std::ostream& operator<<(std::ostream& out, Inference i);
+
+class InferenceManager;
+
+/**
+ * An inference. This is a class to track an unprocessed call to either
+ * send a fact, lemma, or conflict that is waiting to be asserted to the
+ * equality engine or sent on the output channel.
+ */
+class InferInfo : public TheoryInference
+{
+ public:
+  InferInfo();
+  ~InferInfo() {}
+  /** Process this inference */
+  bool process(TheoryInferenceManager* im, bool asLemma) override;
+  /** The inference identifier */
+  Inference d_id;
+  /** The conclusion */
+  Node d_conclusion;
+  /**
+   * The premise(s) of the inference, interpreted conjunctively. These are
+   * literals that currently hold in the equality engine.
+   */
+  std::vector<Node> d_premises;
+
+  /**
+   * A list of new skolems introduced as a result of this inference. They
+   * are mapped to by a length status, indicating the length constraint that
+   * can be assumed for them.
+   */
+  std::vector<Node> d_newSkolem;
+  /**  Is this infer info trivial? True if d_conc is true. */
+  bool isTrivial() const;
+  /**
+   * Does this infer info correspond to a conflict? True if d_conc is false
+   * and it has no new premises (d_noExplain).
+   */
+  bool isConflict() const;
+  /**
+   * Does this infer info correspond to a "fact". A fact is an inference whose
+   * conclusion should be added as an equality or predicate to the equality
+   * engine with no new external premises (d_noExplain).
+   */
+  bool isFact() const;
+  /** Get premises */
+  Node getPremises() const;
+};
+
+/**
+ * Writes an inference info to a stream.
+ *
+ * @param out The stream to write to
+ * @param ii The inference info to write to the stream
+ * @return The stream
+ */
+std::ostream& operator<<(std::ostream& out, const InferInfo& ii);
+
+}  // namespace bags
+}  // namespace theory
+}  // namespace CVC4
+
+#endif /* CVC4__THEORY__BAGS__INFER_INFO_H */
diff --git a/src/theory/bags/inference_generator.cpp b/src/theory/bags/inference_generator.cpp
new file mode 100644 (file)
index 0000000..759ea1f
--- /dev/null
@@ -0,0 +1,273 @@
+/*********************                                                        */
+/*! \file inference_generator.cpp
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Inference generator utility
+ **/
+
+#include "inference_generator.h"
+
+#include "expr/attribute.h"
+#include "expr/bound_var_manager.h"
+#include "expr/skolem_manager.h"
+#include "theory/uf/equality_engine.h"
+
+namespace CVC4 {
+namespace theory {
+namespace bags {
+
+InferenceGenerator::InferenceGenerator(SolverState* state) : d_state(state)
+{
+  d_nm = NodeManager::currentNM();
+  d_sm = d_nm->getSkolemManager();
+  d_true = d_nm->mkConst(true);
+  d_zero = d_nm->mkConst(Rational(0));
+  d_one = d_nm->mkConst(Rational(1));
+}
+
+InferInfo InferenceGenerator::mkBag(Node n, Node e)
+{
+  Assert(n.getKind() == kind::MK_BAG);
+  Assert(e.getType() == n.getType().getBagElementType());
+
+  InferInfo inferInfo;
+  inferInfo.d_id = Inference::BAG_MK_BAG;
+  Node count = getMultiplicitySkolem(e, n, inferInfo);
+  if (n[0] == e)
+  {
+    // TODO: refactor this with the rewriter
+    // (=> true (= (bag.count e (bag e c)) c))
+    inferInfo.d_conclusion = count.eqNode(n[1]);
+  }
+  else
+  {
+    // (=>
+    //   true
+    //   (= (bag.count e (bag x c)) (ite (= e x) c 0)))
+
+    Node same = d_nm->mkNode(kind::EQUAL, n[0], e);
+    Node ite = d_nm->mkNode(kind::ITE, same, n[1], d_zero);
+    Node equal = count.eqNode(ite);
+    inferInfo.d_conclusion = equal;
+  }
+  return inferInfo;
+}
+
+InferInfo InferenceGenerator::bagEquality(Node n, Node e)
+{
+  Assert(n.getKind() == kind::EQUAL && n[0].getType().isBag());
+  Assert(e.getType() == n[0].getType().getBagElementType());
+
+  Node A = n[0];
+  Node B = n[1];
+  InferInfo inferInfo;
+  inferInfo.d_id = Inference::BAG_EQUALITY;
+  inferInfo.d_premises.push_back(n);
+  Node countA = getMultiplicitySkolem(e, A, inferInfo);
+  Node countB = getMultiplicitySkolem(e, B, inferInfo);
+
+  Node equal = countA.eqNode(countB);
+  inferInfo.d_conclusion = equal;
+  return inferInfo;
+}
+
+struct BagsDeqAttributeId
+{
+};
+typedef expr::Attribute<BagsDeqAttributeId, Node> BagsDeqAttribute;
+
+InferInfo InferenceGenerator::bagDisequality(Node n)
+{
+  Assert(n.getKind() == kind::NOT && n[0].getKind() == kind::EQUAL);
+  Assert(n[0][0].getType().isBag());
+
+  Node A = n[0][0];
+  Node B = n[0][1];
+
+  InferInfo inferInfo;
+  inferInfo.d_id = Inference::BAG_DISEQUALITY;
+
+  TypeNode elementType = A.getType().getBagElementType();
+
+  BoundVarManager* bvm = d_nm->getBoundVarManager();
+  Node element = bvm->mkBoundVar<BagsDeqAttribute>(n, elementType);
+  SkolemManager* sm = d_nm->getSkolemManager();
+  Node skolem =
+      sm->mkSkolem(element,
+                   n,
+                   "bag_disequal",
+                   "an extensional lemma for disequality of two bags");
+
+  inferInfo.d_newSkolem.push_back(skolem);
+
+  Node countA = getMultiplicitySkolem(skolem, A, inferInfo);
+  Node countB = getMultiplicitySkolem(skolem, B, inferInfo);
+
+  Node disEqual = countA.eqNode(countB).notNode();
+
+  inferInfo.d_premises.push_back(n);
+  inferInfo.d_conclusion = disEqual;
+  return inferInfo;
+}
+
+InferInfo InferenceGenerator::bagEmpty(Node e)
+{
+  EmptyBag emptyBag = EmptyBag(d_nm->mkBagType(e.getType()));
+  Node empty = d_nm->mkConst(emptyBag);
+  InferInfo inferInfo;
+  inferInfo.d_id = Inference::BAG_EMPTY;
+  Node count = getMultiplicitySkolem(e, empty, inferInfo);
+
+  Node equal = count.eqNode(d_zero);
+  inferInfo.d_conclusion = equal;
+  return inferInfo;
+}
+
+InferInfo InferenceGenerator::unionDisjoint(Node n, Node e)
+{
+  Assert(n.getKind() == kind::UNION_DISJOINT && n[0].getType().isBag());
+  Assert(e.getType() == n[0].getType().getBagElementType());
+
+  Node A = n[0];
+  Node B = n[1];
+  InferInfo inferInfo;
+  inferInfo.d_id = Inference::BAG_UNION_DISJOINT;
+
+  Node countA = getMultiplicitySkolem(e, A, inferInfo);
+  Node countB = getMultiplicitySkolem(e, B, inferInfo);
+  Node count = getMultiplicitySkolem(e, n, inferInfo);
+
+  Node sum = d_nm->mkNode(kind::PLUS, countA, countB);
+  Node equal = count.eqNode(sum);
+
+  inferInfo.d_conclusion = equal;
+  return inferInfo;
+}
+
+InferInfo InferenceGenerator::unionMax(Node n, Node e)
+{
+  Assert(n.getKind() == kind::UNION_MAX && n[0].getType().isBag());
+  Assert(e.getType() == n[0].getType().getBagElementType());
+
+  Node A = n[0];
+  Node B = n[1];
+  InferInfo inferInfo;
+  inferInfo.d_id = Inference::BAG_UNION_MAX;
+
+  Node countA = getMultiplicitySkolem(e, A, inferInfo);
+  Node countB = getMultiplicitySkolem(e, B, inferInfo);
+  Node count = getMultiplicitySkolem(e, n, inferInfo);
+
+  Node gt = d_nm->mkNode(kind::GT, countA, countB);
+  Node max = d_nm->mkNode(kind::ITE, gt, countA, countB);
+  Node equal = count.eqNode(max);
+
+  inferInfo.d_conclusion = equal;
+  return inferInfo;
+}
+
+InferInfo InferenceGenerator::intersection(Node n, Node e)
+{
+  Assert(n.getKind() == kind::INTERSECTION_MIN && n[0].getType().isBag());
+  Assert(e.getType() == n[0].getType().getBagElementType());
+
+  Node A = n[0];
+  Node B = n[1];
+  InferInfo inferInfo;
+  inferInfo.d_id = Inference::BAG_INTERSECTION_MIN;
+
+  Node countA = getMultiplicitySkolem(e, A, inferInfo);
+  Node countB = getMultiplicitySkolem(e, B, inferInfo);
+  Node count = getMultiplicitySkolem(e, n, inferInfo);
+
+  Node lt = d_nm->mkNode(kind::LT, countA, countB);
+  Node min = d_nm->mkNode(kind::ITE, lt, countA, countB);
+  Node equal = count.eqNode(min);
+  inferInfo.d_conclusion = equal;
+  return inferInfo;
+}
+
+InferInfo InferenceGenerator::differenceSubtract(Node n, Node e)
+{
+  Assert(n.getKind() == kind::DIFFERENCE_SUBTRACT && n[0].getType().isBag());
+  Assert(e.getType() == n[0].getType().getBagElementType());
+
+  Node A = n[0];
+  Node B = n[1];
+  InferInfo inferInfo;
+  inferInfo.d_id = Inference::BAG_DIFFERENCE_SUBTRACT;
+
+  Node countA = getMultiplicitySkolem(e, A, inferInfo);
+  Node countB = getMultiplicitySkolem(e, B, inferInfo);
+  Node count = getMultiplicitySkolem(e, n, inferInfo);
+
+  Node subtract = d_nm->mkNode(kind::MINUS, countA, countB);
+  Node gte = d_nm->mkNode(kind::GEQ, countA, countB);
+  Node difference = d_nm->mkNode(kind::ITE, gte, subtract, d_zero);
+  Node equal = count.eqNode(difference);
+  inferInfo.d_conclusion = equal;
+  return inferInfo;
+}
+
+InferInfo InferenceGenerator::differenceRemove(Node n, Node e)
+{
+  Assert(n.getKind() == kind::DIFFERENCE_REMOVE && n[0].getType().isBag());
+  Assert(e.getType() == n[0].getType().getBagElementType());
+
+  Node A = n[0];
+  Node B = n[1];
+  InferInfo inferInfo;
+  inferInfo.d_id = Inference::BAG_DIFFERENCE_REMOVE;
+
+  Node countA = getMultiplicitySkolem(e, A, inferInfo);
+  Node countB = getMultiplicitySkolem(e, B, inferInfo);
+  Node count = getMultiplicitySkolem(e, n, inferInfo);
+
+  Node notInB = d_nm->mkNode(kind::EQUAL, countB, d_zero);
+  Node difference = d_nm->mkNode(kind::ITE, notInB, countA, d_zero);
+  Node equal = count.eqNode(difference);
+  inferInfo.d_conclusion = equal;
+  return inferInfo;
+}
+
+InferInfo InferenceGenerator::duplicateRemoval(Node n, Node e)
+{
+  Assert(n.getKind() == kind::DUPLICATE_REMOVAL && n[0].getType().isBag());
+  Assert(e.getType() == n[0].getType().getBagElementType());
+
+  Node A = n[0];
+  InferInfo inferInfo;
+  inferInfo.d_id = Inference::BAG_DUPLICATE_REMOVAL;
+
+  Node countA = getMultiplicitySkolem(e, A, inferInfo);
+  Node count = getMultiplicitySkolem(e, n, inferInfo);
+
+  Node gte = d_nm->mkNode(kind::GEQ, countA, d_one);
+  Node ite = d_nm->mkNode(kind::ITE, gte, d_one, d_zero);
+  Node equal = count.eqNode(ite);
+  inferInfo.d_conclusion = equal;
+  return inferInfo;
+}
+
+Node InferenceGenerator::getMultiplicitySkolem(Node element,
+                                               Node bag,
+                                               InferInfo& inferInfo)
+{
+  Node count = d_nm->mkNode(kind::BAG_COUNT, element, bag);
+  Node skolem = d_state->registerBagElement(count);
+  eq::EqualityEngine* ee = d_state->getEqualityEngine();
+  ee->assertEquality(skolem.eqNode(count), true, d_nm->mkConst(true));
+  inferInfo.d_newSkolem.push_back(skolem);
+  return skolem;
+}
+
+}  // namespace bags
+}  // namespace theory
+}  // namespace CVC4
diff --git a/src/theory/bags/inference_generator.h b/src/theory/bags/inference_generator.h
new file mode 100644 (file)
index 0000000..b569970
--- /dev/null
@@ -0,0 +1,172 @@
+/*********************                                                        */
+/*! \file inference_generator.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Mudathir Mohamed
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Inference generator utility
+ **/
+
+#include "cvc4_private.h"
+
+#ifndef CVC4__THEORY__BAGS__INFERENCE_GENERATOR_H
+#define CVC4__THEORY__BAGS__INFERENCE_GENERATOR_H
+
+#include <map>
+#include <vector>
+
+#include "expr/node.h"
+#include "infer_info.h"
+#include "theory/bags/solver_state.h"
+
+namespace CVC4 {
+namespace theory {
+namespace bags {
+
+/**
+ * An inference generator class. This class is used by the core solver to
+ * generate lemmas
+ */
+class InferenceGenerator
+{
+ public:
+  InferenceGenerator(SolverState* state);
+
+  /**
+   * @param n is (bag x c) of type (Bag E)
+   * @param e is a node of type E
+   * @return an inference that represents the following implication
+   * (=>
+   *   true
+   *   (= (bag.count e (bag x c)) (ite (= e x) c 0)))
+   */
+  InferInfo mkBag(Node n, Node e);
+
+  /**
+   * @param n is (= A B) where A, B are bags of type (Bag E)
+   * @param e is a node of Type E
+   * @return an inference that represents the following implication
+   * (=>
+   *   (= A B)
+   *   (= (count e A) (count e B)))
+   */
+  InferInfo bagEquality(Node n, Node e);
+  /**
+   * @param n is (not (= A B)) where A, B are bags of type (Bag E)
+   * @return an inference that represents the following implication
+   * (=>
+   *   (not (= A B))
+   *   (not (= (count e A) (count e B))))
+   *   where e is a fresh skolem of type E
+   */
+  InferInfo bagDisequality(Node n);
+  /**
+   * @param e is a node of Type E
+   * @return an inference that represents the following implication
+   * (=>
+   *   true
+   *   (= 0 (count e (as emptybag (Bag E)))))
+   */
+  InferInfo bagEmpty(Node e);
+  /**
+   * @param n is (union_disjoint A B) where A, B are bags of type (Bag E)
+   * @param e is a node of Type E
+   * @return an inference that represents the following implication
+   * (=>
+   *   true
+   *   (= (count e k_{(union_disjoint A B)})
+   *      (+ (count e A) (count e B))))
+   *  where k_{(union_disjoint A B)} is a unique purification skolem
+   *  for (union_disjoint A B)
+   */
+  InferInfo unionDisjoint(Node n, Node e);
+  /**
+   * @param n is (union_disjoint A B) where A, B are bags of type (Bag E)
+   * @param e is a node of Type E
+   * @return an inference that represents the following implication
+   * (=>
+   *   true
+   *   (= (count e (union_max A B))
+   *     (ite
+   *     (> (count e A) (count e B))
+   *     (count e A)
+   *     (count e B)))))
+   */
+  InferInfo unionMax(Node n, Node e);
+  /**
+   * @param n is (intersection_min A B) where A, B are bags of type (Bag E)
+   * @param e is a node of Type E
+   * @return an inference that represents the following implication
+   * (=>
+   *   true
+   *   (= (count e (intersection_min A B))
+   *     (ite(
+   *     (< (count e A) (count e B))
+   *     (count e A)
+   *     (count e B)))))
+   */
+  InferInfo intersection(Node n, Node e);
+  /**
+   * @param n is (difference_subtract A B) where A, B are bags of type (Bag E)
+   * @param e is a node of Type E
+   * @return an inference that represents the following implication
+   * (=>
+   *   true
+   *   (= (count e (difference_subtract A B))
+   *     (ite
+   *        (>= (count e A) (count e B))
+   *        (- (count e A) (count e B))
+   *        0))))
+   */
+  InferInfo differenceSubtract(Node n, Node e);
+  /**
+   * @param n is (difference_remove A B) where A, B are bags of type (Bag E)
+   * @param e is a node of Type E
+   * @return an inference that represents the following implication
+   * (=>
+   *   true
+   *   (= (count e (difference_remove A B))
+   *     (ite
+   *        (= (count e B) 0)
+   *        (count e A)
+   *        0))))
+   */
+  InferInfo differenceRemove(Node n, Node e);
+  /**
+   * @param n is (duplicate_removal A) where A is a bag of type (Bag E)
+   * @param e is a node of Type E
+   * @return an inference that represents the following implication
+   * (=>
+   *   true
+   *   (= (count e (duplicate_removal A))
+   *     (ite (>= (count e A) 1) 1 0))))
+   */
+  InferInfo duplicateRemoval(Node n, Node e);
+
+  /**
+   * @param element of type T
+   * @param bag of type (bag T)
+   * @param inferInfo to store new skolem
+   * @return  a skolem for (bag.count element bag)
+   */
+  Node getMultiplicitySkolem(Node element, Node bag, InferInfo& inferInfo);
+
+ private:
+  NodeManager* d_nm;
+  SkolemManager* d_sm;
+  SolverState* d_state;
+  Node d_true;
+  Node d_zero;
+  Node d_one;
+};
+
+}  // namespace bags
+}  // namespace theory
+}  // namespace CVC4
+
+#endif /* CVC4__THEORY__BAGS__INFERENCE_GENERATOR_H */
index eb695725e9e334e762042ab004a350ea370e76ff..0ccd06922c30e9ddf0dd335d30c86c70fc42c594 100644 (file)
@@ -30,6 +30,20 @@ InferenceManager::InferenceManager(Theory& t,
   d_false = NodeManager::currentNM()->mkConst(false);
 }
 
+void InferenceManager::doPending()
+{
+  doPendingFacts();
+  if (d_state.isInConflict())
+  {
+    // just clear the pending vectors, nothing else to do
+    clearPendingLemmas();
+    clearPendingPhaseRequirements();
+    return;
+  }
+  doPendingLemmas();
+  doPendingPhaseRequirements();
+}
+
 }  // namespace bags
 }  // namespace theory
 }  // namespace CVC4
index 90d188d333de2206390fefea8a7ef1dc61fcc9af..67025548ceb4c8b387ff1a24411237b6a464c62a 100644 (file)
@@ -38,6 +38,16 @@ class InferenceManager : public InferenceManagerBuffered
  public:
   InferenceManager(Theory& t, SolverState& s, ProofNodeManager* pnm);
 
+  /**
+   * Do pending method. This processes all pending facts, lemmas and pending
+   * phase requests based on the policy of this manager. This means that
+   * we process the pending facts first and abort if in conflict. Otherwise, we
+   * process the pending lemmas and then the pending phase requirements.
+   * Notice that we process the pending lemmas even if there were facts.
+   */
+  // TODO: refactor this before merge with theory of strings
+  void doPending();
+
  private:
   /** constants */
   Node d_true;
index 88c41a96127251ce82c9017938849a06ca14eeec..2a5c091255c762dde41c8bebf7825484459b539d 100644 (file)
@@ -158,7 +158,7 @@ Node NormalForm::evaluateBinaryOperation(const TNode& n,
   remainderOfA(elements, elementsB, itB);
 
   Trace("bags-evaluate") << "elements: " << elements << std::endl;
-  Node bag = constructBagFromElements(n.getType(), elements);
+  Node bag = constructConstantBagFromElements(n.getType(), elements);
   Trace("bags-evaluate") << "bag: " << bag << std::endl;
   return bag;
 }
@@ -187,7 +187,7 @@ std::map<Node, Rational> NormalForm::getBagElements(TNode n)
   return elements;
 }
 
-Node NormalForm::constructBagFromElements(
+Node NormalForm::constructConstantBagFromElements(
     TypeNode t, const std::map<Node, Rational>& elements)
 {
   Assert(t.isBag());
@@ -209,6 +209,26 @@ Node NormalForm::constructBagFromElements(
   return bag;
 }
 
+Node NormalForm::constructBagFromElements(TypeNode t,
+                                          const std::map<Node, Node>& elements)
+{
+  Assert(t.isBag());
+  NodeManager* nm = NodeManager::currentNM();
+  if (elements.empty())
+  {
+    return nm->mkConst(EmptyBag(t));
+  }
+  TypeNode elementType = t.getBagElementType();
+  std::map<Node, Node>::const_reverse_iterator it = elements.rbegin();
+  Node bag = nm->mkBag(elementType, it->first, it->second);
+  while (++it != elements.rend())
+  {
+    Node n = nm->mkBag(elementType, it->first, it->second);
+    bag = nm->mkNode(UNION_DISJOINT, n, bag);
+  }
+  return bag;
+}
+
 Node NormalForm::evaluateMakeBag(TNode n)
 {
   // the case where n is const should be handled earlier.
@@ -262,7 +282,7 @@ Node NormalForm::evaluateDuplicateRemoval(TNode n)
   {
     it->second = one;
   }
-  Node bag = constructBagFromElements(n[0].getType(), newElements);
+  Node bag = constructConstantBagFromElements(n[0].getType(), newElements);
   return bag;
 }
 
@@ -624,7 +644,7 @@ Node NormalForm::evaluateFromSet(TNode n)
     bagElements[element] = one;
   }
   TypeNode bagType = nm->mkBagType(n[0].getType().getSetElementType());
-  Node bag = constructBagFromElements(bagType, bagElements);
+  Node bag = constructConstantBagFromElements(bagType, bagElements);
   return bag;
 }
 
index e88fe6fefce79932fa2b7b99130c041209fa8e4a..acb13bb44a5ea9c4bb4fdfdccc91f05907773ba3 100644 (file)
@@ -61,9 +61,19 @@ class NormalForm
    *        multiplicities
    * @return a constant bag that contains
    */
-  static Node constructBagFromElements(
+  static Node constructConstantBagFromElements(
       TypeNode t, const std::map<Node, Rational>& elements);
 
+  /**
+   * construct a constant bag from node elements
+   * @param t the type of the returned bag
+   * @param elements a map whose keys are constant elements and values are
+   *        multiplicities
+   * @return a constant bag that contains
+   */
+  static Node constructBagFromElements(TypeNode t,
+                                       const std::map<Node, Node>& elements);
+
  private:
   /**
    * a high order helper function that return a constant bag that is the result
index d58b686579ab5910c64e2557afe642f7d1043727..744f6de9fde04735afda8e8509d9bc71aa734426 100644 (file)
 
 #include "theory/bags/solver_state.h"
 
+#include "expr/attribute.h"
+#include "expr/bound_var_manager.h"
+#include "expr/skolem_manager.h"
+#include "theory/uf/equality_engine.h"
+
 using namespace std;
 using namespace CVC4::kind;
 
@@ -30,6 +35,51 @@ SolverState::SolverState(context::Context* c,
   d_false = NodeManager::currentNM()->mkConst(false);
 }
 
+struct BagsCountAttributeId
+{
+};
+typedef expr::Attribute<BagsCountAttributeId, Node> BagsCountAttribute;
+
+void SolverState::registerClass(TNode n)
+{
+  TypeNode t = n.getType();
+  if (!t.isBag())
+  {
+    return;
+  }
+  d_bags.insert(n);
+}
+
+Node SolverState::registerBagElement(TNode n)
+{
+  Assert(n.getKind() == BAG_COUNT);
+  Node element = n[0];
+  TypeNode elementType = element.getType();
+  Node bag = n[1];
+  d_elements[elementType].insert(element);
+  NodeManager* nm = NodeManager::currentNM();
+  BoundVarManager* bvm = nm->getBoundVarManager();
+  Node multiplicity = bvm->mkBoundVar<BagsCountAttribute>(n, nm->integerType());
+  Node equal = n.eqNode(multiplicity);
+  SkolemManager* sm = nm->getSkolemManager();
+  Node skolem = sm->mkSkolem(
+      multiplicity,
+      equal,
+      "bag_multiplicity",
+      "an extensional lemma for multiplicity of an element in a bag");
+  d_count[bag][element] = skolem;
+  Trace("bags::SolverState::registerBagElement")
+      << "New skolem: " << skolem << " for " << n << std::endl;
+
+  return skolem;
+}
+
+std::set<Node>& SolverState::getBags() { return d_bags; }
+
+std::set<Node>& SolverState::getElements(TypeNode t) { return d_elements[t]; }
+
+std::map<Node, Node>& SolverState::getBagElements(Node B) { return d_count[B]; }
+
 }  // namespace bags
 }  // namespace theory
 }  // namespace CVC4
index 9a4bfdae7d5b92e1bf7b81c4e78166c8b6f545df..8d70ee8f73d1f21fe97b2f6937c4f129ccfe7f59 100644 (file)
@@ -31,10 +31,24 @@ class SolverState : public TheoryState
  public:
   SolverState(context::Context* c, context::UserContext* u, Valuation val);
 
+  void registerClass(TNode n);
+
+  Node registerBagElement(TNode n);
+
+  std::set<Node>& getBags();
+
+  std::set<Node>& getElements(TypeNode t);
+
+  std::map<Node, Node>& getBagElements(Node B);
+
  private:
   /** constants */
   Node d_true;
   Node d_false;
+  std::set<Node> d_bags;
+  std::map<TypeNode, std::set<Node>> d_elements;
+  /** bag -> element -> multiplicity */
+  std::map<Node, std::map<Node, Node>> d_count;
 }; /* class SolverState */
 
 }  // namespace bags
index 6ba1ad87a19c62c0e4385331c970b58f156aaea6..b3251e464f2ee71b472fc2df1ec88917b544c430 100644 (file)
@@ -14,6 +14,8 @@
 
 #include "theory/bags/theory_bags.h"
 
+#include "theory/theory_model.h"
+
 using namespace CVC4::kind;
 
 namespace CVC4 {
@@ -28,10 +30,13 @@ TheoryBags::TheoryBags(context::Context* c,
                        ProofNodeManager* pnm)
     : Theory(THEORY_BAGS, c, u, out, valuation, logicInfo, pnm),
       d_state(c, u, valuation),
-      d_im(*this, d_state, pnm),
+      d_im(*this, d_state, nullptr),
+      d_ig(&d_state),
       d_notify(*this, d_im),
       d_statistics(),
-      d_rewriter(&d_statistics.d_rewrites)
+      d_rewriter(&d_statistics.d_rewrites),
+      d_termReg(d_state, d_im),
+      d_solver(d_state, d_im, d_termReg)
 {
   // use the official theory state and inference manager objects
   d_theoryState = &d_state;
@@ -70,7 +75,60 @@ void TheoryBags::finishInit()
   d_equalityEngine->addFunctionKind(BAG_TO_SET);
 }
 
-void TheoryBags::postCheck(Effort level) {}
+void TheoryBags::postCheck(Effort effort)
+{
+  d_im.doPendingFacts();
+  // TODO: clean this before merge Assert(d_strat.isStrategyInit());
+  if (!d_state.isInConflict() && !d_valuation.needCheck())
+  // TODO: clean this before merge && d_strat.hasStrategyEffort(e))
+  {
+    Trace("bags::TheoryBags::postCheck") << "effort: " << std::endl;
+
+    // TODO: clean this before merge ++(d_statistics.d_checkRuns);
+    bool sentLemma = false;
+    bool hadPending = false;
+    Trace("bags-check") << "Full effort check..." << std::endl;
+    do
+    {
+      d_im.reset();
+      // TODO: clean this before merge ++(d_statistics.d_strategyRuns);
+      Trace("bags-check") << "  * Run strategy..." << std::endl;
+      // TODO: clean this before merge runStrategy(e);
+
+      d_solver.postCheck();
+
+      // remember if we had pending facts or lemmas
+      hadPending = d_im.hasPending();
+      // Send the facts *and* the lemmas. We send lemmas regardless of whether
+      // we send facts since some lemmas cannot be dropped. Other lemmas are
+      // otherwise avoided by aborting the strategy when a fact is ready.
+      d_im.doPending();
+      // Did we successfully send a lemma? Notice that if hasPending = true
+      // and sentLemma = false, then the above call may have:
+      // (1) had no pending lemmas, but successfully processed pending facts,
+      // (2) unsuccessfully processed pending lemmas.
+      // In either case, we repeat the strategy if we are not in conflict.
+      sentLemma = d_im.hasSentLemma();
+      if (Trace.isOn("bags-check"))
+      {
+        // TODO: clean this Trace("bags-check") << "  ...finish run strategy: ";
+        Trace("bags-check") << (hadPending ? "hadPending " : "");
+        Trace("bags-check") << (sentLemma ? "sentLemma " : "");
+        Trace("bags-check") << (d_state.isInConflict() ? "conflict " : "");
+        if (!hadPending && !sentLemma && !d_state.isInConflict())
+        {
+          Trace("bags-check") << "(none)";
+        }
+        Trace("bags-check") << std::endl;
+      }
+      // repeat if we did not add a lemma or conflict, and we had pending
+      // facts or lemmas.
+    } while (!d_state.isInConflict() && !sentLemma && hadPending);
+  }
+  Trace("bags-check") << "Theory of bags, done check : " << effort << std::endl;
+  Assert(!d_im.hasPendingFact());
+  Assert(!d_im.hasPendingLemma());
+}
 
 void TheoryBags::notifyFact(TNode atom,
                             bool polarity,
@@ -80,8 +138,43 @@ void TheoryBags::notifyFact(TNode atom,
 }
 
 bool TheoryBags::collectModelValues(TheoryModel* m,
-                                    const std::set<Node>& termBag)
+                                    const std::set<Node>& termSet)
 {
+  Trace("bags-model") << "TheoryBags : Collect model values" << std::endl;
+
+  Trace("bags-model") << "Term set: " << termSet << std::endl;
+
+  // get the relevant bag equivalence classes
+  for (const Node& n : termSet)
+  {
+    TypeNode tn = n.getType();
+    if (!tn.isBag())
+    {
+      continue;
+    }
+    Node r = d_state.getRepresentative(n);
+    std::map<Node, Node> elements = d_state.getBagElements(r);
+    Trace("bags-model") << "Elements of bag " << n << " are: " << std::endl
+                        << elements << std::endl;
+    std::map<Node, Node> elementReps;
+    for (std::pair<Node, Node> pair : elements)
+    {
+      Node key = d_state.getRepresentative(pair.first);
+      Node value = d_state.getRepresentative(pair.second);
+      elementReps[key] = value;
+    }
+    Node rep = NormalForm::constructBagFromElements(tn, elementReps);
+    rep = Rewriter::rewrite(rep);
+
+    Trace("bags-model") << "rep of " << n << " is: " << rep << std::endl;
+    for (std::pair<Node, Node> pair : elementReps)
+    {
+      m->assertSkeleton(pair.first);
+      m->assertSkeleton(pair.second);
+    }
+    m->assertEquality(rep, n, true);
+    m->assertSkeleton(rep);
+  }
   return true;
 }
 
@@ -101,41 +194,73 @@ void TheoryBags::presolve() {}
 
 /**************************** eq::NotifyClass *****************************/
 
-void TheoryBags::eqNotifyNewClass(TNode t)
+void TheoryBags::eqNotifyNewClass(TNode n)
 {
-  Assert(false) << "Not implemented yet" << std::endl;
+  Kind k = n.getKind();
+  d_state.registerClass(n);
+  if (n.getKind() == MK_BAG)
+  {
+    // TODO: refactor this before merge
+    /*
+     * (bag x m) generates the lemma (and (= s (count x (bag x m))) (= s m))
+     * where s is a fresh skolem variable
+     */
+    NodeManager* nm = NodeManager::currentNM();
+    Node count = nm->mkNode(BAG_COUNT, n[0], n);
+    Node skolem = d_state.registerBagElement(count);
+    Node countSkolem = count.eqNode(skolem);
+    Node skolemMultiplicity = n[1].eqNode(skolem);
+    Node lemma = countSkolem.andNode(skolemMultiplicity);
+    TrustNode trustedLemma = TrustNode::mkTrustLemma(lemma, nullptr);
+    d_im.trustedLemma(trustedLemma);
+  }
+  if (k == BAG_COUNT)
+  {
+    /*
+     * (count x A) generates the lemma (= s (count x A))
+     * where s is a fresh skolem variable
+     */
+    Node skolem = d_state.registerBagElement(n);
+    Node lemma = n.eqNode(skolem);
+    TrustNode trustedLemma = TrustNode::mkTrustLemma(lemma, nullptr);
+    d_im.trustedLemma(trustedLemma);
+  }
 }
 
-void TheoryBags::eqNotifyMerge(TNode t1, TNode t2)
-{
-  Assert(false) << "Not implemented yet" << std::endl;
-}
+void TheoryBags::eqNotifyMerge(TNode n1, TNode n2) {}
 
-void TheoryBags::eqNotifyDisequal(TNode t1, TNode t2, TNode reason)
+void TheoryBags::eqNotifyDisequal(TNode n1, TNode n2, TNode reason)
 {
-  Assert(false) << "Not implemented yet" << std::endl;
+  TypeNode t1 = n1.getType();
+  if (t1.isBag())
+  {
+    InferInfo info = d_ig.bagDisequality(n1.eqNode(n2).notNode());
+    Node lemma = reason.impNode(info.d_conclusion);
+    TrustNode trustedLemma = TrustNode::mkTrustLemma(lemma, nullptr);
+    d_im.trustedLemma(trustedLemma);
+  }
 }
 
-void TheoryBags::NotifyClass::eqNotifyNewClass(TNode t)
+void TheoryBags::NotifyClass::eqNotifyNewClass(TNode n)
 {
   Debug("bags-eq") << "[bags-eq] eqNotifyNewClass:"
-                   << " t = " << t << std::endl;
-  d_theory.eqNotifyNewClass(t);
+                   << " n = " << n << std::endl;
+  d_theory.eqNotifyNewClass(n);
 }
 
-void TheoryBags::NotifyClass::eqNotifyMerge(TNode t1, TNode t2)
+void TheoryBags::NotifyClass::eqNotifyMerge(TNode n1, TNode n2)
 {
   Debug("bags-eq") << "[bags-eq] eqNotifyMerge:"
-                   << " t1 = " << t1 << " t2 = " << t2 << std::endl;
-  d_theory.eqNotifyMerge(t1, t2);
+                   << " n1 = " << n1 << " n2 = " << n2 << std::endl;
+  d_theory.eqNotifyMerge(n1, n2);
 }
 
-void TheoryBags::NotifyClass::eqNotifyDisequal(TNode t1, TNode t2, TNode reason)
+void TheoryBags::NotifyClass::eqNotifyDisequal(TNode n1, TNode n2, TNode reason)
 {
   Debug("bags-eq") << "[bags-eq] eqNotifyDisequal:"
-                   << " t1 = " << t1 << " t2 = " << t2 << " reason = " << reason
+                   << " n1 = " << n1 << " n2 = " << n2 << " reason = " << reason
                    << std::endl;
-  d_theory.eqNotifyDisequal(t1, t2, reason);
+  d_theory.eqNotifyDisequal(n1, n2, reason);
 }
 
 }  // namespace bags
index 03676c2d2c367d34b6a4ba91591588f3df4442cb..09784add9414ce30a31b1d17341791445d20064d 100644 (file)
 
 #include <memory>
 
+#include "theory/bags/bag_solver.h"
 #include "theory/bags/bags_rewriter.h"
 #include "theory/bags/bags_statistics.h"
 #include "theory/bags/inference_manager.h"
+#include "theory/bags/inference_generator.h"
 #include "theory/bags/solver_state.h"
 #include "theory/theory.h"
 #include "theory/theory_eq_notify.h"
@@ -58,7 +60,7 @@ class TheoryBags : public Theory
 
   //--------------------------------- standard check
   /** Post-check, called after the fact queue of the theory is processed. */
-  void postCheck(Effort level) override;
+  void postCheck(Effort effort) override;
   /** Notify fact */
   void notifyFact(TNode atom, bool pol, TNode fact, bool isInternal) override;
   //--------------------------------- end standard check
@@ -82,9 +84,9 @@ class TheoryBags : public Theory
         : TheoryEqNotifyClass(inferenceManager), d_theory(theory)
     {
     }
-    void eqNotifyNewClass(TNode t) override;
-    void eqNotifyMerge(TNode t1, TNode t2) override;
-    void eqNotifyDisequal(TNode t1, TNode t2, TNode reason) override;
+    void eqNotifyNewClass(TNode n) override;
+    void eqNotifyMerge(TNode n1, TNode n2) override;
+    void eqNotifyDisequal(TNode n1, TNode n2, TNode reason) override;
 
    private:
     TheoryBags& d_theory;
@@ -94,15 +96,21 @@ class TheoryBags : public Theory
   SolverState d_state;
   /** The inference manager */
   InferenceManager d_im;
+  /** The inference generator */
+  InferenceGenerator d_ig;
   /** Instance of the above class */
   NotifyClass d_notify;
   /** Statistics for the theory of bags. */
   BagsStatistics d_statistics;
   /** The theory rewriter for this theory. */
   BagsRewriter d_rewriter;
+  /** The term registry for this theory */
+  TermRegistry d_termReg;
+  /** the main solver for bags */
+  BagSolver d_solver;
 
-  void eqNotifyNewClass(TNode t);
-  void eqNotifyMerge(TNode t1, TNode t2);
+  void eqNotifyNewClass(TNode n);
+  void eqNotifyMerge(TNode n1, TNode n2);
   void eqNotifyDisequal(TNode t1, TNode t2, TNode reason);
 }; /* class TheoryBags */
 
index 929034e42f2faf5a224bd2604f3a1e0282cc8c15..66ec35277164aff8d6f410ef5e09082a07adc03a 100644 (file)
@@ -225,7 +225,7 @@ class BaseSolver
    *
    * This set contains a set of nodes that are not representatives of their
    * congruence class. This set is used to skip reasoning about terms in
-   * various inference schemas implemnted by this class.
+   * various inference schemas implemented by this class.
    */
   NodeSet d_congruent;
   /**
index 72757ef32ff007bce5043bda3cdd213271f9faf9..7ceb3b4b2f82871e0ac946dc1487bcb8f3c5425b 100644 (file)
@@ -1415,6 +1415,12 @@ set(regress_1_tests
   regress1/bug681.smt2
   regress1/bug694-Unapply1.scala-0.smt2
   regress1/bug800.smt2
+  regress1/bags/disequality.smt2
+  regress1/bags/subbag1.smt2
+  regress1/bags/subbag2.smt2
+  regress1/bags/union_disjoint.smt2
+  regress1/bags/union_max1.smt2
+  regress1/bags/union_max2.smt2
   regress1/bv/bench_38.delta.smt2
   regress1/bv/bug787.smt2
   regress1/bv/bug_extract_mult_leading_bit.smt2
diff --git a/test/regress/regress1/bags/disequality.smt2 b/test/regress/regress1/bags/disequality.smt2
new file mode 100644 (file)
index 0000000..e49935f
--- /dev/null
@@ -0,0 +1,14 @@
+(set-logic ALL)
+(set-info :status sat)
+(declare-fun A () (Bag Int))
+(declare-fun B () (Bag Int))
+(declare-fun C () (Bag Int))
+(declare-fun x () Int)
+(declare-fun y () Int)
+(assert (distinct A B C))
+(assert (> (bag.count x A) 0))
+(assert (> (bag.count y B) 0))
+(assert (= (bag.count x A) (bag.count x B)))
+(assert (= (bag.count y A) (bag.count y B)))
+(assert (distinct x y))
+(check-sat)
\ No newline at end of file
diff --git a/test/regress/regress1/bags/subbag1.smt2 b/test/regress/regress1/bags/subbag1.smt2
new file mode 100644 (file)
index 0000000..055e47a
--- /dev/null
@@ -0,0 +1,11 @@
+(set-logic ALL)
+(set-info :status unsat)
+(declare-fun A () (Bag Int))
+(declare-fun B () (Bag Int))
+(declare-fun x () Int)
+(assert (= x 1))
+(assert (subbag A B))
+(assert (subbag B A))
+(assert (= (bag.count x A) 5))
+(assert (= (bag.count x B) 10))
+(check-sat)
\ No newline at end of file
diff --git a/test/regress/regress1/bags/subbag2.smt2 b/test/regress/regress1/bags/subbag2.smt2
new file mode 100644 (file)
index 0000000..6d5cde3
--- /dev/null
@@ -0,0 +1,13 @@
+(set-logic ALL)
+(set-info :status sat)
+(declare-fun A () (Bag Int))
+(declare-fun B () (Bag Int))
+(declare-fun x () Int)
+(declare-fun y () Int)
+(assert (subbag A B))
+(assert (subbag B A))
+(assert (= (bag.count x A) x))
+(assert (= (bag.count y A) x))
+(assert (distinct x y))
+(assert (= (bag.count x B) (+ 1 y)))
+(check-sat)
\ No newline at end of file
diff --git a/test/regress/regress1/bags/union_disjoint.smt2 b/test/regress/regress1/bags/union_disjoint.smt2
new file mode 100644 (file)
index 0000000..d30ed4b
--- /dev/null
@@ -0,0 +1,10 @@
+(set-logic ALL)
+(set-info :status sat)
+(declare-fun A () (Bag Int))
+(declare-fun B () (Bag Int))
+(declare-fun x () Int)
+(declare-fun y () Int)
+(assert (= A (union_disjoint (bag x 1) (bag y 2))))
+(assert (= A (union_disjoint B (bag y 2))))
+(assert (= x y))
+(check-sat)
diff --git a/test/regress/regress1/bags/union_max1.smt2 b/test/regress/regress1/bags/union_max1.smt2
new file mode 100644 (file)
index 0000000..d278527
--- /dev/null
@@ -0,0 +1,10 @@
+(set-logic ALL)
+(set-info :status sat)
+(declare-fun A () (Bag Int))
+(declare-fun B () (Bag Int))
+(declare-fun x () Int)
+(declare-fun y () Int)
+(assert (= A (union_max (bag x 1) (bag y 2))))
+(assert (= A (union_disjoint B (bag y 2))))
+(assert (= x y))
+(check-sat)
\ No newline at end of file
diff --git a/test/regress/regress1/bags/union_max2.smt2 b/test/regress/regress1/bags/union_max2.smt2
new file mode 100644 (file)
index 0000000..dd4bcef
--- /dev/null
@@ -0,0 +1,11 @@
+(set-logic ALL)
+(set-info :status unsat)
+(declare-fun A () (Bag Int))
+(declare-fun B () (Bag Int))
+(declare-fun x () Int)
+(declare-fun y () Int)
+(assert (= A (union_max (bag x 1) (bag y 2))))
+(assert (= A (union_disjoint B (bag y 2))))
+(assert (= x y))
+(assert (distinct (as emptybag (Bag Int)) B))
+(check-sat)
\ No newline at end of file