Simplify handling of disequalities in strings (#8047)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Thu, 3 Feb 2022 23:48:39 +0000 (17:48 -0600)
committerGitHub <noreply@github.com>
Thu, 3 Feb 2022 23:48:39 +0000 (23:48 +0000)
This simplifies how string disequalities are handled, and fixes a caching bug in our implementation of extensionality.

The simplifications are two-fold:
- We track disequalities via assertions, not via an equality engine callback
- We process disequalities by iterating on the disequality list, not via iterating on pairs of equivalence classes

Extensionality is fixed by not *explaining* disequalities, which leads to being out of sync with the cache (as the disequality could have a different explanation in another SAT context).

This fixes the last known issues with `--seq-array=X`.

src/theory/strings/core_solver.cpp
src/theory/strings/theory_strings.cpp
src/theory/strings/theory_strings.h
test/regress/CMakeLists.txt
test/regress/regress0/seq/wrong-model-020322.smt2 [new file with mode: 0644]
test/regress/regress0/seq/wrong-sat-020322.smt2 [new file with mode: 0644]

index f2fe97c7bdd40424203109cc03f67a96944dfa1f..3dc97ddc4c4e664c4a435856df5ea2ef56d11a51 100644 (file)
@@ -1967,15 +1967,6 @@ CoreSolver::ProcessLoopResult CoreSolver::processLoop(NormalForm& nfi,
 
 void CoreSolver::processDeq(Node ni, Node nj)
 {
-  // If using the sequence update solver, we always apply extensionality.
-  // This is required for model soundness currently, although we could
-  // investigate determining cases where the disequality is already
-  // satisfied (for optimization).
-  if (options().strings.seqArray != options::SeqArrayMode::NONE)
-  {
-    processDeqExtensionality(ni, nj);
-    return;
-  }
   NodeManager* nm = NodeManager::currentNM();
   NormalForm& nfni = getNormalForm(ni);
   NormalForm& nfnj = getNormalForm(nj);
@@ -2073,12 +2064,6 @@ void CoreSolver::processDeq(Node ni, Node nj)
     return;
   }
 
-  if (options().strings.stringsDeqExt)
-  {
-    processDeqExtensionality(ni, nj);
-    return;
-  }
-
   nfi = nfni.d_nf;
   nfj = nfnj.d_nf;
 
@@ -2485,8 +2470,9 @@ void CoreSolver::processDeqExtensionality(Node n1, Node n2)
   Node conc = nm->mkNode(OR, lenDeq, nm->mkAnd(concs));
   // A != B => ( seq.len(A) != seq.len(B) or
   //             ( seq.nth(A, d) != seq.nth(B, d) ^ 0 <= d < seq.len(A) ) )
+  // Note that we take A != B verbatim, and do not explain it.
   d_im.sendInference(
-      {deq}, conc, InferenceId::STRINGS_DEQ_EXTENSIONALITY, false, true);
+      {deq}, {deq}, conc, InferenceId::STRINGS_DEQ_EXTENSIONALITY, false, true);
 }
 
 void CoreSolver::addNormalFormPair( Node n1, Node n2 ){
@@ -2542,9 +2528,12 @@ void CoreSolver::checkNormalFormsDeq()
   const context::CDList<Node>& deqs = d_state.getDisequalityList();
 
   NodeManager* nm = NodeManager::currentNM();
+  Trace("str-deq") << "Process disequalites..." << std::endl;
+  std::vector<Node> relevantDeqs;
   //for each pair of disequal strings, must determine whether their lengths are equal or disequal
   for (const Node& eq : deqs)
   {
+    Trace("str-deq") << "- disequality " << eq << std::endl;
     Node n[2];
     for( unsigned i=0; i<2; i++ ){
       n[i] = ee->getRepresentative( eq[i] );
@@ -2559,11 +2548,26 @@ void CoreSolver::checkNormalFormsDeq()
           lt[i] = nm->mkNode(STRING_LENGTH, eq[i]);
         }
       }
-      if (!d_state.areEqual(lt[0], lt[1]) && !d_state.areDisequal(lt[0], lt[1]))
+      if (d_state.areEqual(lt[0], lt[1]))
+      {
+        // if they have equal lengths, we must process the disequality below
+        relevantDeqs.push_back(eq);
+        Trace("str-deq") << "...relevant" << std::endl;
+      }
+      else if (!d_state.areDisequal(lt[0], lt[1]))
       {
         d_im.sendSplit(lt[0], lt[1], InferenceId::STRINGS_DEQ_LENGTH_SP);
+        Trace("str-deq") << "...split" << std::endl;
+      }
+      else
+      {
+        Trace("str-deq") << "...disequal length" << std::endl;
       }
     }
