Infrastructure for string length entailments via approximations (#2514)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Thu, 4 Oct 2018 22:26:10 +0000 (17:26 -0500)
committerGitHub <noreply@github.com>
Thu, 4 Oct 2018 22:26:10 +0000 (17:26 -0500)
src/theory/strings/theory_strings_rewriter.cpp
src/theory/strings/theory_strings_rewriter.h

index de5effd233f48b2878985ee753c5168c28892a8d..b2778298ac35b6969272c5effa60ebe202f4723f 100644 (file)
@@ -3673,52 +3673,501 @@ bool TheoryStringsRewriter::checkEntailArith(Node a, Node b, bool strict)
   }
 }
 
+struct StrCheckEntailArithTag
+{
+};
+struct StrCheckEntailArithComputedTag
+{
+};
+/** Attribute true for expressions for which checkEntailArith returned true */
+typedef expr::Attribute<StrCheckEntailArithTag, bool> StrCheckEntailArithAttr;
+typedef expr::Attribute<StrCheckEntailArithComputedTag, bool>
+    StrCheckEntailArithComputedAttr;
+
 bool TheoryStringsRewriter::checkEntailArith(Node a, bool strict)
 {
   if (a.isConst())
   {
     return a.getConst<Rational>().sgn() >= (strict ? 1 : 0);
   }
-  else
+
+  Node ar =
+      strict
+          ? NodeManager::currentNM()->mkNode(
+                kind::MINUS, a, NodeManager::currentNM()->mkConst(Rational(1)))
+          : a;
+  ar = Rewriter::rewrite(ar);
+
+  if (ar.getAttribute(StrCheckEntailArithComputedAttr()))
   {
-    Node ar = strict
-                  ? NodeManager::currentNM()->mkNode(
-                        kind::MINUS,
-                        a,
-                        NodeManager::currentNM()->mkConst(Rational(1)))
-                  : a;
-    ar = Rewriter::rewrite(ar);
-    if (checkEntailArithInternal(ar))
+    return ar.getAttribute(StrCheckEntailArithAttr());
+  }
+
+  bool ret = checkEntailArithInternal(ar);
+  if (!ret)
+  {
+    // try with approximations
+    ret = checkEntailArithApprox(ar);
+  }
+  // cache the result
+  ar.setAttribute(StrCheckEntailArithAttr(), ret);
+  ar.setAttribute(StrCheckEntailArithComputedAttr(), true);
+  return ret;
+}
+
+bool TheoryStringsRewriter::checkEntailArithApprox(Node ar)
+{
+  Assert(Rewriter::rewrite(ar) == ar);
+  NodeManager* nm = NodeManager::currentNM();
+  std::map<Node, Node> msum;
+  Trace("strings-ent-approx-debug")
+      << "Setup arithmetic approximations for " << ar << std::endl;
+  if (!ArithMSum::getMonomialSum(ar, msum))
+  {
+    Trace("strings-ent-approx-debug")
+        << "...failed to get monomial sum!" << std::endl;
+    return false;
+  }
+  // for each monomial v*c, mApprox[v] a list of
+  // possibilities for how the term can be soundly approximated, that is,
+  // if mApprox[v] contains av, then v*c > av*c. Notice that if c
+  // is positive, then v > av, otherwise if c is negative, then v < av.
+  // In other words, av is an under-approximation if c is positive, and an
+  // over-approximation if c is negative.
+  bool changed = false;
+  std::map<Node, std::vector<Node> > mApprox;
+  // map from approximations to their monomial sums
+  std::map<Node, std::map<Node, Node> > approxMsums;
+  // aarSum stores each monomial that does not have multiple approximations
+  std::vector<Node> aarSum;
+  for (std::pair<const Node, Node>& m : msum)
+  {
+    Node v = m.first;
+    Node c = m.second;
+    Trace("strings-ent-approx-debug")
+        << "Get approximations " << v << "..." << std::endl;
+    if (v.isNull())
+    {
+      Node mn = c.isNull() ? nm->mkConst(Rational(1)) : c;
+      aarSum.push_back(mn);
+    }
+    else
     {
-      return true;
+      // c.isNull() means c = 1
+      bool isOverApprox = !c.isNull() && c.getConst<Rational>().sgn() == -1;
+      std::vector<Node>& approx = mApprox[v];
+      std::unordered_set<Node, NodeHashFunction> visited;
+      std::vector<Node> toProcess;
+      toProcess.push_back(v);
+      do
+      {
+        Node curr = toProcess.back();
+        Trace("strings-ent-approx-debug") << "  process " << curr << std::endl;
+        curr = Rewriter::rewrite(curr);
+        toProcess.pop_back();
+        if (visited.find(curr) == visited.end())
+        {
+          visited.insert(curr);
+          std::vector<Node> currApprox;
+          getArithApproximations(curr, currApprox, isOverApprox);
+          if (currApprox.empty())
+          {
+            Trace("strings-ent-approx-debug")
+                << "...approximation: " << curr << std::endl;
+            // no approximations, thus curr is a possibility
+            approx.push_back(curr);
+          }
+          else
+          {
+            toProcess.insert(
+                toProcess.end(), currApprox.begin(), currApprox.end());
+          }
+        }
+      } while (!toProcess.empty());
+      Assert(!approx.empty());
+      // if we have only one approximation, move it to final
+      if (approx.size() == 1)
+      {
+        changed = v != approx[0];
+        Node mn = ArithMSum::mkCoeffTerm(c, approx[0]);
+        aarSum.push_back(mn);
+        mApprox.erase(v);
+      }
+      else
+      {
+        // compute monomial sum form for each approximation, used below
+        for (const Node& aa : approx)
+        {
+          if (approxMsums.find(aa) == approxMsums.end())
+          {
+            CVC4_UNUSED bool ret =
+                ArithMSum::getMonomialSum(aa, approxMsums[aa]);
+            Assert(ret);
+          }
+        }
+        changed = true;
+      }
     }
-    // TODO (#1180) : abstract interpretation goes here
-
-    // over approximation O/U
-
-    // O( x + y ) -> O( x ) + O( y )
-    // O( c * x ) -> O( x ) if c > 0, U( x ) if c < 0
-    // O( len( x ) ) -> len( x )
-    // O( len( int.to.str( x ) ) ) -> len( int.to.str( x ) )
-    // O( len( str.substr( x, n1, n2 ) ) ) -> O( n2 ) | O( len( x ) )
-    // O( len( str.replace( x, y, z ) ) ) ->
-    //   O( len( x ) ) + O( len( z ) ) - U( len( y ) )
-    // O( indexof( x, y, n ) ) -> O( len( x ) ) - U( len( y ) )
-    // O( str.to.int( x ) ) -> str.to.int( x )
-
-    // U( x + y ) -> U( x ) + U( y )
-    // U( c * x ) -> U( x ) if c > 0, O( x ) if c < 0
-    // U( len( x ) ) -> len( x )
-    // U( len( int.to.str( x ) ) ) -> 1
-    // U( len( str.substr( x, n1, n2 ) ) ) ->
-    //   min( U( len( x ) ) - O( n1 ), U( n2 ) )
-    // U( len( str.replace( x, y, z ) ) ) ->
-    //   U( len( x ) ) + U( len( z ) ) - O( len( y ) ) | 0
-    // U( indexof( x, y, n ) ) -> -1    ?
-    // U( str.to.int( x ) ) -> -1
-
+  }
+  if (!changed)
+  {
+    // approximations had no effect, return
+    Trace("strings-ent-approx-debug") << "...no approximations" << std::endl;
     return false;
   }
+  // get the current "fixed" sum for the abstraction of ar
+  Node aar = aarSum.empty()
+                 ? nm->mkConst(Rational(0))
+                 : (aarSum.size() == 1 ? aarSum[0] : nm->mkNode(PLUS, aarSum));
+  aar = Rewriter::rewrite(aar);
+  Trace("strings-ent-approx-debug")
+      << "...processed fixed sum " << aar << " with " << mApprox.size()
+      << " approximated monomials." << std::endl;
+  // if we have a choice of how to approximate
+  if (!mApprox.empty())
+  {
+    // convert aar back to monomial sum
+    std::map<Node, Node> msumAar;
+    if (!ArithMSum::getMonomialSum(aar, msumAar))
+    {
+      return false;
+    }
+    if (Trace.isOn("strings-ent-approx"))
+    {
+      Trace("strings-ent-approx")
+          << "---- Check arithmetic entailment by under-approximation " << ar
+          << " >= 0" << std::endl;
+      Trace("strings-ent-approx") << "FIXED:" << std::endl;
+      ArithMSum::debugPrintMonomialSum(msumAar, "strings-ent-approx");
+      Trace("strings-ent-approx") << "APPROX:" << std::endl;
+      for (std::pair<const Node, std::vector<Node> >& a : mApprox)
+      {
+        Node c = msum[a.first];
+        Trace("strings-ent-approx") << "  ";
+        if (!c.isNull())
+        {
+          Trace("strings-ent-approx") << c << " * ";
+        }
+        Trace("strings-ent-approx")
+            << a.second << " ...from " << a.first << std::endl;
+      }
+      Trace("strings-ent-approx") << std::endl;
+    }
+    Rational one(1);
+    // incorporate monomials one at a time that have a choice of approximations
+    while (!mApprox.empty())
+    {
+      Node v;
+      Node vapprox;
+      int maxScore = -1;
+      // Look at each approximation, take the one with the best score.
+      // Notice that we are in the process of trying to prove
+      // ( c1*t1 + .. + cn*tn ) + ( approx_1 | ... | approx_m ) >= 0,
+      // where c1*t1 + .. + cn*tn is the "fixed" component of our sum (aar)
+      // and approx_1 ... approx_m are possible approximations. The
+      // intution here is that we want coefficients c1...cn to be positive.
+      // This is because arithmetic string terms t1...tn (which may be
+      // applications of len, indexof, str.to.int) are never entailed to be
+      // negative. Hence, we add the approx_i that contributes the "most"
+      // towards making all constants c1...cn positive and cancelling negative
+      // monomials in approx_i itself.
+      for (std::pair<const Node, std::vector<Node> >& nam : mApprox)
+      {
+        for (const Node& aa : nam.second)
+        {
+          unsigned helpsCancelCount = 0;
+          unsigned addsObligationCount = 0;
+          std::map<Node, Node>::iterator it;
+          for (std::pair<const Node, Node>& aam : approxMsums[aa])
+          {
+            // Say aar is of the form t + c1*v, and aam is the monomial c2*v
+            // where c2 != 0. We say aam:
+            // (1) helps cancel if c1 != 0 and c1>0 != c2>0
+            // (2) adds obligation if c1>=0 and c1+c2<0
+            Node v = aam.first;
+            Node c2 = aam.second;
+            int c2Sgn = c2.isNull() ? 1 : c2.getConst<Rational>().sgn();
+            it = msumAar.find(v);
+            if (it != msumAar.end())
+            {
+              Node c1 = it->second;
+              int c1Sgn = c1.isNull() ? 1 : c1.getConst<Rational>().sgn();
+              if (c1Sgn == 0)
+              {
+                addsObligationCount += (c2Sgn == -1 ? 1 : 0);
+              }
+              else if (c1Sgn != c2Sgn)
+              {
+                helpsCancelCount++;
+                Rational r1 = c1.isNull() ? one : c1.getConst<Rational>();
+                Rational r2 = c2.isNull() ? one : c2.getConst<Rational>();
+                Rational r12 = r1 + r2;
+                if (r12.sgn() == -1)
+                {
+                  addsObligationCount++;
+                }
+              }
+            }
+            else
+            {
+              addsObligationCount += (c2Sgn == -1 ? 1 : 0);
+            }
+          }
+          Trace("strings-ent-approx-debug")
+              << "counts=" << helpsCancelCount << "," << addsObligationCount
+              << " for " << aa << " into " << aar << std::endl;
+          int score = (addsObligationCount > 0 ? 0 : 2)
+                      + (helpsCancelCount > 0 ? 1 : 0);
+          // if its the best, update v and vapprox
+          if (v.isNull() || score > maxScore)
+          {
+            v = nam.first;
+            vapprox = aa;
+            maxScore = score;
+          }
+        }
+        if (!v.isNull())
+        {
+          break;
+        }
+      }
+      Trace("strings-ent-approx")
+          << "- Decide " << v << " = " << vapprox << std::endl;
+      // we incorporate v approximated by vapprox into the overall approximation
+      // for ar
+      Assert(!v.isNull() && !vapprox.isNull());
+      Assert(msum.find(v) != msum.end());
+      Node mn = ArithMSum::mkCoeffTerm(msum[v], vapprox);
+      aar = nm->mkNode(PLUS, aar, mn);
+      // update the msumAar map
+      aar = Rewriter::rewrite(aar);
+      msumAar.clear();
+      if (!ArithMSum::getMonomialSum(aar, msumAar))
+      {
+        Assert(false);
+        Trace("strings-ent-approx")
+            << "...failed to get monomial sum!" << std::endl;
+        return false;
+      }
+      // we have processed the approximation for v
+      mApprox.erase(v);
+    }
+    Trace("strings-ent-approx") << "-----------------" << std::endl;
+  }
+  if (aar == ar)
+  {
+    Trace("strings-ent-approx-debug")
+        << "...approximation had no effect" << std::endl;
+    // this should never happen, but we avoid the infinite loop for sanity here
+    Assert(false);
+    return false;
+  }
+  // Check entailment on the approximation of ar.
+  // Notice that this may trigger further reasoning by approximation. For
+  // example, len( replace( x ++ y, substr( x, 0, n ), z ) ) may be
+  // under-approximated as len( x ) + len( y ) - len( substr( x, 0, n ) ) on
+  // this call, where in the recursive call we may over-approximate
+  // len( substr( x, 0, n ) ) as len( x ). In this example, we can infer
+  // that len( replace( x ++ y, substr( x, 0, n ), z ) ) >= len( y ) in two
+  // steps.
+  if (checkEntailArith(aar))
+  {
+    Trace("strings-ent-approx")
+        << "*** StrArithApprox: showed " << ar
+        << " >= 0 using under-approximation!" << std::endl;
+    Trace("strings-ent-approx")
+        << "*** StrArithApprox: under-approximation was " << aar << std::endl;
+    return true;
+  }
+  return false;
+}
+
+void TheoryStringsRewriter::getArithApproximations(Node a,
+                                                   std::vector<Node>& approx,
+                                                   bool isOverApprox)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  // We do not handle PLUS here since this leads to exponential behavior.
+  // Instead, this is managed, e.g. during checkEntailArithApprox, where
+  // PLUS terms are expanded "on-demand" during the reasoning.
+  Trace("strings-ent-approx-debug")
+      << "Get arith approximations " << a << std::endl;
+  Kind ak = a.getKind();
+  if (ak == MULT)
+  {
+    Node c;
+    Node v;
+    if (ArithMSum::getMonomial(a, c, v))
+    {
+      bool isNeg = c.getConst<Rational>().sgn() == -1;
+      getArithApproximations(v, approx, isNeg ? !isOverApprox : isOverApprox);
+      for (unsigned i = 0, size = approx.size(); i < size; i++)
+      {
+        approx[i] = nm->mkNode(MULT, c, approx[i]);
+      }
+    }
+  }
+  else if (ak == STRING_LENGTH)
+  {
+    Kind aak = a[0].getKind();
+    if (aak == STRING_SUBSTR)
+    {
+      // over,under-approximations for len( substr( x, n, m ) )
+      Node lenx = nm->mkNode(STRING_LENGTH, a[0][0]);
+      if (isOverApprox)
+      {
+        // m >= 0 implies
+        //  m >= len( substr( x, n, m ) )
+        if (checkEntailArith(a[0][2]))
+        {
+          approx.push_back(a[0][2]);
+        }
+        if (checkEntailArith(lenx, a[0][1]))
+        {
+          // n <= len( x ) implies
+          //   len( x ) - n >= len( substr( x, n, m ) )
+          approx.push_back(nm->mkNode(MINUS, lenx, a[0][1]));
+        }
+        else
+        {
+          // len( x ) >= len( substr( x, n, m ) )
+          approx.push_back(lenx);
+        }
+      }
+      else
+      {
+        // 0 <= n and n+m <= len( x ) implies
+        //   m <= len( substr( x, n, m ) )
+        Node npm = nm->mkNode(PLUS, a[0][1], a[0][2]);
+        if (checkEntailArith(a[0][1]) && checkEntailArith(lenx, npm))
+        {
+          approx.push_back(a[0][2]);
+        }
+        // 0 <= n and n+m >= len( x ) implies
+        //   len(x)-n <= len( substr( x, n, m ) )
+        if (checkEntailArith(a[0][1]) && checkEntailArith(npm, lenx))
+        {
+          approx.push_back(nm->mkNode(MINUS, lenx, a[0][1]));
+        }
+      }
+    }
+    else if (aak == STRING_STRREPL)
+    {
+      // over,under-approximations for len( replace( x, y, z ) )
+      // notice this is either len( x ) or ( len( x ) + len( z ) - len( y ) )
+      Node lenx = nm->mkNode(STRING_LENGTH, a[0][0]);
+      Node leny = nm->mkNode(STRING_LENGTH, a[0][1]);
+      Node lenz = nm->mkNode(STRING_LENGTH, a[0][2]);
+      if (isOverApprox)
+      {
+        if (checkEntailArith(leny, lenz))
+        {
+          // len( y ) >= len( z ) implies
+          //   len( x ) >= len( replace( x, y, z ) )
+          approx.push_back(lenx);
+        }
+        else
+        {
+          // len( x ) + len( z ) >= len( replace( x, y, z ) )
+          approx.push_back(nm->mkNode(PLUS, lenx, lenz));
+        }
+      }
+      else
+      {
+        if (checkEntailArith(lenz, leny) || checkEntailArith(lenz, lenx))
+        {
+          // len( y ) <= len( z ) or len( x ) <= len( z ) implies
+          //   len( x ) <= len( replace( x, y, z ) )
+          approx.push_back(lenx);
+        }
+        else
+        {
+          // len( x ) - len( y ) <= len( replace( x, y, z ) )
+          approx.push_back(nm->mkNode(MINUS, lenx, leny));
+        }
+      }
+    }
+    else if (aak == STRING_ITOS)
+    {
+      // over,under-approximations for len( int.to.str( x ) )
+      if (isOverApprox)
+      {
+        if (checkEntailArith(a[0][0], false))
+        {
+          if (checkEntailArith(a[0][0], true))
+          {
+            // x > 0 implies
+            //   x >= len( int.to.str( x ) )
+            approx.push_back(a[0][0]);
+          }
+          else
+          {
+            // x >= 0 implies
+            //   x+1 >= len( int.to.str( x ) )
+            approx.push_back(
+                nm->mkNode(PLUS, nm->mkConst(Rational(1)), a[0][0]));
+          }
+        }
+      }
+      else
+      {
+        if (checkEntailArith(a[0][0]))
+        {
+          // x >= 0 implies
+          //   len( int.to.str( x ) ) >= 1
+          approx.push_back(nm->mkConst(Rational(1)));
+        }
+        // other crazy things are possible here, e.g.
+        // len( int.to.str( len( y ) + 10 ) ) >= 2
+      }
+    }
+  }
+  else if (ak == STRING_STRIDOF)
+  {
+    // over,under-approximations for indexof( x, y, n )
+    if (isOverApprox)
+    {
+      Node lenx = nm->mkNode(STRING_LENGTH, a[0]);
+      Node leny = nm->mkNode(STRING_LENGTH, a[1]);
+      if (checkEntailArith(lenx, leny))
+      {
+        // len( x ) >= len( y ) implies
+        //   len( x ) - len( y ) >= indexof( x, y, n )
+        approx.push_back(nm->mkNode(MINUS, lenx, leny));
+      }
+      else
+      {
+        // len( x ) >= indexof( x, y, n )
+        approx.push_back(lenx);
+      }
+    }
+    else
+    {
+      // TODO?:
+      // contains( substr( x, n, len( x ) ), y ) implies
+      //   n <= indexof( x, y, n )
+      // ...hard to test, runs risk of non-termination
+
+      // -1 <= indexof( x, y, n )
+      approx.push_back(nm->mkConst(Rational(-1)));
+    }
+  }
+  else if (ak == STRING_STOI)
+  {
+    // over,under-approximations for str.to.int( x )
+    if (isOverApprox)
+    {
+      // TODO?:
+      // y >= 0 implies
+      //   y >= str.to.int( int.to.str( y ) )
+    }
+    else
+    {
+      // -1 <= str.to.int( x )
+      approx.push_back(nm->mkConst(Rational(-1)));
+    }
+  }
+  Trace("strings-ent-approx-debug") << "Return " << approx.size() << std::endl;
 }
 
 bool TheoryStringsRewriter::checkEntailArithWithEqAssumption(Node assumption,
@@ -3940,12 +4389,11 @@ Node TheoryStringsRewriter::getConstantArithBound(Node a, bool isLower)
       << "Constant " << (isLower ? "lower" : "upper") << " bound for " << a
       << " is " << ret << std::endl;
   Assert(ret.isNull() || ret.isConst());
-  Assert(!isLower
-         || (ret.isNull() || ret.getConst<Rational>().sgn() < 0)
-                != checkEntailArith(a, false));
-  Assert(!isLower
-         || (ret.isNull() || ret.getConst<Rational>().sgn() <= 0)
-                != checkEntailArith(a, true));
+  // entailment check should be at least as powerful as computing a lower bound
+  Assert(!isLower || ret.isNull() || ret.getConst<Rational>().sgn() < 0
+         || checkEntailArith(a, false));
+  Assert(!isLower || ret.isNull() || ret.getConst<Rational>().sgn() <= 0
+         || checkEntailArith(a, true));
   return ret;
 }
 
