Standardize equality engine notifications in sets (#5098)
[cvc5.git] / src / theory / sets / theory_sets_private.cpp
index 78f6fa8b5603170fce05068c5aa1db3fcc1c8363..8779ac48931f74e24697918f42852542658b0933 100644 (file)
@@ -2,9 +2,9 @@
 /*! \file theory_sets_private.cpp
  ** \verbatim
  ** Top contributors (to current version):
- **   Andrew Reynolds, Kshitij Bansal, Paul Meng
+ **   Andrew Reynolds, Mudathir Mohamed, Kshitij Bansal
  ** This file is part of the CVC4 project.
- ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
  ** in the top-level source directory) and their institutional affiliations.
  ** All rights reserved.  See the file COPYING in the top-level source
  ** directory for licensing information.\endverbatim
@@ -35,35 +35,25 @@ namespace theory {
 namespace sets {
 
 TheorySetsPrivate::TheorySetsPrivate(TheorySets& external,
-                                     context::Context* c,
-                                     context::UserContext* u)
-    : d_members(c),
-      d_deq(c),
-      d_termProcessed(u),
-      d_keep(c),
+                                     SolverState& state,
+                                     InferenceManager& im,
+                                     SkolemCache& skc)
+    : d_deq(state.getSatContext()),
+      d_termProcessed(state.getUserContext()),
       d_full_check_incomplete(false),
       d_external(external),
-      d_notify(*this),
-      d_equalityEngine(d_notify, c, "theory::sets::ee", true),
-      d_state(*this, d_equalityEngine, c, u),
-      d_im(*this, d_state, d_equalityEngine, c, u),
-      d_rels(new TheorySetsRels(d_state, d_im, d_equalityEngine, u)),
-      d_cardSolver(
-          new CardinalityExtension(d_state, d_im, d_equalityEngine, c, u)),
+      d_state(state),
+      d_im(im),
+      d_skCache(skc),
+      d_treg(state, im, skc),
+      d_rels(new TheorySetsRels(state, im, skc, d_treg)),
+      d_cardSolver(new CardinalityExtension(state, im, d_treg)),
       d_rels_enabled(false),
       d_card_enabled(false)
 {
   d_true = NodeManager::currentNM()->mkConst(true);
   d_false = NodeManager::currentNM()->mkConst(false);
   d_zero = NodeManager::currentNM()->mkConst(Rational(0));
-
-  d_equalityEngine.addFunctionKind(kind::SINGLETON);
-  d_equalityEngine.addFunctionKind(kind::UNION);
-  d_equalityEngine.addFunctionKind(kind::INTERSECTION);
-  d_equalityEngine.addFunctionKind(kind::SETMINUS);
-
-  d_equalityEngine.addFunctionKind(kind::MEMBER);
-  d_equalityEngine.addFunctionKind(kind::SUBSET);
 }
 
 TheorySetsPrivate::~TheorySetsPrivate()
@@ -74,6 +64,12 @@ TheorySetsPrivate::~TheorySetsPrivate()
   }
 }
 
+void TheorySetsPrivate::finishInit()
+{
+  d_equalityEngine = d_external.getEqualityEngine();
+  Assert(d_equalityEngine != nullptr);
+}
+
 void TheorySetsPrivate::eqNotifyNewClass(TNode t)
 {
   if (t.getKind() == kind::SINGLETON || t.getKind() == kind::EMPTYSET)
@@ -83,9 +79,7 @@ void TheorySetsPrivate::eqNotifyNewClass(TNode t)
   }
 }
 
-void TheorySetsPrivate::eqNotifyPreMerge(TNode t1, TNode t2) {}
-
-void TheorySetsPrivate::eqNotifyPostMerge(TNode t1, TNode t2)
+void TheorySetsPrivate::eqNotifyMerge(TNode t1, TNode t2)
 {
   if (!d_state.isInConflict())
   {
@@ -110,16 +104,14 @@ void TheorySetsPrivate::eqNotifyPostMerge(TNode t1, TNode t2)
             // infer equality between elements of singleton
             Node exp = s1.eqNode(s2);
             Node eq = s1[0].eqNode(s2[0]);
-            d_keep.insert(exp);
-            d_keep.insert(eq);
-            assertFact(eq, exp);
+            d_im.assertInternalFact(eq, true, exp);
           }
           else
           {
             // singleton equal to emptyset, conflict
             Trace("sets-prop")
                 << "Propagate conflict : " << s1 << " == " << s2 << std::endl;
-            conflict(s1, s2);
+            d_im.conflictEqConstantMerge(s1, s2);
             return;
           }
         }
@@ -133,73 +125,27 @@ void TheorySetsPrivate::eqNotifyPostMerge(TNode t1, TNode t2)
     }
     // merge membership list
     Trace("sets-prop-debug") << "Copying membership list..." << std::endl;
-    NodeIntMap::iterator mem_i2 = d_members.find(t2);
-    if (mem_i2 != d_members.end())
+    // if s1 has a singleton or empty set and s2 does not, we may have new
+    // inferences to process.
+    Node checkSingleton = s2.isNull() ? s1 : Node::null();
+    std::vector<Node> facts;
+    // merge the membership list in the state, which may produce facts or
+    // conflicts to propagate
+    if (!d_state.merge(t1, t2, facts, checkSingleton))
     {
-      NodeIntMap::iterator mem_i1 = d_members.find(t1);
-      int n_members = 0;
-      if (mem_i1 != d_members.end())
-      {
-        n_members = (*mem_i1).second;
-      }
-      for (int i = 0; i < (*mem_i2).second; i++)
-      {
-        Assert(i < (int)d_members_data[t2].size()
-               && d_members_data[t2][i].getKind() == kind::MEMBER);
-        Node m2 = d_members_data[t2][i];
-        // check if redundant
-        bool add = true;
-        for (int j = 0; j < n_members; j++)
-        {
-          Assert(j < (int)d_members_data[t1].size()
-                 && d_members_data[t1][j].getKind() == kind::MEMBER);
-          if (d_state.areEqual(m2[0], d_members_data[t1][j][0]))
-          {
-            add = false;
-            break;
-          }
-        }
-        if (add)
-        {
-          if (!s1.isNull() && s2.isNull())
-          {
-            Assert(m2[1].getType().isComparableTo(s1.getType()));
-            Assert(d_state.areEqual(m2[1], s1));
-            Node exp = NodeManager::currentNM()->mkNode(
-                kind::AND, m2[1].eqNode(s1), m2);
-            if (s1.getKind() == kind::SINGLETON)
-            {
-              if (s1[0] != m2[0])
-              {
-                Node eq = s1[0].eqNode(m2[0]);
-                d_keep.insert(exp);
-                d_keep.insert(eq);
-                Trace("sets-prop") << "Propagate eq-mem eq inference : " << exp
-                                   << " => " << eq << std::endl;
-                assertFact(eq, exp);
-              }
-            }
-            else
-            {
-              // conflict
-              Trace("sets-prop")
-                  << "Propagate eq-mem conflict : " << exp << std::endl;
-              d_state.setConflict(exp);
-              return;
-            }
-          }
-          if (n_members < (int)d_members_data[t1].size())
-          {
-            d_members_data[t1][n_members] = m2;
-          }
-          else
-          {
-            d_members_data[t1].push_back(m2);
-          }
-          n_members++;
-        }
-      }
-      d_members[t1] = n_members;
+      // conflict case
+      Assert(facts.size() == 1);
+      Trace("sets-prop") << "Propagate eq-mem conflict : " << facts[0]
+                         << std::endl;
+      d_im.conflict(facts[0]);
+      return;
+    }
+    for (const Node& f : facts)
+    {
+      Assert(f.getKind() == kind::IMPLIES);
+      Trace("sets-prop") << "Propagate eq-mem eq inference : " << f[0] << " => "
+                         << f[1] << std::endl;
+      d_im.assertInternalFact(f[1], true, f[0]);
     }
   }
 }
@@ -240,13 +186,13 @@ TheorySetsPrivate::EqcInfo* TheorySetsPrivate::getOrMakeEqcInfo(TNode n,
 
 bool TheorySetsPrivate::areCareDisequal(Node a, Node b)
 {
-  if (d_equalityEngine.isTriggerTerm(a, THEORY_SETS)
-      && d_equalityEngine.isTriggerTerm(b, THEORY_SETS))
+  if (d_equalityEngine->isTriggerTerm(a, THEORY_SETS)
+      && d_equalityEngine->isTriggerTerm(b, THEORY_SETS))
   {
     TNode a_shared =
-        d_equalityEngine.getTriggerTermRepresentative(a, THEORY_SETS);
+        d_equalityEngine->getTriggerTermRepresentative(a, THEORY_SETS);
     TNode b_shared =
-        d_equalityEngine.getTriggerTermRepresentative(b, THEORY_SETS);
+        d_equalityEngine->getTriggerTermRepresentative(b, THEORY_SETS);
     EqualityStatus eqStatus =
         d_external.d_valuation.getEqualityStatus(a_shared, b_shared);
     if (eqStatus == EQUALITY_FALSE_AND_PROPAGATED || eqStatus == EQUALITY_FALSE
@@ -258,103 +204,9 @@ bool TheorySetsPrivate::areCareDisequal(Node a, Node b)
   return false;
 }
 
-bool TheorySetsPrivate::isMember(Node x, Node s)
-{
-  Assert(d_equalityEngine.hasTerm(s)
-         && d_equalityEngine.getRepresentative(s) == s);
-  NodeIntMap::iterator mem_i = d_members.find(s);
-  if (mem_i != d_members.end())
-  {
-    for (int i = 0; i < (*mem_i).second; i++)
-    {
-      if (d_state.areEqual(d_members_data[s][i][0], x))
-      {
-        return true;
-      }
-    }
-  }
-  return false;
-}
-
-bool TheorySetsPrivate::assertFact(Node fact, Node exp)
-{
-  Trace("sets-assert") << "TheorySets::assertFact : " << fact
-                       << ", exp = " << exp << std::endl;
-  bool polarity = fact.getKind() != kind::NOT;
-  TNode atom = polarity ? fact : fact[0];
-  if (!d_state.isEntailed(atom, polarity))
-  {
-    if (atom.getKind() == kind::EQUAL)
-    {
-      d_equalityEngine.assertEquality(atom, polarity, exp);
-    }
-    else
-    {
-      d_equalityEngine.assertPredicate(atom, polarity, exp);
-    }
-    if (!d_state.isInConflict())
-    {
-      if (atom.getKind() == kind::MEMBER && polarity)
-      {
-        // check if set has a value, if so, we can propagate
-        Node r = d_equalityEngine.getRepresentative(atom[1]);
-        EqcInfo* e = getOrMakeEqcInfo(r, true);
-        if (e)
-        {
-          Node s = e->d_singleton;
-          if (!s.isNull())
-          {
-            Node pexp = NodeManager::currentNM()->mkNode(
-                kind::AND, atom, atom[1].eqNode(s));
-            d_keep.insert(pexp);
-            if (s.getKind() == kind::SINGLETON)
-            {
-              if (s[0] != atom[0])
-              {
-                Trace("sets-prop")
-                    << "Propagate mem-eq : " << pexp << std::endl;
-                Node eq = s[0].eqNode(atom[0]);
-                d_keep.insert(eq);
-                assertFact(eq, pexp);
-              }
-            }
-            else
-            {
-              Trace("sets-prop")
-                  << "Propagate mem-eq conflict : " << pexp << std::endl;
-              d_state.setConflict(pexp);
-            }
-          }
-        }
-        // add to membership list
-        NodeIntMap::iterator mem_i = d_members.find(r);
-        int n_members = 0;
-        if (mem_i != d_members.end())
-        {
-          n_members = (*mem_i).second;
-        }
-        d_members[r] = n_members + 1;
-        if (n_members < (int)d_members_data[r].size())
-        {
-          d_members_data[r][n_members] = atom;
-        }
-        else
-        {
-          d_members_data[r].push_back(atom);
-        }
-      }
-    }
-    return true;
-  }
-  else
-  {
-    return false;
-  }
-}
-
 void TheorySetsPrivate::fullEffortReset()
 {
-  Assert(d_equalityEngine.consistent());
+  Assert(d_equalityEngine->consistent());
   d_full_check_incomplete = false;
   d_most_common_type.clear();
   d_most_common_type_term.clear();
@@ -364,6 +216,7 @@ void TheorySetsPrivate::fullEffortReset()
   d_state.reset();
   // reset the inference manager
   d_im.reset();
+  d_im.clearPendingLemmas();
   // reset the cardinality solver
   d_cardSolver->reset();
 }
@@ -373,14 +226,14 @@ void TheorySetsPrivate::fullEffortCheck()
   Trace("sets") << "----- Full effort check ------" << std::endl;
   do
   {
-    Assert(!d_im.hasPendingLemmas() || d_im.hasProcessed());
+    Assert(!d_im.hasPendingLemma() || d_im.hasSent());
 
     Trace("sets") << "...iterate full effort check..." << std::endl;
     fullEffortReset();
 
     Trace("sets-eqc") << "Equality Engine:" << std::endl;
     std::map<TypeNode, unsigned> eqcTypeCount;
-    eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(&d_equalityEngine);
+    eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(d_equalityEngine);
     while (!eqcs_i.isFinished())
     {
       Node eqc = (*eqcs_i);
@@ -398,13 +251,13 @@ void TheorySetsPrivate::fullEffortCheck()
         tnct = eqc;
       }
       Trace("sets-eqc") << "[" << eqc << "] : ";
-      eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, &d_equalityEngine);
+      eq::EqClassIterator eqc_i = eq::EqClassIterator(eqc, d_equalityEngine);
       while (!eqc_i.isFinished())
       {
         Node n = (*eqc_i);
         if (n != eqc)
         {
-          Trace("sets-eqc") << n << " ";
+          Trace("sets-eqc") << n << " (" << n.isConst() << ") ";
         }
         TypeNode tnn = n.getType();
         if (isSet)
@@ -421,15 +274,21 @@ void TheorySetsPrivate::fullEffortCheck()
         }
         // register it with the state
         d_state.registerTerm(eqc, tnn, n);
-        if (n.getKind() == kind::CARD)
+        Kind nk = n.getKind();
+        if (nk == kind::SINGLETON)
+        {
+          // ensure the proxy has been introduced
+          d_treg.getProxy(n);
+        }
+        else if (nk == kind::CARD)
         {
           d_card_enabled = true;
           // register it with the cardinality solver
           d_cardSolver->registerTerm(n);
           // if we do not handle the kind, set incomplete
-          Kind nk = n[0].getKind();
+          Kind nk0 = n[0].getKind();
           // some kinds of cardinality we cannot handle
-          if (d_rels->isRelationKind(nk))
+          if (d_rels->isRelationKind(nk0))
           {
             d_full_check_incomplete = true;
             Trace("sets-incomplete")
@@ -445,12 +304,9 @@ void TheorySetsPrivate::fullEffortCheck()
             // 4- Supporting cardinality for relations (hard)
           }
         }
-        else
+        else if (d_rels->isRelationKind(nk))
         {
-          if (d_rels->isRelationKind(n.getKind()))
-          {
-            d_rels_enabled = true;
-          }
+          d_rels_enabled = true;
         }
         ++eqc_i;
       }
@@ -478,7 +334,7 @@ void TheorySetsPrivate::fullEffortCheck()
 
     // We may have sent lemmas while registering the terms in the loop above,
     // e.g. the cardinality solver.
-    if (d_im.hasProcessed())
+    if (d_im.hasSent())
     {
       continue;
     }
@@ -507,36 +363,36 @@ void TheorySetsPrivate::fullEffortCheck()
     }
     // check subtypes
     checkSubtypes();
-    d_im.flushPendingLemmas(true);
-    if (d_im.hasProcessed())
+    d_im.doPendingLemmas();
+    if (d_im.hasSent())
     {
       continue;
     }
     // check downwards closure
     checkDownwardsClosure();
-    d_im.flushPendingLemmas();
-    if (d_im.hasProcessed())
+    d_im.doPendingLemmas();
+    if (d_im.hasSent())
     {
       continue;
     }
     // check upwards closure
     checkUpwardsClosure();
-    d_im.flushPendingLemmas();
-    if (d_im.hasProcessed())
+    d_im.doPendingLemmas();
+    if (d_im.hasSent())
     {
       continue;
     }
     // check disequalities
     checkDisequalities();
-    d_im.flushPendingLemmas();
-    if (d_im.hasProcessed())
+    d_im.doPendingLemmas();
+    if (d_im.hasSent())
     {
       continue;
     }
     // check reduce comprehensions
     checkReduceComprehensions();
-    d_im.flushPendingLemmas();
-    if (d_im.hasProcessed())
+    d_im.doPendingLemmas();
+    if (d_im.hasSent())
     {
       continue;
     }
@@ -544,7 +400,7 @@ void TheorySetsPrivate::fullEffortCheck()
     {
       // call the check method of the cardinality solver
       d_cardSolver->check();
-      if (d_im.hasProcessed())
+      if (d_im.hasSent())
       {
         continue;
       }
@@ -555,8 +411,8 @@ void TheorySetsPrivate::fullEffortCheck()
       d_rels->check(Theory::EFFORT_FULL);
     }
   } while (!d_im.hasSentLemma() && !d_state.isInConflict()
-           && d_im.hasAddedFact());
-  Assert(!d_im.hasPendingLemmas() || d_im.hasProcessed());
+           && d_im.hasSentFact());
+  Assert(!d_im.hasPendingLemma() || d_im.hasSent());
   Trace("sets") << "----- End full effort check, conflict="
                 << d_state.isInConflict() << ", lemma=" << d_im.hasSentLemma()
                 << std::endl;
@@ -588,7 +444,7 @@ void TheorySetsPrivate::checkSubtypes()
           exp.push_back(it2.second);
           Assert(d_state.areEqual(mctt, it2.second[1]));
           exp.push_back(mctt.eqNode(it2.second[1]));
-          Node tc_k = d_state.getTypeConstraintSkolem(it2.first, mct);
+          Node tc_k = d_treg.getTypeConstraintSkolem(it2.first, mct);
           if (!tc_k.isNull())
           {
             Node etc = tc_k.eqNode(it2.first);
@@ -624,7 +480,7 @@ void TheorySetsPrivate::checkDownwardsClosure()
           {
             Node mem = it2.second;
             Node eq_set = nv;
-            Assert(d_equalityEngine.areEqual(mem[1], eq_set));
+            Assert(d_equalityEngine->areEqual(mem[1], eq_set));
             if (mem[1] != eq_set)
             {
               Trace("sets-debug") << "Downwards closure based on " << mem
@@ -646,7 +502,7 @@ void TheorySetsPrivate::checkDownwardsClosure()
               else
               {
                 // use proxy set
-                Node k = d_state.getProxy(eq_set);
+                Node k = d_treg.getProxy(eq_set);
                 Node pmem =
                     NodeManager::currentNM()->mkNode(kind::MEMBER, mem[0], k);
                 Node nmem = NodeManager::currentNM()->mkNode(
@@ -761,10 +617,10 @@ void TheorySetsPrivate::checkUpwardsClosure()
                 }
                 if (valid)
                 {
-                  Node rr = d_equalityEngine.getRepresentative(term);
-                  if (!isMember(x, rr))
+                  Node rr = d_equalityEngine->getRepresentative(term);
+                  if (!d_state.isMember(x, rr))
                   {
-                    Node kk = d_state.getProxy(term);
+                    Node kk = d_treg.getProxy(term);
                     Node fact = nm->mkNode(kind::MEMBER, x, kk);
                     d_im.assertInference(fact, exp, "upc", inferType);
                     if (d_state.isInConflict())
@@ -785,13 +641,13 @@ void TheorySetsPrivate::checkUpwardsClosure()
                 for (const std::pair<const Node, Node>& itm2m : r2mem)
                 {
                   Node x = itm2m.second[0];
-                  Node rr = d_equalityEngine.getRepresentative(term);
-                  if (!isMember(x, rr))
+                  Node rr = d_equalityEngine->getRepresentative(term);
+                  if (!d_state.isMember(x, rr))
                   {
                     std::vector<Node> exp;
                     exp.push_back(itm2m.second);
                     d_state.addEqualityToExp(term[1], itm2m.second[1], exp);
-                    Node r = d_state.getProxy(term);
+                    Node r = d_treg.getProxy(term);
                     Node fact = nm->mkNode(kind::MEMBER, x, r);
                     d_im.assertInference(fact, exp, "upc2");
                     if (d_state.isInConflict())
@@ -807,7 +663,7 @@ void TheorySetsPrivate::checkUpwardsClosure()
       }
     }
   }
-  if (!d_im.hasProcessed())
+  if (!d_im.hasSent())
   {
     if (options::setsExt())
     {
@@ -837,7 +693,7 @@ void TheorySetsPrivate::checkUpwardsClosure()
               // equivalence class
               if (s != ueqc)
               {
-                u = d_state.getUnivSet(tn);
+                u = d_treg.getUnivSet(tn);
               }
               univ_set[tn] = u;
             }
@@ -882,10 +738,10 @@ void TheorySetsPrivate::checkDisequalities()
     }
     Node deq = (*it).first;
     // check if it is already satisfied
-    Assert(d_equalityEngine.hasTerm(deq[0])
-           && d_equalityEngine.hasTerm(deq[1]));
-    Node r1 = d_equalityEngine.getRepresentative(deq[0]);
-    Node r2 = d_equalityEngine.getRepresentative(deq[1]);
+    Assert(d_equalityEngine->hasTerm(deq[0])
+           && d_equalityEngine->hasTerm(deq[1]));
+    Node r1 = d_equalityEngine->getRepresentative(deq[0]);
+    Node r2 = d_equalityEngine->getRepresentative(deq[1]);
     bool is_sat = d_state.isSetDisequalityEntailed(r1, r2);
     Trace("sets-debug") << "Check disequality " << deq
                         << ", is_sat = " << is_sat << std::endl;
@@ -906,15 +762,15 @@ void TheorySetsPrivate::checkDisequalities()
     d_termProcessed.insert(deq[1].eqNode(deq[0]));
     Trace("sets") << "Process Disequality : " << deq.negate() << std::endl;
     TypeNode elementType = deq[0].getType().getSetElementType();
-    Node x = d_state.getSkolemCache().mkTypedSkolemCached(
+    Node x = d_skCache.mkTypedSkolemCached(
         elementType, deq[0], deq[1], SkolemCache::SK_DISEQUAL, "sde");
     Node mem1 = nm->mkNode(MEMBER, x, deq[0]);
     Node mem2 = nm->mkNode(MEMBER, x, deq[1]);
     Node lem = nm->mkNode(OR, deq, nm->mkNode(EQUAL, mem1, mem2).negate());
     lem = Rewriter::rewrite(lem);
-    d_im.assertInference(lem, d_emp_exp, "diseq", 1);
-    d_im.flushPendingLemmas();
-    if (d_im.hasProcessed())
+    d_im.assertInference(lem, d_true, "diseq", 1);
+    d_im.doPendingLemmas();
+    if (d_im.hasSent())
     {
       return;
     }
@@ -952,30 +808,14 @@ void TheorySetsPrivate::checkReduceComprehensions()
         nm->mkNode(FORALL, nm->mkNode(BOUND_VAR_LIST, v), body.eqNode(mem));
     Trace("sets-comprehension")
         << "Comprehension reduction: " << lem << std::endl;
-    d_im.flushLemma(lem);
+    d_im.lemma(lem);
   }
 }
 
-/**************************** TheorySetsPrivate *****************************/
-/**************************** TheorySetsPrivate *****************************/
-/**************************** TheorySetsPrivate *****************************/
+//--------------------------------- standard check
 
-void TheorySetsPrivate::check(Theory::Effort level)
+void TheorySetsPrivate::postCheck(Theory::Effort level)
 {
-  Trace("sets-check") << "Sets check effort " << level << std::endl;
-  if (level == Theory::EFFORT_LAST_CALL)
-  {
-    return;
-  }
-  while (!d_external.done() && !d_state.isInConflict())
-  {
-    // Get all the assertions
-    Assertion assertion = d_external.get();
-    TNode fact = assertion.d_assertion;
-    Trace("sets-assert") << "Assert from input " << fact << std::endl;
-    // assert the fact
-    assertFact(fact, fact);
-  }
   Trace("sets-check") << "Sets finished assertions effort " << level
                       << std::endl;
   // invoke full effort check, relations check
@@ -989,25 +829,60 @@ void TheorySetsPrivate::check(Theory::Effort level)
         if (!d_state.isInConflict() && !d_im.hasSentLemma()
             && d_full_check_incomplete)
         {
-          d_external.d_out->setIncomplete();
+          d_im.setIncomplete();
         }
       }
     }
   }
   Trace("sets-check") << "Sets finish Check effort " << level << std::endl;
