Prepare theory of sets for dynamic allocation of equality engine (#4868)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 12 Aug 2020 03:53:58 +0000 (22:53 -0500)
committerGitHub <noreply@github.com>
Wed, 12 Aug 2020 03:53:58 +0000 (22:53 -0500)
In forthcoming PRs, Theory objects will be assigned equality engine objects dynamically.

This PR prepares the theory of sets for this update, which involves refactoring of its internal members.

12 files changed:
src/theory/sets/cardinality_extension.cpp
src/theory/sets/cardinality_extension.h
src/theory/sets/inference_manager.cpp
src/theory/sets/inference_manager.h
src/theory/sets/solver_state.cpp
src/theory/sets/solver_state.h
src/theory/sets/theory_sets.cpp
src/theory/sets/theory_sets.h
src/theory/sets/theory_sets_private.cpp
src/theory/sets/theory_sets_private.h
src/theory/sets/theory_sets_rels.cpp
src/theory/sets/theory_sets_rels.h

index dc3c773531572b4b302026be3c4c1c2e0af95337..1c12c71e42b1ef22f2705a2793c7bd5f8e38ba5c 100644 (file)
@@ -31,19 +31,15 @@ namespace sets {
 
 CardinalityExtension::CardinalityExtension(SolverState& s,
                                            InferenceManager& im,
-                                           eq::EqualityEngine& e,
                                            context::Context* c,
                                            context::UserContext* u)
     : d_state(s),
       d_im(im),
-      d_ee(e),
       d_card_processed(u),
       d_finite_type_constants_processed(false)
 {
   d_true = NodeManager::currentNM()->mkConst(true);
   d_zero = NodeManager::currentNM()->mkConst(Rational(0));
-  // we do congruence over cardinality
-  d_ee.addFunctionKind(CARD);
 }
 
 void CardinalityExtension::reset()
@@ -60,7 +56,7 @@ void CardinalityExtension::registerTerm(Node n)
   Assert(n.getKind() == CARD);
   TypeNode tnc = n[0].getType().getSetElementType();
   d_t_card_enabled[tnc] = true;
-  Node r = d_ee.getRepresentative(n[0]);
+  Node r = d_state.getRepresentative(n[0]);
   if (d_eqc_to_card_term.find(r) == d_eqc_to_card_term.end())
   {
     d_eqc_to_card_term[r] = n;
@@ -144,7 +140,7 @@ void CardinalityExtension::checkCardinalityExtended(TypeNode& t)
   for (Node& representative : representatives)
   {
     // the universe set is a subset of itself
-    if (representative != d_ee.getRepresentative(univ))
+    if (representative != d_state.getRepresentative(univ))
     {
       // here we only add representatives with variables to avoid adding
       // infinite equivalent generated terms to the cardinality graph
@@ -399,7 +395,7 @@ void CardinalityExtension::checkCardCyclesRec(Node eqc,
       true_sib = 1;
     }
     Node u = Rewriter::rewrite(nm->mkNode(UNION, n[0], n[1]));
-    if (!d_ee.hasTerm(u))
+    if (!d_state.hasTerm(u))
     {
       u = Node::null();
     }
@@ -413,7 +409,7 @@ void CardinalityExtension::checkCardCyclesRec(Node eqc,
       // parent equal siblings
       for (unsigned e = 0; e < true_sib; e++)
       {
-        if (d_ee.hasTerm(sib[e]) && !d_state.areEqual(n[e], sib[e]))
+        if (d_state.hasTerm(sib[e]) && !d_state.areEqual(n[e], sib[e]))
         {
           conc.push_back(n[e].eqNode(sib[e]));
         }
@@ -518,7 +514,7 @@ void CardinalityExtension::checkCardCyclesRec(Node eqc,
     for (unsigned k = 0, numcp = card_parents.size(); k < numcp; k++)
     {
       Node cpk = card_parents[k];
-      Node eqcc = d_ee.getRepresentative(cpk);
+      Node eqcc = d_state.getRepresentative(cpk);
       Trace("sets-debug") << "Check card parent " << k << "/"
                           << card_parents.size() << std::endl;
 
@@ -833,7 +829,7 @@ void CardinalityExtension::checkNormalForm(Node eqc,
               Trace("sets-nf") << "   Intro split : " << o0 << " against " << o1
                                << ", term is " << intro << std::endl;
               intro_sets.push_back(intro);
-              Assert(!d_ee.hasTerm(intro));
+              Assert(!d_state.hasTerm(intro));
               return;
             }
           }
index 5705301085f1a5bff92e43d036c6557398bc0b9d..b71af8a432ada1ce9bd2328aff0cdfd52d0f9b3d 100644 (file)
@@ -69,7 +69,6 @@ class CardinalityExtension
    */
   CardinalityExtension(SolverState& s,
                        InferenceManager& im,
-                       eq::EqualityEngine& e,
                        context::Context* c,
                        context::UserContext* u);
 
@@ -164,8 +163,6 @@ class CardinalityExtension
   SolverState& d_state;
   /** Reference to the inference manager for the theory of sets */
   InferenceManager& d_im;
-  /** Reference to the equality engine of theory of sets */
-  eq::EqualityEngine& d_ee;
   /** register cardinality term
    *
    * This method add lemmas corresponding to the definition of
index ac0a8205bd2a20f7db9c1fe938263d9c4ea277d9..f99dad91e652f615824d8ed10c4e8784cf64ed7c 100644 (file)
@@ -27,12 +27,10 @@ namespace sets {
 
 InferenceManager::InferenceManager(TheorySetsPrivate& p,
                                    SolverState& s,
-                                   eq::EqualityEngine& e,
                                    context::Context* c,
                                    context::UserContext* u)
     : d_parent(p),
       d_state(s),
-      d_ee(e),
       d_sentLemma(false),
       d_addedFact(false),
       d_lemmas_produced(u),
index 2a5a7703ca20ee7d023457ca8c4e759ed1b2e56b..ba6be9905707eb357d9f3e7ebe54fdae40b620b7 100644 (file)
@@ -40,7 +40,6 @@ class InferenceManager
  public:
   InferenceManager(TheorySetsPrivate& p,
                    SolverState& s,
-                   eq::EqualityEngine& e,
                    context::Context* c,
                    context::UserContext* u);
   /** reset
@@ -118,8 +117,6 @@ class InferenceManager
   TheorySetsPrivate& d_parent;
   /** Reference to the state object for the theory of sets */
   SolverState& d_state;
-  /** Reference to the equality engine of theory of sets */
-  eq::EqualityEngine& d_ee;
   /** pending lemmas */
   std::vector<Node> d_pendingLemmas;
   /** sent lemma
index d3c23454e2220f75551c6ed9678d869cab9853ba..f3371cf61e09fe1cff84ce2bacf3d3f44b2d2afe 100644 (file)
@@ -26,15 +26,20 @@ namespace theory {
 namespace sets {
 
 SolverState::SolverState(TheorySetsPrivate& p,
-                         eq::EqualityEngine& e,
                          context::Context* c,
                          context::UserContext* u)
-    : d_conflict(c), d_parent(p), d_ee(e), d_proxy(u), d_proxy_to_term(u)
+    : d_conflict(c), d_parent(p), d_ee(nullptr), d_proxy(u), d_proxy_to_term(u)
 {
   d_true = NodeManager::currentNM()->mkConst(true);
   d_false = NodeManager::currentNM()->mkConst(false);
 }
 
+void SolverState::finishInit(eq::EqualityEngine* ee)
+{
+  Assert(ee != nullptr);
+  d_ee = ee;
+}
+
 void SolverState::reset()
 {
   d_set_eqc.clear();
@@ -69,8 +74,8 @@ void SolverState::registerTerm(Node r, TypeNode tnn, Node n)
   {
     if (r.isConst())
     {
-      Node s = d_ee.getRepresentative(n[1]);
-      Node x = d_ee.getRepresentative(n[0]);
+      Node s = d_ee->getRepresentative(n[1]);
+      Node x = d_ee->getRepresentative(n[0]);
       int pindex = r == d_true ? 0 : (r == d_false ? 1 : -1);
       if (pindex != -1)
       {
@@ -99,7 +104,7 @@ void SolverState::registerTerm(Node r, TypeNode tnn, Node n)
     {
       // singleton lemma
       getProxy(n);
-      Node re = d_ee.getRepresentative(n[0]);
+      Node re = d_ee->getRepresentative(n[0]);
       if (d_singleton_index.find(re) == d_singleton_index.end())
       {
         d_singleton_index[re] = n;
@@ -122,8 +127,8 @@ void SolverState::registerTerm(Node r, TypeNode tnn, Node n)
     }
     else
     {
-      Node r1 = d_ee.getRepresentative(n[0]);
-      Node r2 = d_ee.getRepresentative(n[1]);
+      Node r1 = d_ee->getRepresentative(n[0]);
+      Node r2 = d_ee->getRepresentative(n[1]);
       std::map<Node, Node>& binr1 = d_bop_index[nk][r1];
       std::map<Node, Node>::iterator itb = binr1.find(r2);
       if (itb == binr1.end())
@@ -164,15 +169,26 @@ void SolverState::registerTerm(Node r, TypeNode tnn, Node n)
   }
 }
 
+Node SolverState::getRepresentative(Node a) const
+{
+  if (d_ee->hasTerm(a))
+  {
+    return d_ee->getRepresentative(a);
+  }
+  return a;
+}
+
+bool SolverState::hasTerm(Node a) const { return d_ee->hasTerm(a); }
+
 bool SolverState::areEqual(Node a, Node b) const
 {
   if (a == b)
   {
     return true;
   }
-  if (d_ee.hasTerm(a) && d_ee.hasTerm(b))
+  if (d_ee->hasTerm(a) && d_ee->hasTerm(b))
   {
-    return d_ee.areEqual(a, b);
+    return d_ee->areEqual(a, b);
   }
   return false;
 }
@@ -183,13 +199,15 @@ bool SolverState::areDisequal(Node a, Node b) const
   {
     return false;
   }
-  else if (d_ee.hasTerm(a) && d_ee.hasTerm(b))
+  else if (d_ee->hasTerm(a) && d_ee->hasTerm(b))
   {
-    return d_ee.areDisequal(a, b, false);
+    return d_ee->areDisequal(a, b, false);
   }
   return a.isConst() && b.isConst();
 }
 
+eq::EqualityEngine* SolverState::getEqualityEngine() const { return d_ee; }
+
 void SolverState::setConflict() { d_conflict = true; }
 void SolverState::setConflict(Node conf)
 {
@@ -279,9 +297,9 @@ bool SolverState::isEntailed(Node n, bool polarity) const
       return true;
     }
     // check members cache
-    if (polarity && d_ee.hasTerm(n[1]))
+    if (polarity && d_ee->hasTerm(n[1]))
     {
-      Node r = d_ee.getRepresentative(n[1]);
+      Node r = d_ee->getRepresentative(n[1]);
       if (d_parent.isMember(n[0], r))
       {
         return true;
@@ -310,8 +328,8 @@ bool SolverState::isEntailed(Node n, bool polarity) const
 
 bool SolverState::isSetDisequalityEntailed(Node r1, Node r2) const
 {
-  Assert(d_ee.hasTerm(r1) && d_ee.getRepresentative(r1) == r1);
-  Assert(d_ee.hasTerm(r2) && d_ee.getRepresentative(r2) == r2);
+  Assert(d_ee->hasTerm(r1) && d_ee->getRepresentative(r1) == r1);
+  Assert(d_ee->hasTerm(r2) && d_ee->getRepresentative(r2) == r2);
   TypeNode tn = r1.getType();
   Node re = getEmptySetEqClass(tn);
   for (unsigned e = 0; e < 2; e++)
@@ -433,7 +451,7 @@ Node SolverState::getProxy(Node n)
 
 Node SolverState::getCongruent(Node n) const
 {
-  Assert(d_ee.hasTerm(n));
+  Assert(d_ee->hasTerm(n));
   std::map<Node, Node>::const_iterator it = d_congruent.find(n);
   if (it == d_congruent.end())
   {
index 0b301dbb7fe5008b8e0558daec7cd4c85afb0737..dce90c2d3b44b02c714761edf939441d825fdf5d 100644 (file)
@@ -48,17 +48,21 @@ class SolverState
 
  public:
   SolverState(TheorySetsPrivate& p,
-              eq::EqualityEngine& e,
               context::Context* c,
               context::UserContext* u);
-  //-------------------------------- initialize
+  /**
+   * Finish initialize, there ee is a pointer to the official equality engine
+   * of theory of strings.
+   */
+  void finishInit(eq::EqualityEngine* ee);
+  //-------------------------------- initialize per check
   /** reset, clears the data structures maintained by this class. */
   void reset();
   /** register equivalence class whose type is tn */
   void registerEqc(TypeNode tn, Node r);
   /** register term n of type tnn in the equivalence class of r */
   void registerTerm(Node r, TypeNode tnn, Node n);
-  //-------------------------------- end initialize
+  //-------------------------------- end initialize per check
   /** Are we currently in conflict? */
   bool isInConflict() const { return d_conflict; }
   /**
@@ -68,10 +72,19 @@ class SolverState
   void setConflict();
   /** Set conf is a conflict node to be sent on the output channel.  */
   void setConflict(Node conf);
+  /**
+   * Get the representative of a in the equality engine of this class, or a
+   * itself if it is not registered as a term.
+   */
+  Node getRepresentative(Node a) const;
+  /** Is a registered as a term in the equality engine of this class? */
+  bool hasTerm(Node a) const;
   /** Is a=b according to equality reasoning in the current context? */
   bool areEqual(Node a, Node b) const;
   /** Is a!=b according to equality reasoning in the current context? */
   bool areDisequal(Node a, Node b) const;
+  /** get equality engine */
+  eq::EqualityEngine* getEqualityEngine() const;
   /** add equality to explanation
    *
    * This adds a = b to exp if a and b are syntactically disequal. The equality
@@ -220,8 +233,8 @@ class SolverState
   context::CDO<bool> d_conflict;
   /** Reference to the parent theory of sets */
   TheorySetsPrivate& d_parent;
-  /** Reference to the equality engine of theory of sets */
-  eq::EqualityEngine& d_ee;
+  /** Pointer to the official equality engine of theory of sets */
+  eq::EqualityEngine* d_ee;
   /** The list of all equivalence classes of type set in the current context */
   std::vector<Node> d_set_eqc;
   /** Maps types to the equivalence class containing empty set of that type */
index 9c680cc64a28e18be5d7b1ea0f15bde37b754690..17caac4f7d7bd6f2be58f9794afdb44dffcadb83 100644 (file)
@@ -34,7 +34,9 @@ TheorySets::TheorySets(context::Context* c,
                        const LogicInfo& logicInfo,
                        ProofNodeManager* pnm)
     : Theory(THEORY_SETS, c, u, out, valuation, logicInfo, pnm),
-      d_internal(new TheorySetsPrivate(*this, c, u))
+      d_internal(new TheorySetsPrivate(*this, c, u)),
+      d_notify(*d_internal.get()),
+      d_equalityEngine(d_notify, c, "theory::sets::ee", true)
 {
   // Do not move me to the header.
   // The constructor + destructor are not in the header as d_internal is a
@@ -57,6 +59,27 @@ void TheorySets::finishInit()
   d_valuation.setUnevaluatedKind(COMPREHENSION);
   // choice is used to eliminate witness
   d_valuation.setUnevaluatedKind(WITNESS);
+
+  // functions we are doing congruence over
+  d_equalityEngine.addFunctionKind(kind::SINGLETON);
+  d_equalityEngine.addFunctionKind(kind::UNION);
+  d_equalityEngine.addFunctionKind(kind::INTERSECTION);
+  d_equalityEngine.addFunctionKind(kind::SETMINUS);
+  d_equalityEngine.addFunctionKind(kind::MEMBER);
+  d_equalityEngine.addFunctionKind(kind::SUBSET);
+  // relation operators
+  d_equalityEngine.addFunctionKind(PRODUCT);
+  d_equalityEngine.addFunctionKind(JOIN);
+  d_equalityEngine.addFunctionKind(TRANSPOSE);
+  d_equalityEngine.addFunctionKind(TCLOSURE);
+  d_equalityEngine.addFunctionKind(JOIN_IMAGE);
+  d_equalityEngine.addFunctionKind(IDEN);
+  d_equalityEngine.addFunctionKind(APPLY_CONSTRUCTOR);
+  // we do congruence over cardinality
+  d_equalityEngine.addFunctionKind(CARD);
+
+  // finish initialization internally
+  d_internal->finishInit();
 }
 
 void TheorySets::addSharedTerm(TNode n) {
@@ -171,14 +194,96 @@ void TheorySets::propagate(Effort e) {
   d_internal->propagate(e);
 }
 
-void TheorySets::setMasterEqualityEngine(eq::EqualityEngine* eq) {
-  d_internal->setMasterEqualityEngine(eq);
-}
-
 bool TheorySets::isEntailed( Node n, bool pol ) {
   return d_internal->isEntailed( n, pol );
 }
 
+eq::EqualityEngine* TheorySets::getEqualityEngine() 
+{
+  return &d_equalityEngine;
+}
+
+/**************************** eq::NotifyClass *****************************/
+
+bool TheorySets::NotifyClass::eqNotifyTriggerEquality(TNode equality,
+                                                      bool value)
+{
+  Debug("sets-eq") << "[sets-eq] eqNotifyTriggerEquality: equality = "
+                   << equality << " value = " << value << std::endl;
+  if (value)
+  {
+    return d_theory.propagate(equality);
+  }
+  else
+  {
+    // We use only literal triggers so taking not is safe
+    return d_theory.propagate(equality.notNode());
+  }
+}
+
+bool TheorySets::NotifyClass::eqNotifyTriggerPredicate(TNode predicate,
+                                                       bool value)
+{
+  Debug("sets-eq") << "[sets-eq] eqNotifyTriggerPredicate: predicate = "
+                   << predicate << " value = " << value << std::endl;
+  if (value)
+  {
+    return d_theory.propagate(predicate);
+  }
+  else
+  {
+    return d_theory.propagate(predicate.notNode());
+  }
+}
+
+bool TheorySets::NotifyClass::eqNotifyTriggerTermEquality(TheoryId tag,
+                                                          TNode t1,
+                                                          TNode t2,
+                                                          bool value)
+{
+  Debug("sets-eq") << "[sets-eq] eqNotifyTriggerTermEquality: tag = " << tag
+                   << " t1 = " << t1 << "  t2 = " << t2 << "  value = " << value
+                   << std::endl;
+  d_theory.propagate(value ? t1.eqNode(t2) : t1.eqNode(t2).negate());
+  return true;
+}
+
+void TheorySets::NotifyClass::eqNotifyConstantTermMerge(TNode t1, TNode t2)
+{
+  Debug("sets-eq") << "[sets-eq] eqNotifyConstantTermMerge "
+                   << " t1 = " << t1 << " t2 = " << t2 << std::endl;
+  d_theory.conflict(t1, t2);
+}
+
+void TheorySets::NotifyClass::eqNotifyNewClass(TNode t)
+{
+  Debug("sets-eq") << "[sets-eq] eqNotifyNewClass:"
+                   << " t = " << t << std::endl;
+  d_theory.eqNotifyNewClass(t);
+}
+
+void TheorySets::NotifyClass::eqNotifyPreMerge(TNode t1, TNode t2)
+{
+  Debug("sets-eq") << "[sets-eq] eqNotifyPreMerge:"
+                   << " t1 = " << t1 << " t2 = " << t2 << std::endl;
+  d_theory.eqNotifyPreMerge(t1, t2);
+}
+
+void TheorySets::NotifyClass::eqNotifyPostMerge(TNode t1, TNode t2)
+{
+  Debug("sets-eq") << "[sets-eq] eqNotifyPostMerge:"
+                   << " t1 = " << t1 << " t2 = " << t2 << std::endl;
+  d_theory.eqNotifyPostMerge(t1, t2);
+}
+
+void TheorySets::NotifyClass::eqNotifyDisequal(TNode t1, TNode t2, TNode reason)
+{
+  Debug("sets-eq") << "[sets-eq] eqNotifyDisequal:"
+                   << " t1 = " << t1 << " t2 = " << t2 << " reason = " << reason
+                   << std::endl;
+  d_theory.eqNotifyDisequal(t1, t2, reason);
+}
+
 }/* CVC4::theory::sets namespace */
 }/* CVC4::theory namespace */
 }/* CVC4 namespace */
index e81412ba96ad9f978925ec5266cda215114aefbc..f1b59e419577aed93341329db7de9418fd7d87de 100644 (file)
@@ -33,6 +33,8 @@ class TheorySetsScrutinize;
 
 class TheorySets : public Theory
 {
+  friend class TheorySetsPrivate;
+  friend class TheorySetsRels;
  public:
   /** Constructs a new instance of TheorySets w.r.t. the provided contexts. */
   TheorySets(context::Context* c,
@@ -43,10 +45,13 @@ class TheorySets : public Theory
              ProofNodeManager* pnm);
   ~TheorySets() override;
 
+  //--------------------------------- initialization
+  /** get the official theory rewriter of this theory */
   TheoryRewriter* getTheoryRewriter() override;
-
   /** finish initialization */
   void finishInit() override;
+  //--------------------------------- end initialization
+
   void addSharedTerm(TNode) override;
   void check(Effort) override;
   bool collectModelInfo(TheoryModel* m) override;
@@ -60,15 +65,36 @@ class TheorySets : public Theory
   PPAssertStatus ppAssert(TNode in, SubstitutionMap& outSubstitutions) override;
   void presolve() override;
   void propagate(Effort) override;
-  void setMasterEqualityEngine(eq::EqualityEngine* eq) override;
   bool isEntailed(Node n, bool pol);
-
+  /* equality engine */
+  virtual eq::EqualityEngine* getEqualityEngine() override;
  private:
-  friend class TheorySetsPrivate;
-  friend class TheorySetsScrutinize;
-  friend class TheorySetsRels;
-
+  /** Functions to handle callbacks from equality engine */
+  class NotifyClass : public eq::EqualityEngineNotify
+  {
+   public:
+    NotifyClass(TheorySetsPrivate& theory) : d_theory(theory) {}
+    bool eqNotifyTriggerEquality(TNode equality, bool value) override;
+    bool eqNotifyTriggerPredicate(TNode predicate, bool value) override;
+    bool eqNotifyTriggerTermEquality(TheoryId tag,
+                                     TNode t1,
+                                     TNode t2,
+                                     bool value) override;
+    void eqNotifyConstantTermMerge(TNode t1, TNode t2) override;
+    void eqNotifyNewClass(TNode t) override;
+    void eqNotifyPreMerge(TNode t1, TNode t2) override;
+    void eqNotifyPostMerge(TNode t1, TNode t2) override;
+    void eqNotifyDisequal(TNode t1, TNode t2, TNode reason) override;
+    
+   private:
+    TheorySetsPrivate& d_theory;
+  };
+  /** The internal theory */
   std::unique_ptr<TheorySetsPrivate> d_internal;
+  /** Instance of the above class */
+  NotifyClass d_notify;  
+  /** Equality engine */
+  eq::EqualityEngine d_equalityEngine;
 }; /* class TheorySets */
 
 }/* CVC4::theory::sets namespace */
index 4c3affe99781be6369d0f8681df867a7a1aeb87f..25ee3167e3a4703989888464df10613819792622 100644 (file)
@@ -43,27 +43,16 @@ TheorySetsPrivate::TheorySetsPrivate(TheorySets& external,
       d_keep(c),
       d_full_check_incomplete(false),
       d_external(external),
-      d_notify(*this),
-      d_equalityEngine(d_notify, c, "theory::sets::ee", true),
-      d_state(*this, d_equalityEngine, c, u),
-      d_im(*this, d_state, d_equalityEngine, c, u),
-      d_rels(new TheorySetsRels(d_state, d_im, d_equalityEngine, u)),
-      d_cardSolver(
-          new CardinalityExtension(d_state, d_im, d_equalityEngine, c, u)),
+      d_state(*this, c, u),
+      d_im(*this, d_state, c, u),
+      d_rels(new TheorySetsRels(d_state, d_im, u)),
+      d_cardSolver(new CardinalityExtension(d_state, d_im, c, u)),
       d_rels_enabled(false),
       d_card_enabled(false)
 {
   d_true = NodeManager::currentNM()->mkConst(true);
   d_false = NodeManager::currentNM()->mkConst(false);
   d_zero = NodeManager::currentNM()->mkConst(Rational(0));
-
-  d_equalityEngine.addFunctionKind(kind::SINGLETON);
-  d_equalityEngine.addFunctionKind(kind::UNION);
-  d_equalityEngine.addFunctionKind(kind::INTERSECTION);
-  d_equalityEngine.addFunctionKind(kind::SETMINUS);
-
-  d_equalityEngine.addFunctionKind(kind::MEMBER);
-  d_equalityEngine.addFunctionKind(kind::SUBSET);
 }
 
 TheorySetsPrivate::~TheorySetsPrivate()
@@ -74,6 +63,13 @@ TheorySetsPrivate::~TheorySetsPrivate()
   }
 }
 
+void TheorySetsPrivate::finishInit()
+{
+  d_equalityEngine = d_external.getEqualityEngine();
+  Assert(d_equalityEngine != nullptr);
+  d_state.finishInit(d_equalityEngine);
+}
+
 void TheorySetsPrivate::eqNotifyNewClass(TNode t)
 {
   if (t.getKind() == kind::SINGLETON || t.getKind() == kind::EMPTYSET)
@@ -240,13 +236,13 @@ TheorySetsPrivate::EqcInfo* TheorySetsPrivate::getOrMakeEqcInfo(TNode n,
 
 bool TheorySetsPrivate::areCareDisequal(Node a, Node b)
 {
-  if (d_equalityEngine.isTriggerTerm(a, THEORY_SETS)
-      && d_equalityEngine.isTriggerTerm(b, THEORY_SETS))
+  if (d_equalityEngine->isTriggerTerm(a, THEORY_SETS)
+      && d_equalityEngine->isTriggerTerm(b, THEORY_SETS))
   {
     TNode a_shared =
-        d_equalityEngine.getTriggerTermRepresentative(a, THEORY_SETS);
+        d_equalityEngine->getTriggerTermRepresentative(a, THEORY_SETS);
     TNode b_shared =
-        d_equalityEngine.getTriggerTermRepresentative(b, THEORY_SETS);
+        d_equalityEngine->getTriggerTermRepresentative(b, THEORY_SETS);
     EqualityStatus eqStatus =
         d_external.d_valuation.getEqualityStatus(a_shared, b_shared);
     if (eqStatus == EQUALITY_FALSE_AND_PROPAGATED || eqStatus == EQUALITY_FALSE
@@ -260,8 +256,8 @@ bool TheorySetsPrivate::areCareDisequal(Node a, Node b)
 
 bool TheorySetsPrivate::isMember(Node x, Node s)
 {
-  Assert(d_equalityEngine.hasTerm(s)
-         && d_equalityEngine.getRepresentative(s) == s);
+  Assert(d_equalityEngine->hasTerm(s)
+         && d_equalityEngine->getRepresentative(s) == s);
   NodeIntMap::iterator mem_i = d_members.find(s);
   if (mem_i != d_members.end())
   {
@@ -286,18 +282,18 @@ bool TheorySetsPrivate::assertFact(Node fact, Node exp)
   {
     if (atom.getKind() == kind::EQUAL)
     {
-      d_equalityEngine.assertEquality(atom, polarity, exp);
+      d_equalityEngine->assertEquality(atom, polarity, exp);
     }
     else
     {
-      d_equalityEngine.assertPredicate(atom, polarity, exp);
+      d_equalityEngine->assertPredicate(atom, polarity, exp);
     }
     if (!d_state.isInConflict())
     {
       if (atom.getKind() == kind::MEMBER && polarity)
       {
         // check if set has a value, if so, we can propagate
-        Node r = d_equalityEngine.getRepresentative(atom[1]);
+        Node r = d_equalityEngine->getRepresentative(atom[1]);
         EqcInfo* e = getOrMakeEqcInfo(r, true);
         if (e)
         {
@@ -354,7 +350,7 @@ bool TheorySetsPrivate::assertFact(Node fact, Node exp)
 
 void TheorySetsPrivate::fullEffortReset()
 {
-  Assert(d_equalityEngine.consistent());
+  Assert(d_equalityEngine->consistent());
   d_full_check_incomplete = false;
   d_most_common_type.clear();
   d_most_common_type_term.clear();
@@ -380,7 +376,7 @@ void TheorySetsPrivate::fullEffortCheck()
 
     Trace("sets-eqc") << "Equality Engine:" << std::endl;
     std::map<TypeNode, unsigned> eqcTypeCount;
-    eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(&d_equalityEngine);
+    eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(d_equalityEngine);
     while (!eqcs_i.isFinished())
     {
       Node eqc = (*eqcs_i);
@@ -398,7 +394,7 @@ void TheorySetsPrivate::fullEffortCheck()
         tnct = eqc;
       }
       Trace("sets-eqc") << "[" << eqc << "] : ";
-      eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, &d_equalityEngine);
+      eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, d_equalityEngine);
       while (!eqc_i.isFinished())
       {
         Node n = (*eqc_i);
@@ -624,7 +620,7 @@ void TheorySetsPrivate::checkDownwardsClosure()
           {
             Node mem = it2.second;
             Node eq_set = nv;
-            Assert(d_equalityEngine.areEqual(mem[1], eq_set));
+            Assert(d_equalityEngine->areEqual(mem[1], eq_set));
             if (mem[1] != eq_set)
             {
               Trace("sets-debug") << "Downwards closure based on " << mem
@@ -761,7 +757,7 @@ void TheorySetsPrivate::checkUpwardsClosure()
                 }
                 if (valid)
                 {
-                  Node rr = d_equalityEngine.getRepresentative(term);
+                  Node rr = d_equalityEngine->getRepresentative(term);
                   if (!isMember(x, rr))
                   {
                     Node kk = d_state.getProxy(term);
@@ -785,7 +781,7 @@ void TheorySetsPrivate::checkUpwardsClosure()
                 for (const std::pair<const Node, Node>& itm2m : r2mem)
                 {
                   Node x = itm2m.second[0];
-                  Node rr = d_equalityEngine.getRepresentative(term);
+                  Node rr = d_equalityEngine->getRepresentative(term);
                   if (!isMember(x, rr))
                   {
                     std::vector<Node> exp;
@@ -882,10 +878,10 @@ void TheorySetsPrivate::checkDisequalities()
     }
     Node deq = (*it).first;
     // check if it is already satisfied
-    Assert(d_equalityEngine.hasTerm(deq[0])
-           && d_equalityEngine.hasTerm(deq[1]));
-    Node r1 = d_equalityEngine.getRepresentative(deq[0]);
-    Node r2 = d_equalityEngine.getRepresentative(deq[1]);
+    Assert(d_equalityEngine->hasTerm(deq[0])
+           && d_equalityEngine->hasTerm(deq[1]));
+    Node r1 = d_equalityEngine->getRepresentative(deq[0]);
+    Node r2 = d_equalityEngine->getRepresentative(deq[1]);
     bool is_sat = d_state.isSetDisequalityEntailed(r1, r2);
     Trace("sets-debug") << "Check disequality " << deq
                         << ", is_sat = " << is_sat << std::endl;
@@ -1005,7 +1001,7 @@ void TheorySetsPrivate::addSharedTerm(TNode n)
 {
   Debug("sets") << "[sets] TheorySetsPrivate::addSharedTerm( " << n << ")"
                 << std::endl;
-  d_equalityEngine.addTriggerTerm(n, THEORY_SETS);
+  d_equalityEngine->addTriggerTerm(n, THEORY_SETS);
 }
 
 void TheorySetsPrivate::addCarePairs(TNodeTrie* t1,
@@ -1028,21 +1024,21 @@ void TheorySetsPrivate::addCarePairs(TNodeTrie* t1,
         {
           TNode x = f1[k];
           TNode y = f2[k];
-          Assert(d_equalityEngine.hasTerm(x));
-          Assert(d_equalityEngine.hasTerm(y));
+          Assert(d_equalityEngine->hasTerm(x));
+          Assert(d_equalityEngine->hasTerm(y));
           Assert(!d_state.areDisequal(x, y));
           Assert(!areCareDisequal(x, y));
-          if (!d_equalityEngine.areEqual(x, y))
+          if (!d_equalityEngine->areEqual(x, y))
           {
             Trace("sets-cg")
                 << "Arg #" << k << " is " << x << " " << y << std::endl;
-            if (d_equalityEngine.isTriggerTerm(x, THEORY_SETS)
-                && d_equalityEngine.isTriggerTerm(y, THEORY_SETS))
+            if (d_equalityEngine->isTriggerTerm(x, THEORY_SETS)
+                && d_equalityEngine->isTriggerTerm(y, THEORY_SETS))
             {
-              TNode x_shared =
-                  d_equalityEngine.getTriggerTermRepresentative(x, THEORY_SETS);
-              TNode y_shared =
-                  d_equalityEngine.getTriggerTermRepresentative(y, THEORY_SETS);
+              TNode x_shared = d_equalityEngine->getTriggerTermRepresentative(
+                  x, THEORY_SETS);
+              TNode y_shared = d_equalityEngine->getTriggerTermRepresentative(
+                  y, THEORY_SETS);
               currentPairs.push_back(make_pair(x_shared, y_shared));
             }
             else if (isCareArg(f1, k) && isCareArg(f2, k))
@@ -1092,7 +1088,7 @@ void TheorySetsPrivate::addCarePairs(TNodeTrie* t1,
         ++it2;
         for (; it2 != t1->d_data.end(); ++it2)
         {
-          if (!d_equalityEngine.areDisequal(it->first, it2->first, false))
+          if (!d_equalityEngine->areDisequal(it->first, it2->first, false))
           {
             if (!areCareDisequal(it->first, it2->first))
             {
@@ -1110,7 +1106,7 @@ void TheorySetsPrivate::addCarePairs(TNodeTrie* t1,
       {
         for (std::pair<const TNode, TNodeTrie>& tt2 : t2->d_data)
         {
-          if (!d_equalityEngine.areDisequal(tt1.first, tt2.first, false))
+          if (!d_equalityEngine->areDisequal(tt1.first, tt2.first, false))
           {
             if (!areCareDisequal(tt1.first, tt2.first))
             {
@@ -1140,9 +1136,9 @@ void TheorySetsPrivate::computeCareGraph()
       // populate indices
       for (TNode f1 : it.second)
       {
-        Assert(d_equalityEngine.hasTerm(f1));
+        Assert(d_equalityEngine->hasTerm(f1));
         Trace("sets-cg-debug") << "...build for " << f1 << std::endl;
-        Assert(d_equalityEngine.hasTerm(f1));
+        Assert(d_equalityEngine->hasTerm(f1));
         // break into index based on operator, and type of first argument (since
         // some operators are parametric)
         TypeNode tn = f1[0].getType();
@@ -1150,7 +1146,7 @@ void TheorySetsPrivate::computeCareGraph()
         bool hasCareArg = false;
         for (unsigned j = 0; j < f1.getNumChildren(); j++)
         {
-          reps.push_back(d_equalityEngine.getRepresentative(f1[j]));
+          reps.push_back(d_equalityEngine->getRepresentative(f1[j]));
           if (isCareArg(f1, j))
           {
             hasCareArg = true;
@@ -1184,7 +1180,7 @@ void TheorySetsPrivate::computeCareGraph()
 
 bool TheorySetsPrivate::isCareArg(Node n, unsigned a)
 {
-  if (d_equalityEngine.isTriggerTerm(n[a], THEORY_SETS))
+  if (d_equalityEngine->isTriggerTerm(n[a], THEORY_SETS))
   {
     return true;
   }
@@ -1201,13 +1197,13 @@ bool TheorySetsPrivate::isCareArg(Node n, unsigned a)
 
 EqualityStatus TheorySetsPrivate::getEqualityStatus(TNode a, TNode b)
 {
-  Assert(d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b));
-  if (d_equalityEngine.areEqual(a, b))
+  Assert(d_equalityEngine->hasTerm(a) && d_equalityEngine->hasTerm(b));
+  if (d_equalityEngine->areEqual(a, b))
   {
     // The terms are implied to be equal
     return EQUALITY_TRUE;
   }
-  if (d_equalityEngine.areDisequal(a, b, false))
+  if (d_equalityEngine->areDisequal(a, b, false))
   {
     // The terms are implied to be dis-equal
     return EQUALITY_FALSE;
@@ -1272,7 +1268,7 @@ bool TheorySetsPrivate::collectModelInfo(TheoryModel* m)
   d_external.computeRelevantTerms(termSet);
 
   // Assert equalities and disequalities to the model
-  if (!m->assertEqualityEngine(&d_equalityEngine, &termSet))
+  if (!m->assertEqualityEngine(d_equalityEngine, &termSet))
   {
     return false;
   }
@@ -1429,11 +1425,6 @@ OutputChannel* TheorySetsPrivate::getOutputChannel()
 
 Valuation& TheorySetsPrivate::getValuation() { return d_external.d_valuation; }
 
-void TheorySetsPrivate::setMasterEqualityEngine(eq::EqualityEngine* eq)
-{
-  d_equalityEngine.setMasterEqualityEngine(eq);
-}
-
 void TheorySetsPrivate::conflict(TNode a, TNode b)
 {
   Node conf = explain(a.eqNode(b));
@@ -1453,11 +1444,11 @@ Node TheorySetsPrivate::explain(TNode literal)
 
   if (atom.getKind() == kind::EQUAL)
   {
-    d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions);
+    d_equalityEngine->explainEquality(atom[0], atom[1], polarity, assumptions);
   }
   else if (atom.getKind() == kind::MEMBER)
   {
-    d_equalityEngine.explainPredicate(atom, polarity, assumptions);
+    d_equalityEngine->explainPredicate(atom, polarity, assumptions);
   }
   else
   {
@@ -1475,10 +1466,10 @@ void TheorySetsPrivate::preRegisterTerm(TNode node)
                 << std::endl;
   switch (node.getKind())
   {
-    case kind::EQUAL: d_equalityEngine.addTriggerEquality(node); break;
-    case kind::MEMBER: d_equalityEngine.addTriggerPredicate(node); break;
-    case kind::CARD: d_equalityEngine.addTriggerTerm(node, THEORY_SETS); break;
-    default: d_equalityEngine.addTerm(node); break;
+    case kind::EQUAL: d_equalityEngine->addTriggerEquality(node); break;
+    case kind::MEMBER: d_equalityEngine->addTriggerPredicate(node); break;
+    case kind::CARD: d_equalityEngine->addTriggerTerm(node, THEORY_SETS); break;
+    default: d_equalityEngine->addTerm(node); break;
   }
 }
 
@@ -1538,92 +1529,6 @@ Node TheorySetsPrivate::getChooseFunction(const TypeNode& setType)
 
 void TheorySetsPrivate::presolve() { d_state.reset(); }
 
-/**************************** eq::NotifyClass *****************************/
-/**************************** eq::NotifyClass *****************************/
-/**************************** eq::NotifyClass *****************************/
-
-bool TheorySetsPrivate::NotifyClass::eqNotifyTriggerEquality(TNode equality,
-                                                             bool value)
-{
-  Debug("sets-eq") << "[sets-eq] eqNotifyTriggerEquality: equality = "
-                   << equality << " value = " << value << std::endl;
-  if (value)
-  {
-    return d_theory.propagate(equality);
-  }
-  else
-  {
-    // We use only literal triggers so taking not is safe
-    return d_theory.propagate(equality.notNode());
-  }
-}
-
-bool TheorySetsPrivate::NotifyClass::eqNotifyTriggerPredicate(TNode predicate,
-                                                              bool value)
-{
-  Debug("sets-eq") << "[sets-eq] eqNotifyTriggerPredicate: predicate = "
-                   << predicate << " value = " << value << std::endl;
-  if (value)
-  {
-    return d_theory.propagate(predicate);
-  }
-  else
-  {
-    return d_theory.propagate(predicate.notNode());
-  }
-}
-
-bool TheorySetsPrivate::NotifyClass::eqNotifyTriggerTermEquality(TheoryId tag,
-                                                                 TNode t1,
-                                                                 TNode t2,
-                                                                 bool value)
-{
-  Debug("sets-eq") << "[sets-eq] eqNotifyTriggerTermEquality: tag = " << tag
-                   << " t1 = " << t1 << "  t2 = " << t2 << "  value = " << value
-                   << std::endl;
-  d_theory.propagate(value ? t1.eqNode(t2) : t1.eqNode(t2).negate());
-  return true;
-}
-
-void TheorySetsPrivate::NotifyClass::eqNotifyConstantTermMerge(TNode t1,
-                                                               TNode t2)
-{
-  Debug("sets-eq") << "[sets-eq] eqNotifyConstantTermMerge "
-                   << " t1 = " << t1 << " t2 = " << t2 << std::endl;
-  d_theory.conflict(t1, t2);
-}
-
-void TheorySetsPrivate::NotifyClass::eqNotifyNewClass(TNode t)
-{
-  Debug("sets-eq") << "[sets-eq] eqNotifyNewClass:"
-                   << " t = " << t << std::endl;
-  d_theory.eqNotifyNewClass(t);
-}
-
-void TheorySetsPrivate::NotifyClass::eqNotifyPreMerge(TNode t1, TNode t2)
-{
-  Debug("sets-eq") << "[sets-eq] eqNotifyPreMerge:"
-                   << " t1 = " << t1 << " t2 = " << t2 << std::endl;
-  d_theory.eqNotifyPreMerge(t1, t2);
-}
-
-void TheorySetsPrivate::NotifyClass::eqNotifyPostMerge(TNode t1, TNode t2)
-{
-  Debug("sets-eq") << "[sets-eq] eqNotifyPostMerge:"
-                   << " t1 = " << t1 << " t2 = " << t2 << std::endl;
-  d_theory.eqNotifyPostMerge(t1, t2);
-}
-
-void TheorySetsPrivate::NotifyClass::eqNotifyDisequal(TNode t1,
-                                                      TNode t2,
-                                                      TNode reason)
-{
-  Debug("sets-eq") << "[sets-eq] eqNotifyDisequal:"
-                   << " t1 = " << t1 << " t2 = " << t2 << " reason = " << reason
-                   << std::endl;
-  d_theory.eqNotifyDisequal(t1, t2, reason);
-}
-
 }  // namespace sets
 }  // namespace theory
 }  // namespace CVC4
index c65c867959e2d89fe0ef41a37c4d6ebfcb9bb76c..2779a42b7c905545f21285fcec74b5f998e0e181 100644 (file)
@@ -163,7 +163,11 @@ class TheorySetsPrivate {
 
   TheoryRewriter* getTheoryRewriter() { return &d_rewriter; }
 
-  void setMasterEqualityEngine(eq::EqualityEngine* eq);
+  /**
+   * Finish initialize, called after the equality engine of theory sets has
+   * been determined.
+   */
+  void finishInit();
 
   void addSharedTerm(TNode);
 
@@ -219,37 +223,18 @@ class TheorySetsPrivate {
   /** get the valuation */
   Valuation& getValuation();
 
- private:
-  TheorySets& d_external;
-
-  /** Functions to handle callbacks from equality engine */
-  class NotifyClass : public eq::EqualityEngineNotify {
-    TheorySetsPrivate& d_theory;
-
-  public:
-    NotifyClass(TheorySetsPrivate& theory): d_theory(theory) {}
-    bool eqNotifyTriggerEquality(TNode equality, bool value) override;
-    bool eqNotifyTriggerPredicate(TNode predicate, bool value) override;
-    bool eqNotifyTriggerTermEquality(TheoryId tag,
-                                     TNode t1,
-                                     TNode t2,
-                                     bool value) override;
-    void eqNotifyConstantTermMerge(TNode t1, TNode t2) override;
-    void eqNotifyNewClass(TNode t) override;
-    void eqNotifyPreMerge(TNode t1, TNode t2) override;
-    void eqNotifyPostMerge(TNode t1, TNode t2) override;
-    void eqNotifyDisequal(TNode t1, TNode t2, TNode reason) override;
-  } d_notify;
-
-  /** Equality engine */
-  eq::EqualityEngine d_equalityEngine;
-
   /** Proagate out to output channel */
   bool propagate(TNode);
 
   /** generate and send out conflict node */
   void conflict(TNode, TNode);
-  
+
+ private:
+  TheorySets& d_external;
+
+  /** Pointer to the equality engine of theory of sets */
+  eq::EqualityEngine* d_equalityEngine;
+
   bool isCareArg( Node n, unsigned a );
 
  public:
index f4b1d3be95afb5a549231e475f45f6c4f371f2be..82f32337e89fdcb256679aa927d8c413a3309579 100644 (file)
@@ -34,19 +34,11 @@ typedef std::map< Node, std::map< Node, std::unordered_set< Node, NodeHashFuncti
 
 TheorySetsRels::TheorySetsRels(SolverState& s,
                                InferenceManager& im,
-                               eq::EqualityEngine& e,
                                context::UserContext* u)
-    : d_state(s), d_im(im), d_ee(e), d_shared_terms(u)
+    : d_state(s), d_im(im), d_shared_terms(u)
 {
   d_trueNode = NodeManager::currentNM()->mkConst(true);
   d_falseNode = NodeManager::currentNM()->mkConst(false);
-  d_ee.addFunctionKind(PRODUCT);
-  d_ee.addFunctionKind(JOIN);
-  d_ee.addFunctionKind(TRANSPOSE);
-  d_ee.addFunctionKind(TCLOSURE);
-  d_ee.addFunctionKind(JOIN_IMAGE);
-  d_ee.addFunctionKind(IDEN);
-  d_ee.addFunctionKind(APPLY_CONSTRUCTOR);
 }
 
 TheorySetsRels::~TheorySetsRels() {}
@@ -185,10 +177,11 @@ void TheorySetsRels::check(Theory::Effort level)
 
   void TheorySetsRels::collectRelsInfo() {
     Trace("rels") << "[sets-rels] Start collecting relational terms..." << std::endl;
-    eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(&d_ee);
+    eq::EqualityEngine* ee = d_state.getEqualityEngine();
+    eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(ee);
     while( !eqcs_i.isFinished() ){
       Node                      eqc_rep  = (*eqcs_i);
-      eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc_rep, &d_ee);
+      eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc_rep, ee);
 
       TypeNode erType = eqc_rep.getType();
       Trace("rels-ee") << "[sets-rels-ee] Eqc term representative: " << eqc_rep << " with type " << eqc_rep.getType() << std::endl;
@@ -1139,24 +1132,17 @@ void TheorySetsRels::check(Theory::Effort level)
   }
 
   Node TheorySetsRels::getRepresentative( Node t ) {
-    if (d_ee.hasTerm(t))
-    {
-      return d_ee.getRepresentative(t);
-    }
-    else
-    {
-      return t;
-    }
+    return d_state.getRepresentative(t);
   }
 
-  bool TheorySetsRels::hasTerm(Node a) { return d_ee.hasTerm(a); }
+  bool TheorySetsRels::hasTerm(Node a) { return d_state.hasTerm(a); }
   bool TheorySetsRels::areEqual( Node a, Node b ){
     Assert(a.getType() == b.getType());
     Trace("rels-eq") << "[sets-rels]**** checking equality between " << a << " and " << b << std::endl;
     if(a == b) {
       return true;
     } else if( hasTerm( a ) && hasTerm( b ) ){
-      return d_ee.areEqual(a, b);
+      return d_state.areEqual(a, b);
     } else if(a.getType().isTuple()) {
       bool equal = true;
       for(unsigned int i = 0; i < a.getType().getTupleLength(); i++) {
index 0d8ace50f263fba4d743915f8314d024b928c3f2..60715ff57b036fa04e8567c4e48d98985d31565c 100644 (file)
@@ -67,7 +67,6 @@ class TheorySetsRels {
 public:
  TheorySetsRels(SolverState& s,
                 InferenceManager& im,
-                eq::EqualityEngine& e,
                 context::UserContext* u);
 
  ~TheorySetsRels();
@@ -90,8 +89,6 @@ private:
   SolverState& d_state;
   /** Reference to the inference manager for the theory of sets */
   InferenceManager& d_im;
-  /** Reference to the equality engine of theory of sets */
-  eq::EqualityEngine& d_ee;
   /** A list of pending inferences to process */
   std::vector<Node> d_pending;
   NodeSet                       d_shared_terms;