Fix and refactor TheoryStrings::checkFlatForms() (#3326)
authorAndres Noetzli <andres.noetzli@gmail.com>
Fri, 1 Nov 2019 18:58:57 +0000 (11:58 -0700)
committerGitHub <noreply@github.com>
Fri, 1 Nov 2019 18:58:57 +0000 (11:58 -0700)
This commit fixes some minor (performance) issues in
`TheoryStrings::checkFlatForms()`: The `inelig` vector was initialized
with copies of the `start` element instead of all the elements before
`start` and the `else` branch of `count == asize` was looping over all
elements from `1` instead of `start + 1`. Additionally, this commit
refactors the code to be a bit more readable.

src/theory/strings/theory_strings.cpp
src/theory/strings/theory_strings.h

index 750710769c7264f0d0c9d26f7e9010947f07bf09..e34130c3f78876bfb90ecdec781265673ab1e8f5 100644 (file)
@@ -2165,21 +2165,56 @@ void TheoryStrings::checkFlatForms()
         {
           return;
         }
+
+        for (const Node& n : it->second)
+        {
+          std::reverse(d_flat_form[n].begin(), d_flat_form[n].end());
+          std::reverse(d_flat_form_index[n].begin(),
+                       d_flat_form_index[n].end());
+        }
       }
     }
   }
 }
 
