Move sets member propagation to SolverState (#5045)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 15 Sep 2020 16:07:21 +0000 (11:07 -0500)
committerGitHub <noreply@github.com>
Tue, 15 Sep 2020 16:07:21 +0000 (11:07 -0500)
This eliminates the parent relationship from solver state to theory sets.

src/theory/sets/solver_state.cpp
src/theory/sets/solver_state.h
src/theory/sets/theory_sets.cpp
src/theory/sets/theory_sets_private.cpp
src/theory/sets/theory_sets_private.h

index 941f59bc6e1517683fffab26afc31204434e11e3..79c7bc1c8b22490e9bc6b4abd0528842c53f8b58 100644 (file)
@@ -29,14 +29,12 @@ SolverState::SolverState(context::Context* c,
                          context::UserContext* u,
                          Valuation val,
                          SkolemCache& skc)
-    : TheoryState(c, u, val), d_skCache(skc), d_parent(nullptr)
+    : TheoryState(c, u, val), d_skCache(skc), d_members(c)
 {
   d_true = NodeManager::currentNM()->mkConst(true);
   d_false = NodeManager::currentNM()->mkConst(false);
 }
 
-void SolverState::setParent(TheorySetsPrivate* p) { d_parent = p; }
-
 void SolverState::reset()
 {
   d_set_eqc.clear();
@@ -249,7 +247,7 @@ bool SolverState::isEntailed(Node n, bool polarity) const
     if (polarity && d_ee->hasTerm(n[1]))
     {
       Node r = d_ee->getRepresentative(n[1]);
-      if (d_parent->isMember(n[0], r))
+      if (isMember(n[0], r))
       {
         return true;
       }
@@ -469,6 +467,125 @@ const vector<Node> SolverState::getSetsEqClasses(const TypeNode& t) const
   return representatives;
 }
 
+bool SolverState::isMember(TNode x, TNode s) const
+{
+  Assert(hasTerm(s) && getRepresentative(s) == s);
+  NodeIntMap::const_iterator mem_i = d_members.find(s);
+  if (mem_i != d_members.end())
+  {
+    std::map<Node, std::vector<Node> >::const_iterator itd =
+        d_members_data.find(s);
+    Assert(itd != d_members_data.end());
+    const std::vector<Node>& members = itd->second;
+    Assert((*mem_i).second <= members.size());
+    for (size_t i = 0, nmem = (*mem_i).second; i < nmem; i++)
+    {
+      if (areEqual(members[i][0], x))
+      {
+        return true;
+      }
+    }
+  }
+  return false;
+}
+
+void SolverState::addMember(TNode r, TNode atom)
+{
+  NodeIntMap::iterator mem_i = d_members.find(r);
+  size_t n_members = 0;
+  if (mem_i != d_members.end())
+  {
+    n_members = (*mem_i).second;
+  }
+  d_members[r] = n_members + 1;
+  if (n_members < d_members_data[r].size())
+  {
+    d_members_data[r][n_members] = atom;
+  }
+  else
+  {
+    d_members_data[r].push_back(atom);
+  }
+}
+
+bool SolverState::merge(TNode t1,
+                        TNode t2,
+                        std::vector<Node>& facts,
+                        TNode cset)
+{
+  NodeIntMap::iterator mem_i2 = d_members.find(t2);
+  if (mem_i2 == d_members.end())
+  {
+    // no members in t2, we are done
+    return true;
+  }
+  NodeIntMap::iterator mem_i1 = d_members.find(t1);
+  size_t n_members = 0;
+  if (mem_i1 != d_members.end())
+  {
+    n_members = (*mem_i1).second;
+  }
+  for (size_t i = 0, nmem2 = (*mem_i2).second; i < nmem2; i++)
+  {
+    Assert(i < d_members_data[t2].size()
+           && d_members_data[t2][i].getKind() == MEMBER);
+    Node m2 = d_members_data[t2][i];
+    // check if redundant
+    bool add = true;
+    for (size_t j = 0; j < n_members; j++)
+    {
+      Assert(j < d_members_data[t1].size()
+             && d_members_data[t1][j].getKind() == MEMBER);
+      if (areEqual(m2[0], d_members_data[t1][j][0]))
+      {
+        add = false;
+        break;
+      }
+    }
+    if (add)
+    {
+      // if there is a concrete set in t1, propagate new facts or conflicts
+      if (!cset.isNull())
+      {
+        NodeManager* nm = NodeManager::currentNM();
+        Assert(areEqual(m2[1], cset));
+        Node exp = nm->mkNode(AND, m2[1].eqNode(cset), m2);
+        if (cset.getKind() == SINGLETON)
+        {
+          if (cset[0] != m2[0])
+          {
+            Node eq = cset[0].eqNode(m2[0]);
+            Trace("sets-prop") << "Propagate eq-mem eq inference : " << exp
+                               << " => " << eq << std::endl;
+            Node fact = nm->mkNode(IMPLIES, exp, eq);
+            facts.push_back(fact);
+          }
+        }
+        else
+        {
+          // conflict
+          Assert(facts.empty());
+          Trace("sets-prop")
+              << "Propagate eq-mem conflict : " << exp << std::endl;
+          facts.push_back(exp);
+          return false;
+        }
+      }
+      if (n_members < d_members_data[t1].size())
+      {
+        d_members_data[t1][n_members] = m2;
+      }
+      else
+      {
+        d_members_data[t1].push_back(m2);
+      }
+      n_members++;
+    }
+  }
+  d_members[t1] = n_members;
+  return true;
+}
+
 }  // namespace sets
 }  // namespace theory
 }  // namespace CVC4
