Introduce best content heuristic for strings (#4382)
authorAndres Noetzli <andres.noetzli@gmail.com>
Thu, 23 Apr 2020 12:58:00 +0000 (05:58 -0700)
committerGitHub <noreply@github.com>
Thu, 23 Apr 2020 12:58:00 +0000 (07:58 -0500)
* Introduce best content heuristic for strings

This commit introduces a "best content heuristic" to perform
context-dependent simplifications. The high-level idea is that for each
equivalence class for strings, we compute a representation that is a
string concatentation of constants and other string terms. For this
representation, we try to get as many letters in the string constants as
we can (i.e. the best approximation of the content). This "best content"
representation is then used by `EXTF_EVAL` to perform simplifications.

Co-authored-by: Andrew Reynolds <andrew.j.reynolds@gmail.com>
src/theory/strings/base_solver.cpp
src/theory/strings/base_solver.h
src/theory/strings/extf_solver.cpp

index 1f8d2f49cf6485faba85b9b3a7021bc776e9a6da..8711973f4ba982bb249e4c5f5b300aaf74cee65a 100644 (file)
@@ -17,6 +17,7 @@
 
 #include "options/strings_options.h"
 #include "theory/strings/theory_strings_utils.h"
+#include "theory/strings/word.h"
 
 using namespace std;
 using namespace CVC4::context;
@@ -42,9 +43,7 @@ BaseSolver::~BaseSolver() {}
 void BaseSolver::checkInit()
 {
   // build term index
-  d_eqcToConst.clear();
-  d_eqcToConstBase.clear();
-  d_eqcToConstExp.clear();
+  d_eqcInfo.clear();
   d_termIndex.clear();
   d_stringsEqc.clear();
 
@@ -70,9 +69,9 @@ void BaseSolver::checkInit()
         Node n = *eqc_i;
         if (n.isConst())
         {
-          d_eqcToConst[eqc] = n;
-          d_eqcToConstBase[eqc] = n;
-          d_eqcToConstExp[eqc] = Node::null();
+          d_eqcInfo[eqc].d_bestContent = n;
+          d_eqcInfo[eqc].d_base = n;
+          d_eqcInfo[eqc].d_exp = Node::null();
         }
         else if (tn.isInteger())
         {
@@ -241,20 +240,33 @@ void BaseSolver::checkConstantEquivalenceClasses()
     vecc.clear();
     Trace("strings-process-debug")
         << "Check constant equivalence classes..." << std::endl;
-    prevSize = d_eqcToConst.size();
-    checkConstantEquivalenceClasses(&d_termIndex[STRING_CONCAT], vecc);
-  } while (!d_im.hasProcessed() && d_eqcToConst.size() > prevSize);
+    prevSize = d_eqcInfo.size();
+    checkConstantEquivalenceClasses(&d_termIndex[STRING_CONCAT], vecc, true);
+  } while (!d_im.hasProcessed() && d_eqcInfo.size() > prevSize);
+
+  if (!d_im.hasProcessed())
+  {
+    // now, go back and set "most content" terms
+    vecc.clear();
+    checkConstantEquivalenceClasses(&d_termIndex[STRING_CONCAT], vecc, false);
+  }
 }
 
 void BaseSolver::checkConstantEquivalenceClasses(TermIndex* ti,
-                                                 std::vector<Node>& vecc)
+                                                 std::vector<Node>& vecc,
+                                                 bool ensureConst,
+                                                 bool isConst)
 {
   Node n = ti->d_data;
   if (!n.isNull())
   {
-    // construct the constant
-    Node c = utils::mkNConcat(vecc, n.getType());
-    if (!d_state.areEqual(n, c))
+    // construct the constant if applicable
+    Node c;
+    if (isConst)
+    {
+      c = utils::mkNConcat(vecc, n.getType());
+    }
+    if (!isConst || !d_state.areEqual(n, c))
     {
       if (Trace.isOn("strings-debug"))
       {
@@ -270,8 +282,12 @@ void BaseSolver::checkConstantEquivalenceClasses(TermIndex* ti,
       size_t count = 0;
       size_t countc = 0;
       std::vector<Node> exp;
+      // non-constant vector
+      std::vector<Node> vecnc;
+      size_t contentSize = 0;
       while (count < n.getNumChildren())
       {
+        // Add explanations for the empty children
         while (count < n.getNumChildren()
                && d_state.areEqual(n[count], d_emptyString))
         {
@@ -280,26 +296,65 @@ void BaseSolver::checkConstantEquivalenceClasses(TermIndex* ti,
         }
         if (count < n.getNumChildren())
         {
-          Trace("strings-debug")
-              << "...explain " << n[count] << " " << vecc[countc] << std::endl;
-          if (!d_state.areEqual(n[count], vecc[countc]))
+          if (vecc[countc].isNull())
           {
-            Node nrr = d_state.getRepresentative(n[count]);
-            Assert(!d_eqcToConstExp[nrr].isNull());
-            d_im.addToExplanation(n[count], d_eqcToConstBase[nrr], exp);
-            exp.push_back(d_eqcToConstExp[nrr]);
+            Assert(!isConst);
+            // no constant for this component, leave it as is
+            vecnc.push_back(n[count]);
           }
           else
           {
-            d_im.addToExplanation(n[count], vecc[countc], exp);
+            if (!isConst)
+            {
+              // use the constant
+              vecnc.push_back(vecc[countc]);
+              Assert(vecc[countc].isConst());
+              contentSize += Word::getLength(vecc[countc]);
+            }
+            Trace("strings-debug") << "...explain " << n[count] << " "
+                                   << vecc[countc] << std::endl;
+            if (!d_state.areEqual(n[count], vecc[countc]))
+            {
+              Node nrr = d_state.getRepresentative(n[count]);
+              Assert(!d_eqcInfo[nrr].d_bestContent.isNull()
+                     && d_eqcInfo[nrr].d_bestContent.isConst());
+              d_im.addToExplanation(n[count], d_eqcInfo[nrr].d_base, exp);
+              exp.push_back(d_eqcInfo[nrr].d_exp);
+            }
+            else
+            {
+              d_im.addToExplanation(n[count], vecc[countc], exp);
+            }
+            countc++;
           }
-          countc++;
           count++;
         }
       }
       // exp contains an explanation of n==c
-      Assert(countc == vecc.size());
-      if (d_state.hasTerm(c))
+      Assert(!isConst || countc == vecc.size());
+      if (!isConst)
+      {
+        // no use storing something with no content
+        if (contentSize > 0)
+        {
+          Node nr = d_state.getRepresentative(n);
+          BaseEqcInfo& bei = d_eqcInfo[nr];
+          if (!bei.d_bestContent.isConst()
+              && (bei.d_bestContent.isNull() || contentSize > bei.d_bestScore))
+          {
+            // The equivalence class is not entailed to be equal to a constant
+            // and we found a better concatenation
+            Node nct = utils::mkNConcat(vecnc, n.getType());
+            Assert(!nct.isConst());
+            bei.d_bestContent = nct;
+            bei.d_base = n;
+            bei.d_exp = utils::mkAnd(exp);
+            Trace("strings-debug")
+                << "Set eqc best content " << n << " to " << nct << std::endl;
+          }
+        }
+      }
+      else if (d_state.hasTerm(c))
       {
         d_im.sendInference(exp, n.eqNode(c), Inference::I_CONST_MERGE);
         return;
@@ -307,31 +362,31 @@ void BaseSolver::checkConstantEquivalenceClasses(TermIndex* ti,
       else if (!d_im.hasProcessed())
       {
         Node nr = d_state.getRepresentative(n);
-        std::map<Node, Node>::iterator it = d_eqcToConst.find(nr);
-        if (it == d_eqcToConst.end())
+        BaseEqcInfo& bei = d_eqcInfo[nr];
+        if (!bei.d_bestContent.isConst())
         {
           Trace("strings-debug")
               << "Set eqc const " << n << " to " << c << std::endl;
-          d_eqcToConst[nr] = c;
-          d_eqcToConstBase[nr] = n;
-          d_eqcToConstExp[nr] = utils::mkAnd(exp);
+          bei.d_bestContent = c;
+          bei.d_base = n;
+          bei.d_exp = utils::mkAnd(exp);
         }
-        else if (c != it->second)
+        else if (c != bei.d_bestContent)
         {
           // conflict
           Trace("strings-debug")
-              << "Conflict, other constant was " << it->second
+              << "Conflict, other constant was " << bei.d_bestContent
               << ", this constant was " << c << std::endl;
-          if (d_eqcToConstExp[nr].isNull())
+          if (bei.d_exp.isNull())
           {
             // n==c ^ n == c' => false
-            d_im.addToExplanation(n, it->second, exp);
+            d_im.addToExplanation(n, bei.d_bestContent, exp);
           }
           else
           {
-            // n==c ^ n == d_eqcToConstBase[nr] == c' => false
-            exp.push_back(d_eqcToConstExp[nr]);
-            d_im.addToExplanation(n, d_eqcToConstBase[nr], exp);
+            // n==c ^ n == d_base == c' => false
+            exp.push_back(bei.d_exp);
+            d_im.addToExplanation(n, bei.d_base, exp);
           }
           d_im.sendInference(exp, d_false, Inference::I_CONST_CONFLICT);
           return;
@@ -345,16 +400,23 @@ void BaseSolver::checkConstantEquivalenceClasses(TermIndex* ti,
   }
   for (std::pair<const TNode, TermIndex>& p : ti->d_children)
   {
-    std::map<Node, Node>::iterator itc = d_eqcToConst.find(p.first);
-    if (itc != d_eqcToConst.end())
+    std::map<Node, BaseEqcInfo>::const_iterator it = d_eqcInfo.find(p.first);
+    if (it != d_eqcInfo.end() && it->second.d_bestContent.isConst())
     {
-      vecc.push_back(itc->second);
-      checkConstantEquivalenceClasses(&p.second, vecc);
+      vecc.push_back(it->second.d_bestContent);
+      checkConstantEquivalenceClasses(&p.second, vecc, ensureConst, isConst);
+      vecc.pop_back();
+    }
+    else if (!ensureConst)
+    {
+      // can still proceed, with null
+      vecc.push_back(Node::null());
+      checkConstantEquivalenceClasses(&p.second, vecc, ensureConst, false);
       vecc.pop_back();
-      if (d_im.hasProcessed())
-      {
-        break;
-      }
+    }
+    if (d_im.hasProcessed())
+    {
+      break;
     }
   }
 }
@@ -499,29 +561,55 @@ bool BaseSolver::isCongruent(Node n)
 
 Node BaseSolver::getConstantEqc(Node eqc)
 {
-  std::map<Node, Node>::iterator it = d_eqcToConst.find(eqc);
-  if (it != d_eqcToConst.end())
+  std::map<Node, BaseEqcInfo>::const_iterator it = d_eqcInfo.find(eqc);
+  if (it != d_eqcInfo.end() && it->second.d_bestContent.isConst())
   {
-    return it->second;
+    return it->second.d_bestContent;
   }
   return Node::null();
 }
 
 Node BaseSolver::explainConstantEqc(Node n, Node eqc, std::vector<Node>& exp)
 {
-  std::map<Node, Node>::iterator it = d_eqcToConst.find(eqc);
-  if (it != d_eqcToConst.end())
+  std::map<Node, BaseEqcInfo>::const_iterator it = d_eqcInfo.find(eqc);
+  if (it != d_eqcInfo.end())
+  {
+    BaseEqcInfo& bei = d_eqcInfo[eqc];
+    if (!bei.d_bestContent.isConst())
+    {
+      return Node::null();
+    }
+    if (!bei.d_exp.isNull())
+    {
+      exp.push_back(bei.d_exp);
+    }
+    if (!bei.d_base.isNull())
+    {
+      d_im.addToExplanation(n, bei.d_base, exp);
+    }
+    return bei.d_bestContent;
+  }
+  return Node::null();
+}
+
+Node BaseSolver::explainBestContentEqc(Node n, Node eqc, std::vector<Node>& exp)
+{
+  std::map<Node, BaseEqcInfo>::const_iterator it = d_eqcInfo.find(eqc);
+  if (it != d_eqcInfo.end())
   {
-    if (!d_eqcToConstExp[eqc].isNull())
+    BaseEqcInfo& bei = d_eqcInfo[eqc];
+    Assert(!bei.d_bestContent.isNull());
+    if (!bei.d_exp.isNull())
     {
-      exp.push_back(d_eqcToConstExp[eqc]);
+      exp.push_back(bei.d_exp);
     }
-    if (!d_eqcToConstBase[eqc].isNull())
+    if (!bei.d_base.isNull())
     {
-      d_im.addToExplanation(n, d_eqcToConstBase[eqc], exp);
+      d_im.addToExplanation(n, bei.d_base, exp);
     }
-    return it->second;
+    return bei.d_bestContent;
   }
+
   return Node::null();
 }
 
index 3681b49a4a1d7e5310efd3c86bcd068416dc947e..1960b83521a3386535fd3ea37b6f7f95a214156c 100644 (file)
@@ -99,6 +99,10 @@ class BaseSolver
    * equivalence class of eqc.
    */
   Node explainConstantEqc(Node n, Node eqc, std::vector<Node>& exp);
+  /**
+   * Same as above, for "best content" terms.
+   */
+  Node explainBestContentEqc(Node n, Node eqc, std::vector<Node>& exp);
   /**
    * Get the set of equivalence classes of type string.
    */
@@ -106,6 +110,48 @@ class BaseSolver
   //-----------------------end query functions
 
  private:
+  /**
+   * The information that we associated with each equivalence class.
+   *
+   * Example 1. Consider the equivalence class { r, x++"a"++y, x++z }, and
+   * assume x = "" and y = "bb" in the current context. We have that
+   *   d_bestContent = "abb",
+   *   d_base = x++"a"++y
+   *   d_exp = ( x = "" AND y = "bb" )
+   *
+   * Example 2. Consider the equivalence class { r, x++"a"++w++y, x++z }, and
+   * assume x = "" and y = "bb" in the current context. We have that
+   *   d_bestContent = "a" ++ w ++ "bb",
+   *   d_bestScore = 3
+   *   d_base = x++"a"++w++y
+   *   d_exp = ( x = "" AND y = "bb" )
+   *
+   * This information is computed during checkInit and is used during various
+   * inference schemas for deriving inferences.
+   */
+  struct BaseEqcInfo
+  {
+    /**
+     * Either a constant or a concatentation of constants and variables that
+     * this equivalence class is entailed to be equal to. If it is a
+     * concatenation, this is the concatenation that is currently known to have
+     * the highest score (see `d_bestScore`).
+     */
+    Node d_bestContent;
+    /**
+     * The sum of the number of characters in the string literals of
+     * `d_bestContent`.
+     */
+    size_t d_bestScore;
+    /**
+     * The term in the equivalence class that is entailed to be equal to
+     * `d_bestContent`.
+     */
+    Node d_base;
+    /** This term explains why `d_bestContent` is equal to `d_base`. */
+    Node d_exp;
+  };
+
   /**
    * A term index that considers terms modulo flattening and constant merging
    * for concatenation terms.
@@ -143,8 +189,17 @@ class BaseSolver
    * accumulates the list of constants in the path to ti. If ti has a non-null
    * data n, then we have inferred that d_data is equivalent to the
    * constant specified by vecc.
+   *
+   * @param ti The term index for string concatenations
+   * @param vecc The list of constants in the path to ti
+   * @param ensureConst If true, require that each element in the path is
+   *                    constant
+   * @param isConst If true, the path so far only includes constants
    */
-  void checkConstantEquivalenceClasses(TermIndex* ti, std::vector<Node>& vecc);
+  void checkConstantEquivalenceClasses(TermIndex* ti,
+                                       std::vector<Node>& vecc,
+                                       bool ensureConst = true,
+                                       bool isConst = true);
   /** The solver state object */
   SolverState& d_state;
   /** The (custom) output channel of the theory of strings */
@@ -164,27 +219,10 @@ class BaseSolver
    */
   NodeSet d_congruent;
   /**
-   * The following three vectors are used for tracking constants that each
-   * equivalence class is entailed to be equal to.
-   * - The map d_eqcToConst maps (representatives) r of equivalence classes to
-   * the constant that that equivalence class is entailed to be equal to,
-   * - The term d_eqcToConstBase[r] is the term in the equivalence class r
-   * that is entailed to be equal to the constant d_eqcToConst[r],
-   * - The term d_eqcToConstExp[r] is the explanation of why
-   * d_eqcToConstBase[r] is equal to d_eqcToConst[r].
-   *
-   * For example, consider the equivalence class { r, x++"a"++y, x++z }, and
-   * assume x = "" and y = "bb" in the current context. We have that
-   *   d_eqcToConst[r] = "abb",
-   *   d_eqcToConstBase[r] = x++"a"++y
-   *   d_eqcToConstExp[r] = ( x = "" AND y = "bb" )
-   *
-   * This information is computed during checkInit and is used during various
-   * inference schemas for deriving inferences.
+   * Maps equivalence classes to their info, see description of `BaseEqcInfo`
+   * for more information.
    */
-  std::map<Node, Node> d_eqcToConst;
-  std::map<Node, Node> d_eqcToConstBase;
-  std::map<Node, Node> d_eqcToConstExp;
+  std::map<Node, BaseEqcInfo> d_eqcInfo;
   /** The list of equivalence classes of type string */
   std::vector<Node> d_stringsEqc;
   /** A term index for each function kind */
index 775b4a7966e0b22a6d02be2ed35c018fedd01fcd..55985406ed78135d7bb2d0a5eb80563c0443ba8c 100644 (file)
@@ -650,7 +650,7 @@ Node ExtfSolver::getCurrentSubstitutionFor(int effort,
     return mv;
   }
   Node nr = d_state.getRepresentative(n);
-  Node c = d_bsolver.explainConstantEqc(n, nr, exp);
+  Node c = d_bsolver.explainBestContentEqc(n, nr, exp);
   if (!c.isNull())
   {
     return c;