+namespace {
+
+enum class FlatFormInfer
+{
+  NONE,
+  CONST,
+  UNIFY,
+  ENDPOINT_EMP,
+  ENDPOINT_EQ,
+};
+
+std::ostream& operator<<(std::ostream& os, FlatFormInfer inf)
+{
+  switch (inf)
+  {
+    case FlatFormInfer::NONE: os << "<None>"; break;
+    case FlatFormInfer::CONST: os << "F_Const"; break;
+    case FlatFormInfer::UNIFY: os << "F_Unify"; break;
+    case FlatFormInfer::ENDPOINT_EMP: os << "F_EndpointEmp"; break;
+    case FlatFormInfer::ENDPOINT_EQ: os << "F_EndpointEq"; break;
+    default: os << "<Unknown>"; break;
+  }
+  return os;
+}
+
+}  // namespace
+
 void TheoryStrings::checkFlatForm(std::vector<Node>& eqc,
-                                  unsigned start,
+                                  size_t start,
                                   bool isRev)
 {
-  unsigned count = 0;
-  std::vector<Node> inelig;
-  for (unsigned i = 0; i <= start; i++)
-  {
-    inelig.push_back(eqc[start]);
-  }
+  size_t count = 0;
+  // We check for flat form inferences involving `eqc[start]` and terms past
+  // `start`. If there was a flat form inference involving `eqc[start]` and a
+  // term at a smaller index `i`, we would have found it with when `start` was
+  // `i`. Thus, we mark the preceeding terms in the equivalence class as
+  // ineligible.
+  std::vector<Node> inelig(eqc.begin(), eqc.begin() + start + 1);
   Node a = eqc[start];
   Trace("strings-ff-debug")
       << "Check flat form for a = " << a << ", whose flat form is "
@@ -2189,40 +2224,41 @@ void TheoryStrings::checkFlatForm(std::vector<Node>& eqc,
   {
     std::vector<Node> exp;
     Node conc;
-    int inf_type = -1;
-    unsigned eqc_size = eqc.size();
-    unsigned asize = d_flat_form[a].size();
+    FlatFormInfer infType = FlatFormInfer::NONE;
+    size_t eqc_size = eqc.size();
+    size_t asize = d_flat_form[a].size();
     if (count == asize)
     {
-      for (unsigned i = start + 1; i < eqc_size; i++)
+      for (size_t i = start + 1; i < eqc_size; i++)
       {
         b = eqc[i];
-        if (std::find(inelig.begin(), inelig.end(), b) == inelig.end())
+        if (std::find(inelig.begin(), inelig.end(), b) != inelig.end())
+        {
+          continue;
+        }
+
+        size_t bsize = d_flat_form[b].size();
+        if (count < bsize)
         {
-          unsigned bsize = d_flat_form[b].size();
-          if (count < bsize)
+          Trace("strings-ff-debug")
+              << "Found endpoint (in a) with non-empty b = " << b
+              << ", whose flat form is " << d_flat_form[b] << std::endl;
+          // endpoint
+          std::vector<Node> conc_c;
+          for (unsigned j = count; j < bsize; j++)
           {
-            Trace("strings-ff-debug")
-                << "Found endpoint (in a) with non-empty b = " << b
-                << ", whose flat form is " << d_flat_form[b] << std::endl;
-            // endpoint
-            std::vector<Node> conc_c;
-            for (unsigned j = count; j < bsize; j++)
-            {
-              conc_c.push_back(
-                  b[d_flat_form_index[b][j]].eqNode(d_emptyString));
-            }
-            Assert(!conc_c.empty());
-            conc = utils::mkAnd(conc_c);
-            inf_type = 2;
-            Assert(count > 0);
-            // swap, will enforce is empty past current
-            a = eqc[i];
-            b = eqc[start];
-            break;
+            conc_c.push_back(b[d_flat_form_index[b][j]].eqNode(d_emptyString));
           }
-          inelig.push_back(eqc[i]);
+          Assert(!conc_c.empty());
+          conc = utils::mkAnd(conc_c);
+          infType = FlatFormInfer::ENDPOINT_EMP;
+          Assert(count > 0);
+          // swap, will enforce is empty past current
+          a = eqc[i];
+          b = eqc[start];
+          break;
         }
+        inelig.push_back(eqc[i]);
       }
     }
     else
@@ -2232,94 +2268,97 @@ void TheoryStrings::checkFlatForm(std::vector<Node>& eqc,
       Node ac = a[d_flat_form_index[a][count]];
       std::vector<Node> lexp;
       Node lcurr = d_state.getLength(ac, lexp);
-      for (unsigned i = 1; i < eqc_size; i++)
+      for (size_t i = start + 1; i < eqc_size; i++)
       {
         b = eqc[i];
-        if (std::find(inelig.begin(), inelig.end(), b) == inelig.end())
+        if (std::find(inelig.begin(), inelig.end(), b) != inelig.end())
+        {
+          continue;
+        }
+
+        if (count == d_flat_form[b].size())
         {
-          if (count == d_flat_form[b].size())
+          inelig.push_back(b);
+          Trace("strings-ff-debug")
+              << "Found endpoint in b = " << b << ", whose flat form is "
+              << d_flat_form[b] << std::endl;
+          // endpoint
+          std::vector<Node> conc_c;
+          for (size_t j = count; j < asize; j++)
           {
-            inelig.push_back(b);
-            Trace("strings-ff-debug")
-                << "Found endpoint in b = " << b << ", whose flat form is "
-                << d_flat_form[b] << std::endl;
-            // endpoint
-            std::vector<Node> conc_c;
-            for (unsigned j = count; j < asize; j++)
-            {
-              conc_c.push_back(
-                  a[d_flat_form_index[a][j]].eqNode(d_emptyString));
-            }
-            Assert(!conc_c.empty());
-            conc = utils::mkAnd(conc_c);
-            inf_type = 2;
-            Assert(count > 0);
-            break;
+            conc_c.push_back(a[d_flat_form_index[a][j]].eqNode(d_emptyString));
           }
-          else
+          Assert(!conc_c.empty());
+          conc = utils::mkAnd(conc_c);
+          infType = FlatFormInfer::ENDPOINT_EMP;
+          Assert(count > 0);
+          break;
+        }
+        else
+        {
+          Node cc = d_flat_form[b][count];
+          if (cc != curr)
           {
-            Node cc = d_flat_form[b][count];
-            if (cc != curr)
+            Node bc = b[d_flat_form_index[b][count]];
+            inelig.push_back(b);
+            Assert(!d_state.areEqual(curr, cc));
+            Node cc_c = getConstantEqc(cc);
+            if (!curr_c.isNull() && !cc_c.isNull())
             {
-              Node bc = b[d_flat_form_index[b][count]];
-              inelig.push_back(b);
-              Assert(!d_state.areEqual(curr, cc));
-              Node cc_c = getConstantEqc(cc);
-              if (!curr_c.isNull() && !cc_c.isNull())
-              {
-                // check for constant conflict
-                int index;
-                Node s = TheoryStringsRewriter::splitConstant(
-                    cc_c, curr_c, index, isRev);
-                if (s.isNull())
-                {
-                  d_im.addToExplanation(ac, d_eqc_to_const_base[curr], exp);
-                  d_im.addToExplanation(d_eqc_to_const_exp[curr], exp);
-                  d_im.addToExplanation(bc, d_eqc_to_const_base[cc], exp);
-                  d_im.addToExplanation(d_eqc_to_const_exp[cc], exp);
-                  conc = d_false;
-                  inf_type = 0;
-                  break;
-                }
-              }
-              else if ((d_flat_form[a].size() - 1) == count
-                       && (d_flat_form[b].size() - 1) == count)
+              // check for constant conflict
+              int index;
+              Node s = TheoryStringsRewriter::splitConstant(
+                  cc_c, curr_c, index, isRev);
+              if (s.isNull())
               {
-                conc = ac.eqNode(bc);
-                inf_type = 3;
+                d_im.addToExplanation(ac, d_eqc_to_const_base[curr], exp);
+                d_im.addToExplanation(d_eqc_to_const_exp[curr], exp);
+                d_im.addToExplanation(bc, d_eqc_to_const_base[cc], exp);
+                d_im.addToExplanation(d_eqc_to_const_exp[cc], exp);
+                conc = d_false;
+                infType = FlatFormInfer::CONST;
                 break;
               }
-              else
+            }
+            else if ((d_flat_form[a].size() - 1) == count
+                     && (d_flat_form[b].size() - 1) == count)
+            {
+              conc = ac.eqNode(bc);
+              infType = FlatFormInfer::ENDPOINT_EQ;
+              break;
+            }
+            else
+            {
+              // if lengths are the same, apply LengthEq
+              std::vector<Node> lexp2;
+              Node lcc = d_state.getLength(bc, lexp2);
+              if (d_state.areEqual(lcurr, lcc))
               {
-                // if lengths are the same, apply LengthEq
-                std::vector<Node> lexp2;
-                Node lcc = d_state.getLength(bc, lexp2);
-                if (d_state.areEqual(lcurr, lcc))
+                if (Trace.isOn("strings-ff-debug"))
                 {
-                  Trace("strings-ff-debug") << "Infer " << ac << " == " << bc
-                                            << " since " << lcurr
-                                            << " == " << lcc << std::endl;
-                  // exp_n.push_back( getLength( curr, true ).eqNode(
-                  // getLength( cc, true ) ) );
-                  Trace("strings-ff-debug") << "Explanation for " << lcurr
-                                            << " is ";
-                  for (unsigned j = 0; j < lexp.size(); j++)
+                  Trace("strings-ff-debug")
+                      << "Infer " << ac << " == " << bc << " since " << lcurr
+                      << " == " << lcc << std::endl;
+                  Trace("strings-ff-debug")
+                      << "Explanation for " << lcurr << " is ";
+                  for (size_t j = 0; j < lexp.size(); j++)
                   {
                     Trace("strings-ff-debug") << lexp[j] << std::endl;
                   }
-                  Trace("strings-ff-debug") << "Explanation for " << lcc
-                                            << " is ";
-                  for (unsigned j = 0; j < lexp2.size(); j++)
+                  Trace("strings-ff-debug")
+                      << "Explanation for " << lcc << " is ";
+                  for (size_t j = 0; j < lexp2.size(); j++)
                   {
                     Trace("strings-ff-debug") << lexp2[j] << std::endl;
                   }
-                  exp.insert(exp.end(), lexp.begin(), lexp.end());
-                  exp.insert(exp.end(), lexp2.begin(), lexp2.end());
-                  d_im.addToExplanation(lcurr, lcc, exp);
-                  conc = ac.eqNode(bc);
-                  inf_type = 1;
-                  break;
                 }
+
+                exp.insert(exp.end(), lexp.begin(), lexp.end());
+                exp.insert(exp.end(), lexp2.begin(), lexp2.end());
+                d_im.addToExplanation(lcurr, lcc, exp);
+                conc = ac.eqNode(bc);
+                infType = FlatFormInfer::UNIFY;
+                break;
               }
             }
           }
@@ -2328,12 +2367,12 @@ void TheoryStrings::checkFlatForm(std::vector<Node>& eqc,
     }
     if (!conc.isNull())
     {
-      Trace("strings-ff-debug")
-          << "Found inference : " << conc << " based on equality " << a
-          << " == " << b << ", " << isRev << " " << inf_type << std::endl;
+      Trace("strings-ff-debug") << "Found inference (" << infType
+                                << "): " << conc << " based on equality " << a
+                                << " == " << b << ", " << isRev << std::endl;
       d_im.addToExplanation(a, b, exp);
       // explain why prefixes up to now were the same
-      for (unsigned j = 0; j < count; j++)
+      for (size_t j = 0; j < count; j++)
       {
         Trace("strings-ff-debug") << "Add at " << d_flat_form_index[a][j] << " "
                                   << d_flat_form_index[b][j] << std::endl;
@@ -2344,8 +2383,9 @@ void TheoryStrings::checkFlatForm(std::vector<Node>& eqc,
       for (unsigned t = 0; t < 2; t++)
       {
         Node c = t == 0 ? a : b;
-        int jj;
-        if (inf_type == 3 || (t == 1 && inf_type == 2))
+        ssize_t jj;
+        if (infType == FlatFormInfer::ENDPOINT_EQ
+            || (t == 1 && infType == FlatFormInfer::ENDPOINT_EMP))
         {
           // explain all the empty components for F_EndpointEq, all for
           // the short end for F_EndpointEmp
@@ -2356,9 +2396,9 @@ void TheoryStrings::checkFlatForm(std::vector<Node>& eqc,
           jj = t == 0 ? d_flat_form_index[a][count]
                       : d_flat_form_index[b][count];
         }
-        int startj = isRev ? jj + 1 : 0;
-        int endj = isRev ? c.getNumChildren() : jj;
-        for (int j = startj; j < endj; j++)
+        ssize_t startj = isRev ? jj + 1 : 0;
+        ssize_t endj = isRev ? c.getNumChildren() : jj;
+        for (ssize_t j = startj; j < endj; j++)
         {
           if (d_state.areEqual(c[j], d_emptyString))
           {
@@ -2371,13 +2411,9 @@ void TheoryStrings::checkFlatForm(std::vector<Node>& eqc,
       // is conflicting by arithmetic len(a.b)=len(a)+len(b)!=len(a)
       // when len(b)!=0. Although if we do not infer this conflict eagerly,
       // it may be applied (see #3272).
-      d_im.sendInference(
-          exp,
-          conc,
-          inf_type == 0 ? "F_Const"
-                        : (inf_type == 1 ? "F_Unify"
-                                         : (inf_type == 2 ? "F_EndpointEmp"
-                                                          : "F_EndpointEq")));
+      std::stringstream ss;
+      ss << infType;
+      d_im.sendInference(exp, conc, ss.str().c_str());
       if (d_conflict)
       {
         return;
@@ -2386,12 +2422,6 @@ void TheoryStrings::checkFlatForm(std::vector<Node>& eqc,
     }
     count++;
   } while (inelig.size() < eqc.size());
-
-  for (const Node& n : eqc)
-  {
-    std::reverse(d_flat_form[n].begin(), d_flat_form[n].end());
-    std::reverse(d_flat_form_index[n].begin(), d_flat_form_index[n].end());
-  }
 }
 
 Node TheoryStrings::checkCycles( Node eqc, std::vector< Node >& curr, std::vector< Node >& exp ){
index 9db40f8fe58c52851b8339020febfd1f6fb11704..54ea0d1584db6e68bae3c4ae4ccad901af61ae85 100644 (file)
@@ -453,10 +453,11 @@ private:
   /**
    * This checks whether there are flat form inferences between eqc[start] and
    * eqc[j] for some j>start. If the flag isRev is true, we check for flat form
-   * interferences in the reverse direction of the flat forms. For more details,
-   * see checkFlatForms below.
+   * interferences in the reverse direction of the flat forms (note:
+   * `d_flat_form` and `d_flat_form_index` must be in reverse order if `isRev`
+   * is true). For more details, see checkFlatForms below.
    */
-  void checkFlatForm(std::vector<Node>& eqc, unsigned start, bool isRev);
+  void checkFlatForm(std::vector<Node>& eqc, size_t start, bool isRev);
   //--------------------------end for checkFlatForm
 
   //--------------------------for checkCycles