-} /* TheorySetsPrivate::check() */
+}
+
+void TheorySetsPrivate::notifyFact(TNode atom, bool polarity, TNode fact)
+{
+  if (d_state.isInConflict())
+  {
+    return;
+  }
+  if (atom.getKind() == kind::MEMBER && polarity)
+  {
+    // check if set has a value, if so, we can propagate
+    Node r = d_equalityEngine->getRepresentative(atom[1]);
+    EqcInfo* e = getOrMakeEqcInfo(r, true);
+    if (e)
+    {
+      Node s = e->d_singleton;
+      if (!s.isNull())
+      {
+        Node pexp = NodeManager::currentNM()->mkNode(
+            kind::AND, atom, atom[1].eqNode(s));
+        if (s.getKind() == kind::SINGLETON)
+        {
+          if (s[0] != atom[0])
+          {
+            Trace("sets-prop") << "Propagate mem-eq : " << pexp << std::endl;
+            Node eq = s[0].eqNode(atom[0]);
+            // triggers an internal inference
+            d_im.assertInternalFact(eq, true, pexp);
+          }
+        }
+        else
+        {
+          Trace("sets-prop")
+              << "Propagate mem-eq conflict : " << pexp << std::endl;
+          d_im.conflict(pexp);
+        }
+      }
+    }
+    // add to membership list
+    d_state.addMember(r, atom);
+  }
+}
+//--------------------------------- end standard check
 
 /************************ Sharing ************************/
 /************************ Sharing ************************/
 /************************ Sharing ************************/
 