+    else
+    {
+      Trace("str-deq") << "...congruent" << std::endl;
+    }
   }
 
   if (d_im.hasProcessed())
@@ -2571,55 +2575,37 @@ void CoreSolver::checkNormalFormsDeq()
     // added splitting lemma above
     return;
   }
-  // otherwise, look at pairs of equivalence classes with equal lengths
-  std::map<TypeNode, std::vector<std::vector<Node> > > colsT;
-  std::map<TypeNode, std::vector<Node> > ltsT;
-  d_state.separateByLength(d_strings_eqc, colsT, ltsT);
-  for (std::pair<const TypeNode, std::vector<std::vector<Node> > >& ct : colsT)
+  for (const Node& eq : relevantDeqs)
   {
-    std::vector<std::vector<Node> >& cols = ct.second;
-    for( unsigned i=0; i<cols.size(); i++ ){
-      if (cols[i].size() > 1 && !d_im.hasPendingLemma())
-      {
-        if (Trace.isOn("strings-solve"))
-        {
-          Trace("strings-solve") << "- Verify disequalities are processed for "
-                                 << cols[i][0] << ", normal form : ";
-          utils::printConcatTrace(getNormalForm(cols[i][0]).d_nf, "strings-solve");
-          Trace("strings-solve")
-              << "... #eql = " << cols[i].size() << std::endl;
-        }
-        //must ensure that normal forms are disequal
-        for( unsigned j=0; j<cols[i].size(); j++ ){
-          for( unsigned k=(j+1); k<cols[i].size(); k++ ){
-            //for strings that are disequal, but have the same length
-            if (cols[i][j].isConst() && cols[i][k].isConst())
-            {
-              // if both are constants, they should be distinct, and its trivial
-              Assert(cols[i][j] != cols[i][k]);
-            }
-            else if (d_state.areDisequal(cols[i][j], cols[i][k]))
-            {
-              Assert(!d_state.isInConflict());
-              if (Trace.isOn("strings-solve"))
-              {
-                Trace("strings-solve") << "- Compare " << cols[i][j] << ", nf ";
-                utils::printConcatTrace(getNormalForm(cols[i][j]).d_nf,
-                                        "strings-solve");
-                Trace("strings-solve") << " against " << cols[i][k] << ", nf ";
-                utils::printConcatTrace(getNormalForm(cols[i][k]).d_nf,
-                                        "strings-solve");
-                Trace("strings-solve") << "..." << std::endl;
-              }
-              processDeq(cols[i][j], cols[i][k]);
-              if (d_im.hasProcessed())
-              {
-                return;
-              }
-            }
-          }
-        }
-      }
+    Assert(!d_state.isInConflict());
+    // If using the sequence update solver, we always apply extensionality.
+    // This is required for model soundness currently, although we could
+    // investigate determining cases where the disequality is already
+    // satisfied (for optimization).
+    if (options().strings.stringsDeqExt
+        || options().strings.seqArray != options::SeqArrayMode::NONE)
+    {
+      processDeqExtensionality(eq[0], eq[1]);
+      continue;
+    }
+    // the method below requires representatives
+    Node n[2];
+    for (size_t i = 0; i < 2; i++)
+    {
+      n[i] = ee->getRepresentative(eq[i]);
+    }
+    if (Trace.isOn("strings-solve"))
+    {
+      Trace("strings-solve") << "- Compare " << n[0] << ", nf ";
+      utils::printConcatTrace(getNormalForm(n[0]).d_nf, "strings-solve");
+      Trace("strings-solve") << " against " << n[1] << ", nf ";
+      utils::printConcatTrace(getNormalForm(n[1]).d_nf, "strings-solve");
+      Trace("strings-solve") << "..." << std::endl;
+    }
+    processDeq(n[0], n[1]);
+    if (d_im.hasProcessed())
+    {
+      return;
     }
   }
 }
