Move disequality list to solver state in strings (#3678)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Thu, 30 Jan 2020 17:53:54 +0000 (11:53 -0600)
committerGitHub <noreply@github.com>
Thu, 30 Jan 2020 17:53:54 +0000 (11:53 -0600)
src/theory/strings/solver_state.cpp
src/theory/strings/solver_state.h
src/theory/strings/theory_strings.cpp
src/theory/strings/theory_strings.h

index 66ae8d6bcc4dd9e3231104c8326de3f203d25a39..664b56b17d1b2f6af7b70c2504a938c8724765ab 100644 (file)
@@ -143,6 +143,7 @@ SolverState::SolverState(context::Context* c,
                          Valuation& v)
     : d_context(c),
       d_ee(ee),
+      d_eeDisequalities(c),
       d_valuation(v),
       d_conflict(c, false),
       d_pendingConflict(c)
@@ -200,6 +201,57 @@ bool SolverState::areDisequal(Node a, Node b) const
 
 eq::EqualityEngine* SolverState::getEqualityEngine() const { return &d_ee; }
 
+const context::CDList<Node>& SolverState::getDisequalityList() const
+{
+  return d_eeDisequalities;
+}
+
+void SolverState::eqNotifyPreMerge(TNode t1, TNode t2)
+{
+  EqcInfo* e2 = getOrMakeEqcInfo(t2, false);
+  if (e2)
+  {
+    EqcInfo* e1 = getOrMakeEqcInfo(t1);
+    // add information from e2 to e1
+    if (!e2->d_lengthTerm.get().isNull())
+    {
+      e1->d_lengthTerm.set(e2->d_lengthTerm);
+    }
+    if (!e2->d_codeTerm.get().isNull())
+    {
+      e1->d_codeTerm.set(e2->d_codeTerm);
+    }
+    if (!e2->d_prefixC.get().isNull())
+    {
+      setPendingConflictWhen(
+          e1->addEndpointConst(e2->d_prefixC, Node::null(), false));
+    }
+    if (!e2->d_suffixC.get().isNull())
+    {
+      setPendingConflictWhen(
+          e1->addEndpointConst(e2->d_suffixC, Node::null(), true));
+    }
+    if (e2->d_cardinalityLemK.get() > e1->d_cardinalityLemK.get())
+    {
+      e1->d_cardinalityLemK.set(e2->d_cardinalityLemK);
+    }
+    if (!e2->d_normalizedLength.get().isNull())
+    {
+      e1->d_normalizedLength.set(e2->d_normalizedLength);
+    }
+  }
+}
+
+void SolverState::eqNotifyDisequal(TNode t1, TNode t2, TNode reason)
+{
+  if (t1.getType().isString())
+  {
+    // store disequalities between strings, may need to check if their lengths
+    // are equal/disequal
+    d_eeDisequalities.push_back(t1.eqNode(t2));
+  }
+}
+
 EqcInfo* SolverState::getOrMakeEqcInfo(Node eqc, bool doMake)
 {
   std::map<Node, EqcInfo*>::iterator eqc_i = d_eqcInfo.find(eqc);
index 46d198d36189d17af43647e91d0d483b3ff1e3b3..cb17e6d1b052de7888b7c44cb2f8d47cd9a4d103 100644 (file)
@@ -88,6 +88,8 @@ class EqcInfo
  */
 class SolverState
 {
+  typedef context::CDList<Node> NodeList;
+
  public:
   SolverState(context::Context* c, eq::EqualityEngine& ee, Valuation& v);
   ~SolverState();
@@ -111,7 +113,18 @@ class SolverState
   bool areDisequal(Node a, Node b) const;
   /** get equality engine */
   eq::EqualityEngine* getEqualityEngine() const;
+  /**
+   * Get the list of disequalities that are currently asserted to the equality
+   * engine.
+   */
+  const context::CDList<Node>& getDisequalityList() const;
   //-------------------------------------- end equality information
+  //-------------------------------------- notifications for equalities
+  /** called when two equivalence classes will merge */
+  void eqNotifyPreMerge(TNode t1, TNode t2);
+  /** called when two equivalence classes are made disequal */
+  void eqNotifyDisequal(TNode t1, TNode t2, TNode reason);
+  //-------------------------------------- end notifications for equalities
   //------------------------------------------ conflicts
   /**
    * Set that the current state of the solver is in conflict. This should be
@@ -188,6 +201,11 @@ class SolverState
   context::Context* d_context;
   /** Reference to equality engine of the theory of strings. */
   eq::EqualityEngine& d_ee;
+  /**
+   * The (SAT-context-dependent) list of disequalities that have been asserted
+   * to the equality engine above.
+   */
+  NodeList d_eeDisequalities;
   /** Reference to the valuation of the theory of strings */
   Valuation& d_valuation;
   /** Are we in conflict? */
index 755e6b4df2f56e7a3010920bba51ebd0642c51d2..152160cde828ac0bbe5e21cfe62eeba4e46b1fab 100644 (file)
@@ -101,7 +101,6 @@ TheoryStrings::TheoryStrings(context::Context* c,
       d_registered_terms_cache(u),
       d_preproc(&d_sk_cache, u),
       d_extf_infer_cache(c),
-      d_ee_disequalities(c),
       d_congruent(c),
       d_proxy_var(u),
       d_proxy_var_to_length(u),
@@ -1075,37 +1074,7 @@ void TheoryStrings::eqNotifyNewClass(TNode t){
 
 /** called when two equivalance classes will merge */
 void TheoryStrings::eqNotifyPreMerge(TNode t1, TNode t2){
-  EqcInfo* e2 = d_state.getOrMakeEqcInfo(t2, false);
-  if( e2 ){
-    EqcInfo* e1 = d_state.getOrMakeEqcInfo(t1);
-    //add information from e2 to e1
-    if (!e2->d_lengthTerm.get().isNull())
-    {
-      e1->d_lengthTerm.set(e2->d_lengthTerm);
-    }
-    if (!e2->d_codeTerm.get().isNull())
-    {
-      e1->d_codeTerm.set(e2->d_codeTerm);
-    }
-    if (!e2->d_prefixC.get().isNull())
-    {
-      d_state.setPendingConflictWhen(
-          e1->addEndpointConst(e2->d_prefixC, Node::null(), false));
-    }
-    if (!e2->d_suffixC.get().isNull())
-    {
-      d_state.setPendingConflictWhen(
-          e1->addEndpointConst(e2->d_suffixC, Node::null(), true));
-    }
-    if (e2->d_cardinalityLemK.get() > e1->d_cardinalityLemK.get())
-    {
-      e1->d_cardinalityLemK.set(e2->d_cardinalityLemK);
-    }
-    if (!e2->d_normalizedLength.get().isNull())
-    {
-      e1->d_normalizedLength.set(e2->d_normalizedLength);
-    }
-  }
+  d_state.eqNotifyPreMerge(t1, t2);
 }
 
 /** called when two equivalance classes have merged */
@@ -1115,10 +1084,7 @@ void TheoryStrings::eqNotifyPostMerge(TNode t1, TNode t2) {
 
 /** called when two equivalance classes are disequal */
 void TheoryStrings::eqNotifyDisequal(TNode t1, TNode t2, TNode reason) {
-  if( t1.getType().isString() ){
-    //store disequalities between strings, may need to check if their lengths are equal/disequal
-    d_ee_disequalities.push_back( t1.eqNode( t2 ) );
-  }
+  d_state.eqNotifyDisequal(t1, t2, reason);
 }
 
 void TheoryStrings::addCarePairs(TNodeTrie* t1,
@@ -4139,10 +4105,12 @@ void TheoryStrings::checkNormalFormsDeq()
   std::vector< std::vector< Node > > cols;
   std::vector< Node > lts;
   std::map< Node, std::map< Node, bool > > processed;
-  
+
+  const NodeList& deqs = d_state.getDisequalityList();
+
   //for each pair of disequal strings, must determine whether their lengths are equal or disequal
-  for( NodeList::const_iterator id = d_ee_disequalities.begin(); id != d_ee_disequalities.end(); ++id ) {
-    Node eq = *id;
+  for (const Node& eq : deqs)
+  {
     Node n[2];
     for( unsigned i=0; i<2; i++ ){
       n[i] = d_equalityEngine.getRepresentative( eq[i] );
index 990461027eeda99b90106054cbdfafce11f99a0b..ce92ada8687246a3ca9c38ec917f5097e0b7e29a 100644 (file)
@@ -243,8 +243,6 @@ class TheoryStrings : public Theory {
   // extended functions inferences cache
   NodeSet d_extf_infer_cache;
   std::vector< Node > d_empty_vec;
-  //
-  NodeList d_ee_disequalities;
 private:
   NodeSet d_congruent;
   /**