Improve rewriter for string equality (#1427)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 2 Jan 2018 17:43:00 +0000 (11:43 -0600)
committerGitHub <noreply@github.com>
Tue, 2 Jan 2018 17:43:00 +0000 (11:43 -0600)
src/theory/quantifiers/extended_rewrite.cpp
src/theory/strings/theory_strings.cpp
src/theory/strings/theory_strings.h
src/theory/strings/theory_strings_rewriter.cpp
src/theory/strings/theory_strings_rewriter.h

index b463a319a022a53d5c477ea5b57974f5c10ad58d..95682230341d178ce8b7da996f8b2d56dc93ea33 100644 (file)
@@ -118,47 +118,6 @@ Node ExtendedRewriter::extendedRewrite(Node n)
     Node new_ret;
     if (ret.getKind() == kind::EQUAL)
     {
-      // string equalities with disequal prefix or suffix
-      if (ret[0].getType().isString())
-      {
-        std::vector<Node> c[2];
-        for (unsigned i = 0; i < 2; i++)
-        {
-          strings::TheoryStringsRewriter::getConcat(ret[i], c[i]);
-        }
-        if (c[0].empty() == c[1].empty())
-        {
-          if (!c[0].empty())
-          {
-            for (unsigned i = 0; i < 2; i++)
-            {
-              unsigned index1 = i == 0 ? 0 : c[0].size() - 1;
-              unsigned index2 = i == 0 ? 0 : c[1].size() - 1;
-              if (c[0][index1].isConst() && c[1][index2].isConst())
-              {
-                CVC4::String s = c[0][index1].getConst<String>();
-                CVC4::String t = c[1][index2].getConst<String>();
-                unsigned len_short = s.size() <= t.size() ? s.size() : t.size();
-                bool isSameFix =
-                    i == 1 ? s.rstrncmp(t, len_short) : s.strncmp(t, len_short);
-                if (!isSameFix)
-                {
-                  Trace("q-ext-rewrite") << "sygus-extr : " << ret
-                                         << " rewrites to false due to "
-                                            "disequal string prefix/suffix."
-                                         << std::endl;
-                  new_ret = d_false;
-                  break;
-                }
-              }
-            }
-          }
-        }
-        else
-        {
-          new_ret = d_false;
-        }
-      }
       if (new_ret.isNull())
       {
         // simple ITE pulling
index e6b8807e9955f2350f6a64ae7ca793616cc32a66..30a5f0fbc4b5c418c043ae5edd908e8ecbf37f42 100644 (file)
@@ -1448,7 +1448,7 @@ void TheoryStrings::checkExtfInference( Node n, Node nr, ExtfInfoTmp& in, int ef
       if( ( in.d_pol==1 && nr[1].getKind()==kind::STRING_CONCAT ) || ( in.d_pol==-1 && nr[0].getKind()==kind::STRING_CONCAT ) ){
         if( d_extf_infer_cache.find( nr )==d_extf_infer_cache.end() ){
           d_extf_infer_cache.insert( nr );
-          
+
           //one argument does (not) contain each of the components of the other argument
           int index = in.d_pol==1 ? 1 : 0;
           std::vector< Node > children;
@@ -1458,9 +1458,21 @@ void TheoryStrings::checkExtfInference( Node n, Node nr, ExtfInfoTmp& in, int ef
           for( unsigned i=0; i<nr[index].getNumChildren(); i++ ){
             children[index] = nr[index][i];
             Node conc = NodeManager::currentNM()->mkNode( kind::STRING_STRCTN, children );
-            //can mark as reduced, since model for n => model for conc
-            getExtTheory()->markReduced( conc );
-            sendInference( in.d_exp, in.d_pol==1 ? conc : conc.negate(), "CTN_Decompose" );
+            conc = Rewriter::rewrite(in.d_pol == 1 ? conc : conc.negate());
+            // check if it already (does not) hold
+            if (hasTerm(conc))
+            {
+              if (areEqual(conc, d_false))
+              {
+                // should be a conflict
+                sendInference(in.d_exp, conc, "CTN_Decompose");
+              }
+              else if (getExtTheory()->hasFunctionKind(conc.getKind()))
+              {
+                // can mark as reduced, since model for n => model for conc
+                getExtTheory()->markReduced(conc);
+              }
+            }
           }
           
         }
@@ -2978,11 +2990,11 @@ void TheoryStrings::processDeq( Node ni, Node nj ) {
                     return;
                   }else if( !areEqual( firstChar, nconst_k ) ){
                     //splitting on demand : try to make them disequal
-                    Node eq = firstChar.eqNode( nconst_k );
-                    sendSplit( firstChar, nconst_k, "S-Split(DEQL-Const)" );
-                    eq = Rewriter::rewrite( eq );
-                    d_pending_req_phase[ eq ] = false;
-                    return;
+                    if (sendSplit(
+                            firstChar, nconst_k, "S-Split(DEQL-Const)", false))
+                    {
+                      return;
+                    }
                   }
                 }else{
                   Node sk = mkSkolemCached( nconst_k, firstChar, sk_id_dc_spt, "dc_spt", 2 );
@@ -3032,18 +3044,16 @@ void TheoryStrings::processDeq( Node ni, Node nj ) {
           }else if( areEqual( li, lj ) ){
             Assert( !areDisequal( i, j ) );
             //splitting on demand : try to make them disequal
-            Node eq = i.eqNode( j );
-            sendSplit( i, j, "S-Split(DEQL)" );
-            eq = Rewriter::rewrite( eq );
-            d_pending_req_phase[ eq ] = false;
-            return;
+            if (sendSplit(i, j, "S-Split(DEQL)", false))
+            {
+              return;
+            }
           }else{
             //splitting on demand : try to make lengths equal
-            Node eq = li.eqNode( lj );
-            sendSplit( li, lj, "D-Split" );
-            eq = Rewriter::rewrite( eq );
-            d_pending_req_phase[ eq ] = true;
-            return;
+            if (sendSplit(li, lj, "D-Split"))
+            {
+              return;
+            }
           }
         }
         index++;
@@ -3361,15 +3371,22 @@ void TheoryStrings::sendInfer( Node eq_exp, Node eq, const char * c ) {
   d_infer_exp.push_back( eq_exp );
 }
 
-void TheoryStrings::sendSplit( Node a, Node b, const char * c, bool preq ) {
+bool TheoryStrings::sendSplit(Node a, Node b, const char* c, bool preq)
+{
   Node eq = a.eqNode( b );
   eq = Rewriter::rewrite( eq );
-  Node neq = NodeManager::currentNM()->mkNode( kind::NOT, eq );
-  Node lemma_or = NodeManager::currentNM()->mkNode( kind::OR, eq, neq );
-  Trace("strings-lemma") << "Strings::Lemma " << c << " SPLIT : " << lemma_or << std::endl;
-  d_lemma_cache.push_back(lemma_or);
-  d_pending_req_phase[eq] = preq;
-  ++(d_statistics.d_splits);
+  if (!eq.isConst())
+  {
+    Node neq = NodeManager::currentNM()->mkNode(kind::NOT, eq);
+    Node lemma_or = NodeManager::currentNM()->mkNode(kind::OR, eq, neq);
+    Trace("strings-lemma") << "Strings::Lemma " << c << " SPLIT : " << lemma_or
+                           << std::endl;
+    d_lemma_cache.push_back(lemma_or);
+    d_pending_req_phase[eq] = preq;
+    ++(d_statistics.d_splits);
+    return true;
+  }
+  return false;
 }
 
 
@@ -3767,8 +3784,10 @@ void TheoryStrings::checkCardinality() {
             itr2 != cols[i].end(); ++itr2) {
             if(!areDisequal( *itr1, *itr2 )) {
               // add split lemma
-              sendSplit( *itr1, *itr2, "CARD-SP" );
-              return;
+              if (sendSplit(*itr1, *itr2, "CARD-SP"))
+              {
+                return;
+              }
             }
           }
         }
index 70706bbd441ec31c50c578bae70cd5e1964d8957..f07057444eb70c0a183cc59bd5c7236bd5538ca2 100644 (file)
@@ -409,7 +409,7 @@ protected:
   void sendInference( std::vector< Node >& exp, Node eq, const char * c, bool asLemma = false );
   void sendLemma( Node ant, Node conc, const char * c );
   void sendInfer( Node eq_exp, Node eq, const char * c );
-  void sendSplit( Node a, Node b, const char * c, bool preq = true );
+  bool sendSplit(Node a, Node b, const char* c, bool preq = true);
   void sendLengthLemma( Node n );
   /** mkConcat **/
   inline Node mkConcat( Node n1, Node n2 );
index 5cb58729ee4c7c430c1ad05b6e1196b7c41c7f9b..a478667e9c24536266d6134e83a4ec517d1036ed 100644 (file)
@@ -196,11 +196,98 @@ Node TheoryStringsRewriter::simpleRegexpConsume( std::vector< Node >& mchildren,
   return Node::null();
 }
 
+Node TheoryStringsRewriter::rewriteEquality(Node node)
+{
+  Assert(node.getKind() == kind::EQUAL);
+  if (node[0] == node[1])
+  {
+    return NodeManager::currentNM()->mkConst(true);
+  }
+  else if (node[0].isConst() && node[1].isConst())
+  {
+    return NodeManager::currentNM()->mkConst(false);
+  }
+  // ( ~contains( s, t ) V ~contains( t, s ) ) => ( s == t ---> false )
+  for (unsigned r = 0; r < 2; r++)
+  {
+    Node ctn = NodeManager::currentNM()->mkNode(
+        kind::STRING_STRCTN, node[r], node[1 - r]);
+    // must call rewrite contains directly to avoid infinite loop
+    // we do a fix point since we may rewrite contains terms to simpler
+    // contains terms.
+    Node prev;
+    do
+    {
+      prev = ctn;
+      ctn = rewriteContains(ctn);
+    } while (prev != ctn && ctn.getKind() == kind::STRING_STRCTN);
+    if (ctn.isConst())
+    {
+      if (!ctn.getConst<bool>())
+      {
+        return returnRewrite(node, ctn, "eq-nctn");
+      }
+      else
+      {
+        // definitely contains but not syntactically equal
+        // We may be able to simplify, e.g.
+        //  str.++( x, "a" ) == "a"  ----> x = ""
+      }
+    }
+  }
+
+  std::vector<Node> c[2];
+  for (unsigned i = 0; i < 2; i++)
+  {
+    strings::TheoryStringsRewriter::getConcat(node[i], c[i]);
+  }
+
+  // check if the prefix, suffix mismatches
+  //   For example, str.++( x, "a", y ) == str.++( x, "bc", z ) ---> false
+  unsigned minsize = std::min(c[0].size(), c[1].size());
+  for (unsigned r = 0; r < 2; r++)
+  {
+    for (unsigned i = 0; i < minsize; i++)
+    {
+      unsigned index1 = r == 0 ? i : (c[0].size() - 1) - i;
+      unsigned index2 = r == 0 ? i : (c[1].size() - 1) - i;
+      if (c[0][index1].isConst() && c[1][index2].isConst())
+      {
+        CVC4::String s = c[0][index1].getConst<String>();
+        CVC4::String t = c[1][index2].getConst<String>();
+        unsigned len_short = s.size() <= t.size() ? s.size() : t.size();
+        bool isSameFix =
+            r == 1 ? s.rstrncmp(t, len_short) : s.strncmp(t, len_short);
+        if (!isSameFix)
+        {
+          Node ret = NodeManager::currentNM()->mkConst(false);
+          return returnRewrite(node, ret, "eq-nfix");
+        }
+      }
+      if (c[0][index1] != c[1][index2])
+      {
+        break;
+      }
+    }
+  }
+
+  // standard ordering
+  if (node[0] > node[1])
+  {
+    return NodeManager::currentNM()->mkNode(kind::EQUAL, node[1], node[0]);
+  }
+  else
+  {
+    return node;
+  }
+}
+
 // TODO (#1180) add rewrite
 //  str.++( str.substr( x, n1, n2 ), str.substr( x, n1+n2, n3 ) ) --->
 //  str.substr( x, n1, n2+n3 )
 Node TheoryStringsRewriter::rewriteConcat(Node node)
 {
+  Assert(node.getKind() == kind::STRING_CONCAT);
   Trace("strings-prerewrite") << "Strings::rewriteConcat start " << node
                               << std::endl;
   Node retNode = node;
@@ -1009,15 +1096,7 @@ RewriteResponse TheoryStringsRewriter::postRewrite(TNode node) {
   if(node.getKind() == kind::STRING_CONCAT) {
     retNode = rewriteConcat(node);
   } else if(node.getKind() == kind::EQUAL) {
-    Node leftNode  = node[0];
-    Node rightNode = node[1];
-    if(leftNode == rightNode) {
-      retNode = NodeManager::currentNM()->mkConst(true);
-    } else if(leftNode.isConst() && rightNode.isConst()) {
-      retNode = NodeManager::currentNM()->mkConst(false);
-    } else if(leftNode > rightNode) {
-      retNode = NodeManager::currentNM()->mkNode(kind::EQUAL, rightNode, leftNode);
-    }
+    retNode = rewriteEquality(node);
   } else if(node.getKind() == kind::STRING_LENGTH) {
     if( node[0].isConst() ){
       retNode = NodeManager::currentNM()->mkConst( ::CVC4::Rational( node[0].getConst<String>().size() ) );
@@ -1320,6 +1399,7 @@ RewriteResponse TheoryStringsRewriter::preRewrite(TNode node) {
 
 Node TheoryStringsRewriter::rewriteSubstr(Node node)
 {
+  Assert(node.getKind() == kind::STRING_SUBSTR);
   if (node[0].isConst())
   {
     if (node[0].getConst<String>().size() == 0)
@@ -1535,6 +1615,7 @@ Node TheoryStringsRewriter::rewriteSubstr(Node node)
 }
 
 Node TheoryStringsRewriter::rewriteContains( Node node ) {
+  Assert(node.getKind() == kind::STRING_STRCTN);
   if( node[0] == node[1] ){
     Node ret = NodeManager::currentNM()->mkConst(true);
     return returnRewrite(node, ret, "ctn-eq");
@@ -1599,7 +1680,88 @@ Node TheoryStringsRewriter::rewriteContains( Node node ) {
     Node ret = NodeManager::currentNM()->mkConst(false);
     return returnRewrite(node, ret, "ctn-len-ineq");
   }
-  else if (checkEntailArithEq(len_n1, len_n2))
+
+  // multi-set reasoning
+  //   For example, contains( str.++( x, "b" ), str.++( "a", x ) ) ---> false
+  //   since the number of a's in the second argument is greater than the number
+  //   of a's in the first argument
+  std::map<Node, unsigned> num_nconst[2];
+  std::map<Node, unsigned> num_const[2];
+  for (unsigned j = 0; j < 2; j++)
+  {
+    std::vector<Node>& ncj = j == 0 ? nc1 : nc2;
+    for (const Node& cc : ncj)
+    {
+      if (cc.isConst())
+      {
+        num_const[j][cc]++;
+      }
+      else
+      {
+        num_nconst[j][cc]++;
+      }
+    }
+  }
+  bool ms_success = true;
+  for (std::pair<const Node, unsigned>& nncp : num_nconst[0])
+  {
+    if (nncp.second > num_nconst[1][nncp.first])
+    {
+      ms_success = false;
+      break;
+    }
+  }
+  if (ms_success)
+  {
+    // count the number of constant characters in the first argument
+    std::map<Node, unsigned> count_const[2];
+    std::vector<Node> chars;
+    for (unsigned j = 0; j < 2; j++)
+    {
+      for (std::pair<const Node, unsigned>& ncp : num_const[j])
+      {
+        Node cn = ncp.first;
+        Assert(cn.isConst());
+        std::vector<unsigned> cc_vec;
+        const std::vector<unsigned>& cvec = cn.getConst<String>().getVec();
+        for (unsigned i = 0, size = cvec.size(); i < size; i++)
+        {
+          // make the character
+          cc_vec.clear();
+          cc_vec.insert(cc_vec.end(), cvec.begin() + i, cvec.begin() + i + 1);
+          Node ch = NodeManager::currentNM()->mkConst(String(cc_vec));
+          count_const[j][ch] += ncp.second;
+          if (std::find(chars.begin(), chars.end(), ch) == chars.end())
+          {
+            chars.push_back(ch);
+          }
+        }
+      }
+    }
+    Trace("strings-rewrite-multiset") << "For " << node << " : " << std::endl;
+    for (const Node& ch : chars)
+    {
+      Trace("strings-rewrite-multiset") << "  # occurrences of substring ";
+      Trace("strings-rewrite-multiset") << ch << " in arguments is ";
+      Trace("strings-rewrite-multiset") << count_const[0][ch] << " / "
+                                        << count_const[1][ch] << std::endl;
+      if (count_const[0][ch] < count_const[1][ch])
+      {
+        Node ret = NodeManager::currentNM()->mkConst(false);
+        return returnRewrite(node, ret, "ctn-mset-nss");
+      }
+    }
+    // TODO (#1180): count the number of 2,3,4,.. character substrings
+    // for example:
+    // str.contains( str.++( x, "cbabc" ), str.++( "cabbc", x ) ) ---> false
+    // since the second argument contains more occurrences of "bb".
+    // note this is orthogonal reasoning to inductive reasoning
+    // via regular membership reduction in Liang et al CAV 2015.
+  }
+  // TODO (#1180): abstract interpretation with multi-set domain
+  // to show first argument is a strict subset of second argument
+
+  if (checkEntailArithEq(len_n1, len_n2))
   {
     // len( n2 ) = len( n1 ) => contains( n1, n2 ) ---> n1 = n2
     Node ret = node[0].eqNode(node[1]);
@@ -1654,6 +1816,7 @@ Node TheoryStringsRewriter::rewriteContains( Node node ) {
 }
 
 Node TheoryStringsRewriter::rewriteIndexof( Node node ) {
+  Assert(node.getKind() == kind::STRING_STRIDOF);
   std::vector< Node > children;
   getConcat( node[0], children );
   //std::vector< Node > children1;
@@ -1759,6 +1922,7 @@ Node TheoryStringsRewriter::rewriteIndexof( Node node ) {
 }
 
 Node TheoryStringsRewriter::rewriteReplace( Node node ) {
+  Assert(node.getKind() == kind::STRING_STRREPL);
   if( node[1]==node[2] ){
     return returnRewrite(node, node[0], "rpl-id");
   }
index 64120eca04d20a2327f69824fde872c4b45b909f..194e9bbe5b58fbb21a33c9e6cc1517b52c30b57a 100644 (file)
@@ -63,6 +63,12 @@ private:
 
   static inline void init() {}
   static inline void shutdown() {}
+  /** rewrite equality
+   *
+   * This method returns a formula that is equivalent to the equality between
+   * two strings, given by node.
+   */
+  static Node rewriteEquality(Node node);
   /** rewrite concat
   * This is the entry point for post-rewriting terms node of the form
   *   str.++( t1, .., tn )