index 2f8b3b8ce27cd69c907272dcb2a1e2b96d1b5517..94393c41b55fbcc32856821139ddaef852552f9a 100644 (file)
@@ -768,15 +768,23 @@ void TheoryStrings::preRegisterTerm(TNode n)
 bool TheoryStrings::preNotifyFact(
     TNode atom, bool pol, TNode fact, bool isPrereg, bool isInternal)
 {
-  // this is only required for internal facts, others are already registered
-  if (isInternal && atom.getKind() == EQUAL)
+  if (atom.getKind() == EQUAL)
   {
-    // We must ensure these terms are registered. We register eagerly here for
-    // performance reasons. Alternatively, terms could be registered at full
-    // effort in e.g. BaseSolver::init.
-    for (const Node& t : atom)
+    // this is only required for internal facts, others are already registered
+    if (isInternal)
+    {
+      // We must ensure these terms are registered. We register eagerly here for
+      // performance reasons. Alternatively, terms could be registered at full
+      // effort in e.g. BaseSolver::init.
+      for (const Node& t : atom)
+      {
+        d_termReg.registerTerm(t, 0);
+      }
+    }
+    // store disequalities between strings that occur as literals
+    if (!pol && atom[0].getType().isStringLike())
     {
-      d_termReg.registerTerm(t, 0);
+      d_state.addDisequality(atom[0], atom[1]);
     }
   }
   return false;
@@ -948,16 +956,6 @@ void TheoryStrings::eqNotifyMerge(TNode t1, TNode t2)
   }
 }
 
-void TheoryStrings::eqNotifyDisequal(TNode t1, TNode t2, TNode reason)
-{
-  if (t1.getType().isStringLike())
-  {
-    // store disequalities between strings, may need to check if their lengths
-    // are equal/disequal
-    d_state.addDisequality(t1, t2);
-  }
-}
-
 void TheoryStrings::addCarePairs(TNodeTrie* t1,
                                  TNodeTrie* t2,
                                  unsigned arity,
index 612fc7b54e3df5735a80fc580652977093443b70..6337e164b01f4c1609db8250f452ff633427e800 100644 (file)
@@ -109,8 +109,6 @@ class TheoryStrings : public Theory {
   void eqNotifyNewClass(TNode t);
   /** Called just after the merge of two equivalence classes */
   void eqNotifyMerge(TNode t1, TNode t2);
-  /** called a disequality is added */
-  void eqNotifyDisequal(TNode t1, TNode t2, TNode reason);
   /** preprocess rewrite */
   TrustNode ppRewrite(TNode atom, std::vector<SkolemLemma>& lems) override;
   /** Collect model values in m based on the relevant terms given by termSet */
@@ -161,8 +159,6 @@ class TheoryStrings : public Theory {
     }
     void eqNotifyDisequal(TNode t1, TNode t2, TNode reason) override
     {
-      Debug("strings") << "NotifyClass::eqNotifyDisequal(" << t1 << ", " << t2 << ", " << reason << std::endl;
-      d_str.eqNotifyDisequal(t1, t2, reason);
     }
 
    private:
index 3bdecf5561b2d3b2d51f1c7bec898cab1a10a011..fbe859eeb2c4e7e269c917309c3942f5d08ed1d5 100644 (file)
@@ -1151,6 +1151,8 @@ set(regress_0_tests
   regress0/seq/update-concat-non-atomic2.smt2
   regress0/seq/update-eq.smt2
   regress0/seq/update-eq-unsat.smt2
+  regress0/seq/wrong-model-020322.smt2
+  regress0/seq/wrong-sat-020322.smt2
   regress0/sets/abt-min.smt2
   regress0/sets/abt-te-exh.smt2
   regress0/sets/abt-te-exh2.smt2
diff --git a/test/regress/regress0/seq/wrong-model-020322.smt2 b/test/regress/regress0/seq/wrong-model-020322.smt2
new file mode 100644 (file)
index 0000000..683761e
--- /dev/null
@@ -0,0 +1,10 @@
+; COMMAND-LINE: --strings-exp --seq-array=lazy -q
+; EXPECT: sat
+(set-logic ALL)
+(set-info :status sat)
+(declare-sort E 0)
+(declare-fun k () E)
+(declare-fun s () (Seq E))
+(declare-fun j () Int)
+(assert (distinct (distinct s (str.update s j (seq.unit (seq.nth s 1)))) (distinct s (str.update (str.update s 0 (seq.unit k)) j (seq.unit (seq.nth s 1))))))
+(check-sat)
diff --git a/test/regress/regress0/seq/wrong-sat-020322.smt2 b/test/regress/regress0/seq/wrong-sat-020322.smt2
new file mode 100644 (file)
index 0000000..bce6053
--- /dev/null
@@ -0,0 +1,14 @@
+; COMMAND-LINE: --strings-exp --seq-array=lazy
+; EXPECT: unsat
+(set-logic ALL)
+(set-info :status unsat)
+(declare-sort E 0)
+(declare-fun k () E)
+(declare-fun s () (Seq E))
+(assert (distinct
+                (distinct s (str.update s 0 (seq.unit (seq.nth s 0))))
+                (distinct s
+                               (str.update (str.update s 0 (seq.unit k))
+                                                   0
+                                                   (seq.unit (seq.nth s 0))))))
+(check-sat)