index 245ad93f605dd09fdfa2f6fd56df9e0308883fd9..29c13b2a2e767f27dfc37976d1eeac96e0ed2069 100644 (file)
@@ -42,13 +42,13 @@ class TheorySetsPrivate;
  */
 class SolverState : public TheoryState
 {
+  typedef context::CDHashMap<Node, size_t, NodeHashFunction> NodeIntMap;
+
  public:
   SolverState(context::Context* c,
               context::UserContext* u,
               Valuation val,
               SkolemCache& skc);
-  /** Set parent */
-  void setParent(TheorySetsPrivate* p);
   //-------------------------------- initialize per check
   /** reset, clears the data structures maintained by this class. */
   void reset();
@@ -156,6 +156,30 @@ class SolverState : public TheoryState
   /** Get the list of all comprehension sets in the current context */
   const std::vector<Node>& getComprehensionSets() const;
 
+  /**
+   * Is x entailed to be a member of set s in the current context?
+   */
+  bool isMember(TNode x, TNode s) const;
+  /**
+   * Add member, called when atom is of the form (member x s) where s is in the
+   * equivalence class of r.
+   */
+  void addMember(TNode r, TNode atom);
+  /**
+   * Called when equivalence classes t1 and t2 merge. This updates the
+   * membership lists, adding members of t2 into t1.
+   *
+   * If cset is non-null, then this is a singleton or empty set in the
+   * equivalence class of t1 where moreover t2 has no singleton or empty sets.
+   * When this is the case, notice that all members of t2 should be made equal
+   * to the element that cset contains, or we are in conflict if cset is the
+   * empty set. These conclusions are added to facts.
+   *
+   * This method returns false if a (single) conflict was added to facts, and
+   * true otherwise.
+   */
+  bool merge(TNode t1, TNode t2, std::vector<Node>& facts, TNode cset);
+
  private:
   /** constants */
   Node d_true;
@@ -165,8 +189,6 @@ class SolverState : public TheoryState
   std::map<Node, Node> d_emptyMap;
   /** Reference to skolem cache */
   SkolemCache& d_skCache;
-  /** Pointer to the parent theory of sets */
-  TheorySetsPrivate* d_parent;
   /** The list of all equivalence classes of type set in the current context */
   std::vector<Node> d_set_eqc;
   /** Maps types to the equivalence class containing empty set of that type */
@@ -209,7 +231,17 @@ class SolverState : public TheoryState
   /** A list of comprehension sets */
   std::vector<Node> d_allCompSets;
   // -------------------------------- end term indices
+  /** List of operators per kind */
   std::map<Kind, std::vector<Node> > d_op_list;
+  //--------------------------------- SAT-context-dependent member list
+  /**
+   * Map from representatives r of set equivalence classes to atoms of the form
+   * (member x s) where s is in the equivalence class of r.
+   */
+  std::map<Node, std::vector<Node> > d_members_data;
+  /** A (SAT-context-dependent) number of members in the above map */
+  NodeIntMap d_members;
+  //--------------------------------- end
   /** is set disequality entailed internal
    *
    * This returns true if disequality between sets a and b is entailed in the
index fe5e56aa6c14edb4c73b40e6e8f083330ce4a640..66f35f24fcabcd1bd0677c41602484c7144857d5 100644 (file)
@@ -43,9 +43,6 @@ TheorySets::TheorySets(context::Context* c,
   // use the official theory state and inference manager objects
   d_theoryState = &d_state;
   d_inferManager = &d_im;
-
-  // TODO: remove this
-  d_state.setParent(d_internal.get());
 }
 
 TheorySets::~TheorySets()
index b1831f26199e49844442dcbf43e10c23b48568b2..c928718982763f819d3e14be69732eaa84a5ca1c 100644 (file)
@@ -38,8 +38,7 @@ TheorySetsPrivate::TheorySetsPrivate(TheorySets& external,
                                      SolverState& state,
                                      InferenceManager& im,
                                      SkolemCache& skc)
-    : d_members(state.getSatContext()),
-      d_deq(state.getSatContext()),
+    : d_deq(state.getSatContext()),
       d_termProcessed(state.getUserContext()),
       d_full_check_incomplete(false),
       d_external(external),
@@ -126,71 +125,27 @@ void TheorySetsPrivate::eqNotifyMerge(TNode t1, TNode t2)
     }
     // merge membership list
     Trace("sets-prop-debug") << "Copying membership list..." << std::endl;
-    NodeIntMap::iterator mem_i2 = d_members.find(t2);
-    if (mem_i2 != d_members.end())
+    // if s1 has a singleton or empty set and s2 does not, we may have new
+    // inferences to process.
+    Node checkSingleton = s2.isNull() ? s1 : Node::null();
+    std::vector<Node> facts;
+    // merge the membership list in the state, which may produce facts or
+    // conflicts to propagate
+    if (!d_state.merge(t1, t2, facts, checkSingleton))
     {
-      NodeIntMap::iterator mem_i1 = d_members.find(t1);
-      int n_members = 0;
-      if (mem_i1 != d_members.end())
-      {
-        n_members = (*mem_i1).second;
-      }
-      for (int i = 0; i < (*mem_i2).second; i++)
-      {
-        Assert(i < (int)d_members_data[t2].size()
-               && d_members_data[t2][i].getKind() == kind::MEMBER);
-        Node m2 = d_members_data[t2][i];
-        // check if redundant
-        bool add = true;
-        for (int j = 0; j < n_members; j++)
-        {
-          Assert(j < (int)d_members_data[t1].size()
-                 && d_members_data[t1][j].getKind() == kind::MEMBER);
-          if (d_state.areEqual(m2[0], d_members_data[t1][j][0]))
-          {
-            add = false;
-            break;
-          }
-        }
-        if (add)
-        {
-          if (!s1.isNull() && s2.isNull())
-          {
-            Assert(m2[1].getType().isComparableTo(s1.getType()));
-            Assert(d_state.areEqual(m2[1], s1));
-            Node exp = NodeManager::currentNM()->mkNode(
-                kind::AND, m2[1].eqNode(s1), m2);
-            if (s1.getKind() == kind::SINGLETON)
-            {
-              if (s1[0] != m2[0])
-              {
-                Node eq = s1[0].eqNode(m2[0]);
-                Trace("sets-prop") << "Propagate eq-mem eq inference : " << exp
-                                   << " => " << eq << std::endl;
-                d_im.assertInternalFact(eq, true, exp);
-              }
-            }
-            else
-            {
-              // conflict
-              Trace("sets-prop")
-                  << "Propagate eq-mem conflict : " << exp << std::endl;
-              d_im.conflict(exp);
-              return;
-            }
-          }
-          if (n_members < (int)d_members_data[t1].size())
-          {
-            d_members_data[t1][n_members] = m2;
-          }
-          else
-          {
-            d_members_data[t1].push_back(m2);
-          }
-          n_members++;
-        }
-      }
-      d_members[t1] = n_members;
+      // conflict case
+      Assert(facts.size() == 1);
+      Trace("sets-prop") << "Propagate eq-mem conflict : " << facts[0]
+                         << std::endl;
+      d_im.conflict(facts[0]);
+      return;
+    }
+    for (const Node& f : facts)
+    {
+      Assert(f.getKind() == kind::IMPLIES);
+      Trace("sets-prop") << "Propagate eq-mem eq inference : " << f[0] << " => "
+                         << f[1] << std::endl;
+      d_im.assertInternalFact(f[1], true, f[0]);
     }
   }
 }
@@ -249,24 +204,6 @@ bool TheorySetsPrivate::areCareDisequal(Node a, Node b)
   return false;
 }
 
-bool TheorySetsPrivate::isMember(Node x, Node s)
-{
-  Assert(d_equalityEngine->hasTerm(s)
-         && d_equalityEngine->getRepresentative(s) == s);
-  NodeIntMap::iterator mem_i = d_members.find(s);
-  if (mem_i != d_members.end())
-  {
-    for (int i = 0; i < (*mem_i).second; i++)
-    {
-      if (d_state.areEqual(d_members_data[s][i][0], x))
-      {
-        return true;
-      }
-    }
-  }
-  return false;
-}
-
 void TheorySetsPrivate::fullEffortReset()
 {
   Assert(d_equalityEngine->consistent());
@@ -681,7 +618,7 @@ void TheorySetsPrivate::checkUpwardsClosure()
                 if (valid)
                 {
                   Node rr = d_equalityEngine->getRepresentative(term);
-                  if (!isMember(x, rr))
+                  if (!d_state.isMember(x, rr))
                   {
                     Node kk = d_treg.getProxy(term);
                     Node fact = nm->mkNode(kind::MEMBER, x, kk);
@@ -705,7 +642,7 @@ void TheorySetsPrivate::checkUpwardsClosure()
                 {
                   Node x = itm2m.second[0];
                   Node rr = d_equalityEngine->getRepresentative(term);
-                  if (!isMember(x, rr))
+                  if (!d_state.isMember(x, rr))
                   {
                     std::vector<Node> exp;
                     exp.push_back(itm2m.second);
@@ -937,21 +874,7 @@ void TheorySetsPrivate::notifyFact(TNode atom, bool polarity, TNode fact)
       }
     }
     // add to membership list
-    NodeIntMap::iterator mem_i = d_members.find(r);
-    int n_members = 0;
-    if (mem_i != d_members.end())
-    {
-      n_members = (*mem_i).second;
-    }
-    d_members[r] = n_members + 1;
-    if (n_members < (int)d_members_data[r].size())
-    {
-      d_members_data[r][n_members] = atom;
-    }
-    else
-    {
-      d_members_data[r].push_back(atom);
-    }
+    d_state.addMember(r, atom);
   }
 }
 //--------------------------------- end standard check
index 71ad3781d6c08d4c7a37812d4131be4cef81926b..d58967c5f78e4643b65ad731a69fd983044aa724 100644 (file)
@@ -40,7 +40,6 @@ class TheorySets;
 
 class TheorySetsPrivate {
   typedef context::CDHashMap< Node, bool, NodeHashFunction> NodeBoolMap;
-  typedef context::CDHashMap< Node, int, NodeHashFunction> NodeIntMap;
   typedef context::CDHashSet<Node, NodeHashFunction> NodeSet;
 
  public:
@@ -51,8 +50,6 @@ class TheorySetsPrivate {
  private:
   /** Are a and b trigger terms in the equality engine that may be disequal? */
   bool areCareDisequal(Node a, Node b);
-  NodeIntMap d_members;
-  std::map< Node, std::vector< Node > > d_members_data;
   /**
    * Invoke the decision procedure for this theory, which is run at
    * full effort. This will either send a lemma or conflict on the output
@@ -246,8 +243,6 @@ class TheorySetsPrivate {
  public:
   /** Is formula n entailed to have polarity pol in the current context? */
   bool isEntailed(Node n, bool pol) { return d_state.isEntailed(n, pol); }
-  /** Is x entailed to be a member of set s in the current context? */
-  bool isMember(Node x, Node s);
 
  private:
   /** get choose function