-void TheorySetsPrivate::addSharedTerm(TNode n)
-{
-  Debug("sets") << "[sets] TheorySetsPrivate::addSharedTerm( " << n << ")"
-                << std::endl;
-  d_equalityEngine.addTriggerTerm(n, THEORY_SETS);
-}
-
 void TheorySetsPrivate::addCarePairs(TNodeTrie* t1,
                                      TNodeTrie* t2,
                                      unsigned arity,
@@ -1028,21 +903,21 @@ void TheorySetsPrivate::addCarePairs(TNodeTrie* t1,
         {
           TNode x = f1[k];
           TNode y = f2[k];
-          Assert(d_equalityEngine.hasTerm(x));
-          Assert(d_equalityEngine.hasTerm(y));
+          Assert(d_equalityEngine->hasTerm(x));
+          Assert(d_equalityEngine->hasTerm(y));
           Assert(!d_state.areDisequal(x, y));
           Assert(!areCareDisequal(x, y));
-          if (!d_equalityEngine.areEqual(x, y))
+          if (!d_equalityEngine->areEqual(x, y))
           {
             Trace("sets-cg")
                 << "Arg #" << k << " is " << x << " " << y << std::endl;
-            if (d_equalityEngine.isTriggerTerm(x, THEORY_SETS)
-                && d_equalityEngine.isTriggerTerm(y, THEORY_SETS))
+            if (d_equalityEngine->isTriggerTerm(x, THEORY_SETS)
+                && d_equalityEngine->isTriggerTerm(y, THEORY_SETS))
             {
-              TNode x_shared =
-                  d_equalityEngine.getTriggerTermRepresentative(x, THEORY_SETS);
-              TNode y_shared =
-                  d_equalityEngine.getTriggerTermRepresentative(y, THEORY_SETS);
+              TNode x_shared = d_equalityEngine->getTriggerTermRepresentative(
+                  x, THEORY_SETS);
+              TNode y_shared = d_equalityEngine->getTriggerTermRepresentative(
+                  y, THEORY_SETS);
               currentPairs.push_back(make_pair(x_shared, y_shared));
             }
             else if (isCareArg(f1, k) && isCareArg(f2, k))
@@ -1092,7 +967,7 @@ void TheorySetsPrivate::addCarePairs(TNodeTrie* t1,
         ++it2;
         for (; it2 != t1->d_data.end(); ++it2)
         {
-          if (!d_equalityEngine.areDisequal(it->first, it2->first, false))
+          if (!d_equalityEngine->areDisequal(it->first, it2->first, false))
           {
             if (!areCareDisequal(it->first, it2->first))
             {
@@ -1110,7 +985,7 @@ void TheorySetsPrivate::addCarePairs(TNodeTrie* t1,
       {
         for (std::pair<const TNode, TNodeTrie>& tt2 : t2->d_data)
         {
-          if (!d_equalityEngine.areDisequal(tt1.first, tt2.first, false))
+          if (!d_equalityEngine->areDisequal(tt1.first, tt2.first, false))
           {
             if (!areCareDisequal(tt1.first, tt2.first))
             {
@@ -1140,9 +1015,9 @@ void TheorySetsPrivate::computeCareGraph()
       // populate indices
       for (TNode f1 : it.second)
       {
-        Assert(d_equalityEngine.hasTerm(f1));
+        Assert(d_equalityEngine->hasTerm(f1));
         Trace("sets-cg-debug") << "...build for " << f1 << std::endl;
-        Assert(d_equalityEngine.hasTerm(f1));
+        Assert(d_equalityEngine->hasTerm(f1));
         // break into index based on operator, and type of first argument (since
         // some operators are parametric)
         TypeNode tn = f1[0].getType();
@@ -1150,7 +1025,7 @@ void TheorySetsPrivate::computeCareGraph()
         bool hasCareArg = false;
         for (unsigned j = 0; j < f1.getNumChildren(); j++)
         {
-          reps.push_back(d_equalityEngine.getRepresentative(f1[j]));
+          reps.push_back(d_equalityEngine->getRepresentative(f1[j]));
           if (isCareArg(f1, j))
           {
             hasCareArg = true;
@@ -1184,7 +1059,7 @@ void TheorySetsPrivate::computeCareGraph()
 
 bool TheorySetsPrivate::isCareArg(Node n, unsigned a)
 {
-  if (d_equalityEngine.isTriggerTerm(n[a], THEORY_SETS))
+  if (d_equalityEngine->isTriggerTerm(n[a], THEORY_SETS))
   {
     return true;
   }
@@ -1199,37 +1074,6 @@ bool TheorySetsPrivate::isCareArg(Node n, unsigned a)
   }
 }
 
-EqualityStatus TheorySetsPrivate::getEqualityStatus(TNode a, TNode b)
-{
-  Assert(d_equalityEngine.hasTerm(a) && d_equalityEngine.hasTerm(b));
-  if (d_equalityEngine.areEqual(a, b))
-  {
-    // The terms are implied to be equal
-    return EQUALITY_TRUE;
-  }
-  if (d_equalityEngine.areDisequal(a, b, false))
-  {
-    // The terms are implied to be dis-equal
-    return EQUALITY_FALSE;
-  }
-  return EQUALITY_UNKNOWN;
-  /*
-  Node aModelValue = d_external.d_valuation.getModelValue(a);
-  if(aModelValue.isNull()) { return EQUALITY_UNKNOWN; }
-  Node bModelValue = d_external.d_valuation.getModelValue(b);
-  if(bModelValue.isNull()) { return EQUALITY_UNKNOWN; }
-  if( aModelValue == bModelValue ) {
-    // The term are true in current model
-    return EQUALITY_TRUE_IN_MODEL;
-  } else {
-    return EQUALITY_FALSE_IN_MODEL;
-  }
-  */
-  // }
-  // //TODO: can we be more precise sometimes?
-  // return EQUALITY_UNKNOWN;
-}
-
 /******************** Model generation ********************/
 /******************** Model generation ********************/
 /******************** Model generation ********************/
@@ -1264,18 +1108,10 @@ std::string traceElements(const Node& set)
 
 }  // namespace
 
-bool TheorySetsPrivate::collectModelInfo(TheoryModel* m)
+bool TheorySetsPrivate::collectModelValues(TheoryModel* m,
+                                           const std::set<Node>& termSet)
 {
-  Trace("sets-model") << "Set collect model info" << std::endl;
-  set<Node> termSet;
-  // Compute terms appearing in assertions and shared terms
-  d_external.computeRelevantTerms(termSet);
-
-  // Assert equalities and disequalities to the model
-  if (!m->assertEqualityEngine(&d_equalityEngine, &termSet))
-  {
-    return false;
-  }
+  Trace("sets-model") << "Set collect model values" << std::endl;
 
   NodeManager* nm = NodeManager::currentNM();
   std::map<Node, Node> mvals;
@@ -1398,51 +1234,8 @@ Node mkAnd(const std::vector<TNode>& conjunctions)
   return conjunction;
 } /* mkAnd() */
 
-void TheorySetsPrivate::propagate(Theory::Effort effort) {}
-
-bool TheorySetsPrivate::propagate(TNode literal)
-{
-  Debug("sets-prop") << " propagate(" << literal << ")" << std::endl;
-
-  // If already in conflict, no more propagation
-  if (d_state.isInConflict())
-  {
-    Debug("sets-prop") << "TheoryUF::propagate(" << literal
-                       << "): already in conflict" << std::endl;
-    return false;
-  }
-
-  // Propagate out
-  bool ok = d_external.d_out->propagate(literal);
-  if (!ok)
-  {
-    d_state.setConflict();
-  }
-
-  return ok;
-} /* TheorySetsPrivate::propagate(TNode) */
-
-OutputChannel* TheorySetsPrivate::getOutputChannel()
-{
-  return d_external.d_out;
-}
-
 Valuation& TheorySetsPrivate::getValuation() { return d_external.d_valuation; }
 
-void TheorySetsPrivate::setMasterEqualityEngine(eq::EqualityEngine* eq)
-{
-  d_equalityEngine.setMasterEqualityEngine(eq);
-}
-
-void TheorySetsPrivate::conflict(TNode a, TNode b)
-{
-  Node conf = explain(a.eqNode(b));
-  d_state.setConflict(conf);
-  Debug("sets") << "[sets] conflict: " << a << " iff " << b << ", explanation "
-                << conf << std::endl;
-  Trace("sets-lemma") << "Equality Conflict : " << conf << std::endl;
-}
-
 Node TheorySetsPrivate::explain(TNode literal)
 {
   Debug("sets") << "TheorySetsPrivate::explain(" << literal << ")" << std::endl;
@@ -1453,11 +1246,11 @@ Node TheorySetsPrivate::explain(TNode literal)
 
   if (atom.getKind() == kind::EQUAL)
   {
-    d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions);
+    d_equalityEngine->explainEquality(atom[0], atom[1], polarity, assumptions);
   }
   else if (atom.getKind() == kind::MEMBER)
   {
-    d_equalityEngine.explainPredicate(atom, polarity, assumptions);
+    d_equalityEngine->explainPredicate(atom, polarity, assumptions);
   }
   else
   {
@@ -1475,154 +1268,126 @@ void TheorySetsPrivate::preRegisterTerm(TNode node)
                 << std::endl;
   switch (node.getKind())
   {
-    case kind::EQUAL: d_equalityEngine.addTriggerEquality(node); break;
-    case kind::MEMBER: d_equalityEngine.addTriggerPredicate(node); break;
-    case kind::CARD: d_equalityEngine.addTriggerTerm(node, THEORY_SETS); break;
-    default: d_equalityEngine.addTerm(node); break;
+    case kind::EQUAL:
+    case kind::MEMBER:
+    {
+      // add trigger predicate for equality and membership
+      d_equalityEngine->addTriggerPredicate(node);
+    }
+    break;
+    case kind::CARD: d_equalityEngine->addTriggerTerm(node, THEORY_SETS); break;
+    default: d_equalityEngine->addTerm(node); break;
   }
 }
 
-Node TheorySetsPrivate::expandDefinition(Node node)
+TrustNode TheorySetsPrivate::expandDefinition(Node node)
 {
   Debug("sets-proc") << "expandDefinition : " << node << std::endl;
 
-  if (node.getKind() == kind::CHOOSE)
+  switch (node.getKind())
   {
-    // (choose A) is expanded as
-    // (witness ((x elementType))
-    //    (ite
-    //      (= A (as emptyset setType))
-    //      (= x chooseUf(A))
-    //      (and (member x A) (= x chooseUf(A)))
-
-    NodeManager* nm = NodeManager::currentNM();
-    Node set = node[0];
-    TypeNode setType = set.getType();
-    Node chooseSkolem = getChooseFunction(setType);
-    Node apply = NodeManager::currentNM()->mkNode(APPLY_UF, chooseSkolem, set);
-
-    Node witnessVariable = nm->mkBoundVar(setType.getSetElementType());
-
-    Node equal = witnessVariable.eqNode(apply);
-    Node emptySet = nm->mkConst(EmptySet(setType.toType()));
-    Node isEmpty = set.eqNode(emptySet);
-    Node member = nm->mkNode(MEMBER, witnessVariable, set);
-    Node memberAndEqual = member.andNode(equal);
-    Node ite = nm->mkNode(kind::ITE, isEmpty, equal, memberAndEqual);
-    Node witnessVariables = nm->mkNode(BOUND_VAR_LIST, witnessVariable);
-    Node witness = nm->mkNode(WITNESS, witnessVariables, ite);
-    return witness;
+    case kind::CHOOSE: return expandChooseOperator(node);
+    case kind::IS_SINGLETON: return expandIsSingletonOperator(node);
+    default: return TrustNode::null();
   }
-
-  return node;
 }
 
-Node TheorySetsPrivate::getChooseFunction(const TypeNode& setType)
+TrustNode TheorySetsPrivate::expandChooseOperator(const Node& node)
 {
-  std::map<TypeNode, Node>::iterator it = d_chooseFunctions.find(setType);
-  if (it != d_chooseFunctions.end())
+  Assert(node.getKind() == CHOOSE);
+
+  // we call the rewriter here to handle the pattern (choose (singleton x))
+  // because the rewriter is called after expansion
+  Node rewritten = Rewriter::rewrite(node);
+  if (rewritten.getKind() != CHOOSE)
   {
-    return it->second;
+    return TrustNode::mkTrustRewrite(node, rewritten, nullptr);
   }
 
+  // (choose A) is expanded as
+  // (witness ((x elementType))
+  //    (ite
+  //      (= A (as emptyset setType))
+  //      (= x chooseUf(A))
+  //      (and (member x A) (= x chooseUf(A)))
+
   NodeManager* nm = NodeManager::currentNM();
-  TypeNode chooseUf = nm->mkFunctionType(setType, setType.getSetElementType());
-  stringstream stream;
-  stream << "chooseUf" << setType.getId();
-  string name = stream.str();
-  Node chooseSkolem = nm->mkSkolem(
-      name, chooseUf, "choose function", NodeManager::SKOLEM_EXACT_NAME);
-  d_chooseFunctions[setType] = chooseSkolem;
-  return chooseSkolem;
+  Node set = rewritten[0];
+  TypeNode setType = set.getType();
+  Node chooseSkolem = getChooseFunction(setType);
+  Node apply = NodeManager::currentNM()->mkNode(APPLY_UF, chooseSkolem, set);
+
+  Node witnessVariable = nm->mkBoundVar(setType.getSetElementType());
+
+  Node equal = witnessVariable.eqNode(apply);
+  Node emptySet = nm->mkConst(EmptySet(setType));
+  Node isEmpty = set.eqNode(emptySet);
+  Node member = nm->mkNode(MEMBER, witnessVariable, set);
+  Node memberAndEqual = member.andNode(equal);
+  Node ite = nm->mkNode(ITE, isEmpty, equal, memberAndEqual);
+  Node witnessVariables = nm->mkNode(BOUND_VAR_LIST, witnessVariable);
+  Node witness = nm->mkNode(WITNESS, witnessVariables, ite);
+  return TrustNode::mkTrustRewrite(node, witness, nullptr);
 }
 
-void TheorySetsPrivate::presolve() { d_state.reset(); }
-
-/**************************** eq::NotifyClass *****************************/
-/**************************** eq::NotifyClass *****************************/
-/**************************** eq::NotifyClass *****************************/
-
-bool TheorySetsPrivate::NotifyClass::eqNotifyTriggerEquality(TNode equality,
-                                                             bool value)
+TrustNode TheorySetsPrivate::expandIsSingletonOperator(const Node& node)
 {
-  Debug("sets-eq") << "[sets-eq] eqNotifyTriggerEquality: equality = "
-                   << equality << " value = " << value << std::endl;
-  if (value)
-  {
-    return d_theory.propagate(equality);
-  }
-  else
-  {
-    // We use only literal triggers so taking not is safe
-    return d_theory.propagate(equality.notNode());
-  }
-}
+  Assert(node.getKind() == IS_SINGLETON);
 
-bool TheorySetsPrivate::NotifyClass::eqNotifyTriggerPredicate(TNode predicate,
-                                                              bool value)
-{
-  Debug("sets-eq") << "[sets-eq] eqNotifyTriggerPredicate: predicate = "
-                   << predicate << " value = " << value << std::endl;
-  if (value)
+  // we call the rewriter here to handle the pattern
+  // (is_singleton (singleton x)) because the rewriter is called after expansion
+  Node rewritten = Rewriter::rewrite(node);
+  if (rewritten.getKind() != IS_SINGLETON)
   {
-    return d_theory.propagate(predicate);
+    return TrustNode::mkTrustRewrite(node, rewritten, nullptr);
   }
-  else
+
+  // (is_singleton A) is expanded as
+  // (exists ((x: T)) (= A (singleton x)))
+  // where T is the sort of elements of A
+
+  NodeManager* nm = NodeManager::currentNM();
+  Node set = rewritten[0];
+
+  std::map<Node, Node>::iterator it = d_isSingletonNodes.find(rewritten);
+
+  if (it != d_isSingletonNodes.end())
   {
-    return d_theory.propagate(predicate.notNode());
+    return TrustNode::mkTrustRewrite(rewritten, it->second, nullptr);
   }
-}
-
-bool TheorySetsPrivate::NotifyClass::eqNotifyTriggerTermEquality(TheoryId tag,
-                                                                 TNode t1,
-                                                                 TNode t2,
-                                                                 bool value)
-{
-  Debug("sets-eq") << "[sets-eq] eqNotifyTriggerTermEquality: tag = " << tag
-                   << " t1 = " << t1 << "  t2 = " << t2 << "  value = " << value
-                   << std::endl;
-  d_theory.propagate(value ? t1.eqNode(t2) : t1.eqNode(t2).negate());
-  return true;
-}
 
-void TheorySetsPrivate::NotifyClass::eqNotifyConstantTermMerge(TNode t1,
-                                                               TNode t2)
-{
-  Debug("sets-eq") << "[sets-eq] eqNotifyConstantTermMerge "
-                   << " t1 = " << t1 << " t2 = " << t2 << std::endl;
-  d_theory.conflict(t1, t2);
-}
+  TypeNode setType = set.getType();
+  Node boundVar = nm->mkBoundVar(setType.getSetElementType());
+  Node singleton = nm->mkNode(kind::SINGLETON, boundVar);
+  Node equal = set.eqNode(singleton);
+  std::vector<Node> variables = {boundVar};
+  Node boundVars = nm->mkNode(BOUND_VAR_LIST, variables);
+  Node exists = nm->mkNode(kind::EXISTS, boundVars, equal);
+  d_isSingletonNodes[rewritten] = exists;
 
-void TheorySetsPrivate::NotifyClass::eqNotifyNewClass(TNode t)
-{
-  Debug("sets-eq") << "[sets-eq] eqNotifyNewClass:"
-                   << " t = " << t << std::endl;
-  d_theory.eqNotifyNewClass(t);
+  return TrustNode::mkTrustRewrite(node, exists, nullptr);
 }
 
-void TheorySetsPrivate::NotifyClass::eqNotifyPreMerge(TNode t1, TNode t2)
+Node TheorySetsPrivate::getChooseFunction(const TypeNode& setType)
 {
-  Debug("sets-eq") << "[sets-eq] eqNotifyPreMerge:"
-                   << " t1 = " << t1 << " t2 = " << t2 << std::endl;
-  d_theory.eqNotifyPreMerge(t1, t2);
-}
+  std::map<TypeNode, Node>::iterator it = d_chooseFunctions.find(setType);
+  if (it != d_chooseFunctions.end())
+  {
+    return it->second;
+  }
 
-void TheorySetsPrivate::NotifyClass::eqNotifyPostMerge(TNode t1, TNode t2)
-{
-  Debug("sets-eq") << "[sets-eq] eqNotifyPostMerge:"
-                   << " t1 = " << t1 << " t2 = " << t2 << std::endl;
-  d_theory.eqNotifyPostMerge(t1, t2);
+  NodeManager* nm = NodeManager::currentNM();
+  TypeNode chooseUf = nm->mkFunctionType(setType, setType.getSetElementType());
+  stringstream stream;
+  stream << "chooseUf" << setType.getId();
+  string name = stream.str();
+  Node chooseSkolem = nm->mkSkolem(
+      name, chooseUf, "choose function", NodeManager::SKOLEM_EXACT_NAME);
+  d_chooseFunctions[setType] = chooseSkolem;
+  return chooseSkolem;
 }
 
-void TheorySetsPrivate::NotifyClass::eqNotifyDisequal(TNode t1,
-                                                      TNode t2,
-                                                      TNode reason)
-{
-  Debug("sets-eq") << "[sets-eq] eqNotifyDisequal:"
-                   << " t1 = " << t1 << " t2 = " << t2 << " reason = " << reason
-                   << std::endl;
-  d_theory.eqNotifyDisequal(t1, t2, reason);
-}
+void TheorySetsPrivate::presolve() { d_state.reset(); }
 
 }  // namespace sets
 }  // namespace theory