index 91d87769c2611fd07662bf9aec5ad02d47e354c1..ed42ce762b23d4881ae1f815e3c43c35b13be7fa 100644 (file)
@@ -465,6 +465,39 @@ class TheoryStringsRewriter {
    * Returns true if it is always the case that a >= 0.
    */
   static bool checkEntailArith(Node a, bool strict = false);
+  /** check arithmetic entailment with approximations
+   *
+   * Returns true if it is always the case that a >= 0. We expect that a is in
+   * rewritten form.
+   *
+   * This function uses "approximation" techniques that under-approximate
+   * the value of a for the purposes of showing the entailment holds. For
+   * example, given:
+   *   len( x ) - len( substr( y, 0, len( x ) ) )
+   * Since we know that len( substr( y, 0, len( x ) ) ) <= len( x ), the above
+   * term can be under-approximated as len( x ) - len( x ) = 0, which is >= 0,
+   * and thus the entailment len( x ) - len( substr( y, 0, len( x ) ) ) >= 0
+   * holds.
+   */
+  static bool checkEntailArithApprox(Node a);
+  /** Get arithmetic approximations
+   *
+   * This gets the (set of) arithmetic approximations for term a and stores
+   * them in approx. If isOverApprox is true, these are over-approximations
+   * for the value of a, otherwise, they are underapproximations. For example,
+   * an over-approximation for len( substr( y, n, m ) ) is m; an
+   * under-approximation for indexof( x, y, n ) is -1.
+   *
+   * Notice that this function is not generally recursive (although it may make
+   * a small bounded of recursive calls). Instead, it returns the shape
+   * of the approximations for a. For example, an under-approximation
+   * for the term len( replace( substr( x, 0, n ), y, z ) ) returned by this
+   * function might be len( substr( x, 0, n ) ) - len( y ), where we don't
+   * consider (recursively) the approximations for len( substr( x, 0, n ) ).
+   */
+  static void getArithApproximations(Node a,
+                                     std::vector<Node>& approx,
+                                     bool isOverApprox = false);
 
   /**
    * Checks whether assumption |= a >= 0 (if strict is false) or