patch for regular expression intersection caching
authorTianyi Liang <tianyi-liang@uiowa.edu>
Fri, 25 Jul 2014 18:43:44 +0000 (13:43 -0500)
committerTianyi Liang <tianyi-liang@uiowa.edu>
Fri, 25 Jul 2014 18:43:44 +0000 (13:43 -0500)
src/theory/strings/kinds
src/theory/strings/regexp_operation.cpp
src/theory/strings/regexp_operation.h
src/theory/strings/theory_strings.cpp
src/theory/strings/theory_strings_rewriter.cpp
src/theory/strings/theory_strings_type_rules.h
test/regress/regress0/strings/Makefile.am

index 4266c02f5b132ced41db6d2d461e7d63c80e4ee7..0f68d120765c9b7674c5116d1e519392aa4a989b 100644 (file)
@@ -88,6 +88,11 @@ operator REGEXP_LOOP 2:3 "regexp loop"
 operator REGEXP_EMPTY 0 "regexp empty"
 operator REGEXP_SIGMA 0 "regexp all characters"
 
+#internal
+operator REGEXP_RV 1 "regexp rv (internal use only)"
+typerule REGEXP_RV ::CVC4::theory::strings::RegExpRVTypeRule
+
+#typerules
 typerule REGEXP_CONCAT ::CVC4::theory::strings::RegExpConcatTypeRule
 typerule REGEXP_UNION ::CVC4::theory::strings::RegExpUnionTypeRule
 typerule REGEXP_INTER ::CVC4::theory::strings::RegExpInterTypeRule
index 369278994bab3ec09d2ab34fd9952e93826681e8..20fbf28702400ad3c36830444d140999533c7ecc 100644 (file)
@@ -1183,8 +1183,19 @@ void RegExpOpr::getCharSet( Node r, std::set<unsigned> &pcset, SetNodes &pvset )
   }
 }
 
+bool RegExpOpr::isPairNodesInSet(std::set< PairNodes > &s, Node n1, Node n2) {
+  for(std::set< PairNodes >::const_iterator itr = s.begin();
+      itr != s.end(); ++itr) {
+    if(itr->first == n1 && itr->second == n2 ||
+       itr->first == n2 && itr->second == n1) {
+      return true;
+    }
+  }
+  return false;
+}
 
 Node RegExpOpr::intersectInternal( Node r1, Node r2, std::map< unsigned, std::set< PairNodes > > cache, bool &spflag ) {
+  Trace("regexp-intersect") << "Starting INTERSECT:\n  "<< mkString(r1) << ",\n  " << mkString(r2) << std::endl;
   if(spflag) {
     //TODO: var
     return Node::null();
@@ -1230,11 +1241,18 @@ Node RegExpOpr::intersectInternal( Node r1, Node r2, std::map< unsigned, std::se
             spflag = true;
           }
         }
+        if(Trace.isOn("regexp-debug")) {
+          Trace("regexp-debug") << "Try CSET( " << cset.size() << " ) = ";
+          for(std::set<unsigned>::const_iterator itr = cset.begin();
+            itr != cset.end(); itr++) {
+            Trace("regexp-debug") << *itr << ", ";
+          }
+          Trace("regexp-debug") << std::endl;
+        }
         for(std::set<unsigned>::const_iterator itr = cset.begin();
           itr != cset.end(); itr++) {
           CVC4::String c( CVC4::String::convertUnsignedIntToChar(*itr) );
-          std::pair< Node, Node > p(r1, r2);
-          if(cache[ *itr ].find(p) == cache[ *itr ].end()) {
+          if(!isPairNodesInSet(cache[ *itr ], r1, r2)) {
             Node r1l = derivativeSingle(r1, c);
             Node r2l = derivativeSingle(r2, c);
             std::map< unsigned, std::set< PairNodes > > cache2(cache);
@@ -1263,10 +1281,201 @@ Node RegExpOpr::intersectInternal( Node r1, Node r2, std::map< unsigned, std::se
   Trace("regexp-intersect") << "INTERSECT( " << mkString(r1) << ", " << mkString(r2) << " ) = " << mkString(rNode) << std::endl;
   return rNode;
 }
+
+bool RegExpOpr::containC2(unsigned cnt, Node n) {
+  if(n.getKind() == kind::REGEXP_RV) {
+    unsigned y = n[0].getConst<Rational>().getNumerator().toUnsignedInt();
+    return cnt == y;
+  } else if(n.getKind() == kind::REGEXP_CONCAT) {
+    for( unsigned i=0; i<n.getNumChildren(); i++ ) {
+      if(containC2(cnt, n[i])) {
+        return true;
+      }
+    }
+  } else if(n.getKind() == kind::REGEXP_STAR) {
+    return containC2(cnt, n[0]);
+  } else if(n.getKind() == kind::REGEXP_UNION) {
+    for( unsigned i=0; i<n.getNumChildren(); i++ ) {
+      if(containC2(cnt, n[i])) {
+        return true;
+      }
+    }
+  }
+  return false;
+}
+Node RegExpOpr::convert1(unsigned cnt, Node n) {
+  Trace("regexp-debug") << "Converting " << n << " ... " << std::endl;
+  Node r1, r2;
+  convert2(cnt, n, r1, r2);
+  Trace("regexp-debug") << "... getting r1=" << r1 << ", and r2=" << r2 << std::endl;
+  Node ret = NodeManager::currentNM()->mkNode(kind::REGEXP_CONCAT, 
+     NodeManager::currentNM()->mkNode(kind::REGEXP_STAR, r1), r2);
+  ret = Rewriter::rewrite( ret );
+  Trace("regexp-debug") << "... done convert, with return " << ret << std::endl;
+  return ret;
+}
+void RegExpOpr::convert2(unsigned cnt, Node n, Node &r1, Node &r2) {
+  if(n == d_emptyRegexp) {
+    r1 = d_emptyRegexp;
+    r2 = d_emptyRegexp;
+  } else if(n == d_emptySingleton) {
+    r1 = d_emptySingleton;
+    r2 = d_emptySingleton;
+  } else if(n.getKind() == kind::REGEXP_RV) {
+    unsigned y = n[0].getConst<Rational>().getNumerator().toUnsignedInt();
+    r1 = d_emptySingleton;
+    if(cnt == y) {
+      r2 = d_emptyRegexp;
+    } else {
+      r2 = n;
+    }
+  } else if(n.getKind() == kind::REGEXP_CONCAT) {
+    //TODO
+    //convert2 x (r@(Seq l r1))
+    //   | contains x r1 = let (r2,r3) = convert2 x r1
+    //                     in (Seq l r2, r3)
+    //   | otherwise = (Empty, r)
+    bool flag = true;
+    std::vector<Node> vr1, vr2;
+    for( unsigned i=0; i<n.getNumChildren(); i++ ) {
+      if(containC2(cnt, n[i])) {
+        Node t1, t2;
+        convert2(cnt, n[i], t1, t2);
+        vr1.push_back(t1);
+        r1 = vr1.size()==0 ? d_emptyRegexp : vr1.size()==1 ? vr1[0] :
+             NodeManager::currentNM()->mkNode(kind::REGEXP_CONCAT, vr1);
+        vr2.push_back(t2);
+        for( unsigned j=i+1; j<n.getNumChildren(); j++ ) {
+          vr2.push_back(n[j]);
+        }
+        r2 = vr2.size()==0 ? d_emptyRegexp : vr2.size()==1 ? vr2[0] :
+             NodeManager::currentNM()->mkNode(kind::REGEXP_CONCAT, vr2);
+        flag = false;
+        break;
+      } else {
+        vr1.push_back(n[i]);
+      }
+    }
+    if(flag) {
+      r1 = d_emptySingleton;
+      r2 = n;
+    }
+  } else if(n.getKind() == kind::REGEXP_UNION) {
+    std::vector<Node> vr1, vr2;
+    for( unsigned i=0; i<n.getNumChildren(); i++ ) {
+      Node t1, t2;
+      convert2(cnt, n[i], t1, t2);
+      vr1.push_back(t1);
+      vr2.push_back(t2);
+    }
+    r1 = NodeManager::currentNM()->mkNode(kind::REGEXP_UNION, vr1);
+    r2 = NodeManager::currentNM()->mkNode(kind::REGEXP_UNION, vr2);
+  } else if(n.getKind() == kind::STRING_TO_REGEXP) {
+      r1 = d_emptySingleton;
+      r2 = n;
+  } else {
+    //is it possible?
+  }
+}
+Node RegExpOpr::intersectInternal2( Node r1, Node r2, std::map< PairNodes, Node > cache, bool &spflag, unsigned cnt ) {
+  Trace("regexp-intersect") << "Starting INTERSECT:\n  "<< mkString(r1) << ",\n  " << mkString(r2) << std::endl;
+  if(spflag) {
+    //TODO: var
+    return Node::null();
+  }
+  std::pair < Node, Node > p(r1, r2);
+  std::map < std::pair< Node, Node >, Node >::const_iterator itr = d_inter_cache.find(p);
+  Node rNode;
+  if(itr != d_inter_cache.end()) {
+    rNode = itr->second;
+  } else {
+    if(r1 == d_emptyRegexp || r2 == d_emptyRegexp) {
+      rNode = d_emptyRegexp;
+    } else if(r1 == d_emptySingleton || r2 == d_emptySingleton) {
+      Node exp;
+      int r = delta((r1 == d_emptySingleton ? r2 : r1), exp);
+      if(r == 0) {
+        //TODO: variable
+        spflag = true;
+      } else if(r == 1) {
+        rNode = d_emptySingleton;
+      } else {
+        rNode = d_emptyRegexp;
+      }
+    } else if(r1 == r2) {
+      rNode = convert1(cnt, r1);
+    } else {
+      PairNodes p(r1, r2);
+      std::map< PairNodes, Node >::const_iterator itrcache = cache.find(p);
+      if(itrcache != cache.end()) {
+        rNode = convert1(cnt, itrcache->second);
+      } else {
+        if(checkConstRegExp(r1) && checkConstRegExp(r2)) {
+          std::vector< unsigned > cset;
+          std::set< unsigned > cset1, cset2;
+          std::set< Node > vset1, vset2;
+          firstChars(r1, cset1, vset1);
+          firstChars(r2, cset2, vset2);
+          std::set_intersection(cset1.begin(), cset1.end(), cset2.begin(), cset1.end(),
+               std::inserter(cset, cset.begin()));
+          std::vector< Node > vec_nodes;
+          Node delta_exp;
+          int flag = delta(r1, delta_exp);
+          int flag2 = delta(r2, delta_exp);
+          if(flag != 2 && flag2 != 2) {
+            if(flag == 1 && flag2 == 1) {
+              vec_nodes.push_back(d_emptySingleton);
+            } else {
+              //TODO
+              spflag = true;
+            }
+          }
+          if(Trace.isOn("regexp-debug")) {
+            Trace("regexp-debug") << "Try CSET( " << cset.size() << " ) = ";
+            for(std::vector<unsigned>::const_iterator itr = cset.begin();
+              itr != cset.end(); itr++) {
+              CVC4::String c( CVC4::String::convertUnsignedIntToChar(*itr) );
+              Trace("regexp-debug") << c << ", ";
+            }
+            Trace("regexp-debug") << std::endl;
+          }
+          for(std::vector<unsigned>::const_iterator itr = cset.begin();
+            itr != cset.end(); itr++) {
+            CVC4::String c( CVC4::String::convertUnsignedIntToChar(*itr) );
+            Node r1l = derivativeSingle(r1, c);
+            Node r2l = derivativeSingle(r2, c);
+            std::map< PairNodes, Node > cache2(cache);
+            PairNodes p(r1l, r2l);
+            cache2[ p ] = NodeManager::currentNM()->mkNode(kind::REGEXP_RV, NodeManager::currentNM()->mkConst(CVC4::Rational(cnt)));
+            Node rt = intersectInternal2(r1l, r2l, cache2, spflag, cnt+1);
+            rt = convert1(cnt, rt);
+            if(spflag) {
+              //TODO:
+              return Node::null();
+            }
+            rt = Rewriter::rewrite( NodeManager::currentNM()->mkNode(kind::REGEXP_CONCAT,
+              NodeManager::currentNM()->mkNode(kind::STRING_TO_REGEXP, NodeManager::currentNM()->mkConst(c)), rt) );
+            vec_nodes.push_back(rt);
+          }
+          rNode = vec_nodes.size()==0 ? d_emptyRegexp : vec_nodes.size()==1 ? vec_nodes[0] :
+              NodeManager::currentNM()->mkNode(kind::REGEXP_UNION, vec_nodes);
+          rNode = Rewriter::rewrite( rNode );
+        } else {
+          //TODO: non-empty var set
+          spflag = true;
+        }
+      }
+    }
+    d_inter_cache[p] = rNode;
+  }
+  Trace("regexp-intersect") << "End of INTERSECT( " << mkString(r1) << ", " << mkString(r2) << " ) = " << mkString(rNode) << std::endl;
+  return rNode;
+}
 Node RegExpOpr::intersect(Node r1, Node r2, bool &spflag) {
-  std::map< unsigned, std::set< PairNodes > > cache;
+  //std::map< unsigned, std::set< PairNodes > > cache;
+  std::map< PairNodes, Node > cache;
   if(checkConstRegExp(r1) && checkConstRegExp(r2)) {
-    return intersectInternal(r1, r2, cache, spflag);
+    return intersectInternal2(r1, r2, cache, spflag, 1);
   } else {
     spflag = true;
     return Node::null();
index e4ae1208d96fd8891fca515dc4546153d31a2c5c..2ae578cd61b2338672f5e6415385e36c317c80cd 100644 (file)
@@ -69,9 +69,14 @@ private:
   std::string niceChar( Node r );
   int gcd ( int a, int b );
   Node mkAllExceptOne( char c );
+  bool isPairNodesInSet(std::set< PairNodes > &s, Node n1, Node n2);
 
   void getCharSet( Node r, std::set<unsigned> &pcset, SetNodes &pvset );
   Node intersectInternal( Node r1, Node r2, std::map< unsigned, std::set< PairNodes > > cache, bool &spflag );
+  bool containC2(unsigned cnt, Node n);
+  Node convert1(unsigned cnt, Node n);
+  void convert2(unsigned cnt, Node n, Node &r1, Node &r2);
+  Node intersectInternal2( Node r1, Node r2, std::map< PairNodes, Node > cache, bool &spflag, unsigned cnt );
   void firstChars( Node r, std::set<unsigned> &pcset, SetNodes &pvset );
 
   //TODO: for intersection
index 2856ce1e09159bcce72992b9d7a477771454848c..30e52966302072b28b40de81188d15b46e11e5f2 100644 (file)
@@ -2500,6 +2500,7 @@ bool TheoryStrings::checkMemberships() {
   std::vector< Node > processed;
   std::vector< Node > cprocessed;
 
+  Trace("regexp-debug") << "Checking Memberships ... " << std::endl;
   //if(options::stringEIT()) {
     //TODO: Opt for normal forms
     for(NodeListMap::const_iterator itr_xr = d_str_re_map.begin();
@@ -2507,6 +2508,7 @@ bool TheoryStrings::checkMemberships() {
       bool spflag = false;
       Node x = (*itr_xr).first;
       NodeList* lst = (*itr_xr).second;
+      Trace("regexp-debug") << "Checking Memberships for " << x << std::endl;
       if(d_inter_index.find(x) == d_inter_index.end()) {
         d_inter_index[x] = 0;
       }
@@ -2515,6 +2517,7 @@ bool TheoryStrings::checkMemberships() {
         if(lst->size() == 1) {
           d_inter_cache[x] = (*lst)[0];
           d_inter_index[x] = 1;
+          Trace("regexp-debug") << "... only one choice " << std::endl;
         } else if(lst->size() > 1) {
           Node r;
           if(d_inter_cache.find(x) != d_inter_cache.end()) {
@@ -2528,6 +2531,7 @@ bool TheoryStrings::checkMemberships() {
           for(int i=0; i<cur_inter_idx; i++) {
             ++itr_lst;
           }
+          Trace("regexp-debug") << "... staring from : " << cur_inter_idx << ", we have " << lst->size() << std::endl;
           for(;itr_lst != lst->end(); ++itr_lst) {
             Node r2 = *itr_lst;
             r = d_regexp_opr.intersect(r, r2, spflag);
@@ -2561,6 +2565,7 @@ bool TheoryStrings::checkMemberships() {
     }
   //}
 
+  Trace("regexp-debug") << "... No Intersec Conflict in Memberships " << std::endl;
   if(!addedLemma) {
     for( unsigned i=0; i<d_regexp_memberships.size(); i++ ) {
       //check regular expression membership
index 12ff92b5e3c1784296e8d7d719a8b9d83c9dc2aa..1014a95db89c0db87f617b3a2b96bbb4a83d06d9 100644 (file)
@@ -302,7 +302,7 @@ Node TheoryStringsRewriter::rewriteMembership(TNode node) {
 
   if(node[1].getKind() == kind::REGEXP_EMPTY) {
     retNode = NodeManager::currentNM()->mkConst( false );
-  } else if( x.getKind() == kind::CONST_STRING && checkConstRegExp(node[1]) ) {
+  } else if(x.getKind()==kind::CONST_STRING && checkConstRegExp(node[1])) {
     //test whether x in node[1]
     CVC4::String s = x.getConst<String>();
     retNode = NodeManager::currentNM()->mkConst( testConstStringInRegExp( s, 0, node[1] ) );
@@ -311,10 +311,12 @@ Node TheoryStringsRewriter::rewriteMembership(TNode node) {
     retNode = one.eqNode(NodeManager::currentNM()->mkNode(kind::STRING_LENGTH, x));
   } else if(node[1].getKind() == kind::REGEXP_STAR && node[1][0].getKind() == kind::REGEXP_SIGMA) {
     retNode = NodeManager::currentNM()->mkConst( true );
-  } else if( x != node[0] ) {
+  } else if(node[1].getKind() == kind::STRING_TO_REGEXP) {
+    retNode = x.eqNode(node[1][0]);
+  } else if(x != node[0]) {
     retNode = NodeManager::currentNM()->mkNode( kind::STRING_IN_REGEXP, x, node[1] );
   }
-    return retNode;
+  return retNode;
 }
 
 RewriteResponse TheoryStringsRewriter::postRewrite(TNode node) {
index 6d1bb1c98ddfcaaf2afa7000ac809f73621fe865..8a51ea36ca85c79c56a8769dff3a4a0c08539c55 100644 (file)
@@ -466,6 +466,21 @@ public:
   }
 };
 
+class RegExpRVTypeRule {
+public:
+  inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check)
+      throw (TypeCheckingExceptionPrivate, AssertionException) {
+    if( check ) {
+      TypeNode t = n[0].getType(check);
+      if (!t.isInteger()) {
+        throw TypeCheckingExceptionPrivate(n, "expecting an integer term in RV");
+      }
+    }
+    return nodeManager->regexpType();
+  }
+};
+
+
 }/* CVC4::theory::strings namespace */
 }/* CVC4::theory namespace */
 }/* CVC4 namespace */
index ddc0eae7c10a29e805ff5c4a8ddf24fe3c36b920..233962d723bda682d501a51c256ab82ebac159e9 100644 (file)
@@ -30,7 +30,6 @@ TESTS =       \
   str005.smt2 \
   str006.smt2 \
   str007.smt2 \
-  fmf001.smt2 \
   fmf002.smt2 \
   type001.smt2 \
   type003.smt2 \
@@ -53,6 +52,7 @@ TESTS =       \
 FAILING_TESTS =
 
 EXTRA_DIST = $(TESTS) \
+  fmf001.smt2 \
   regexp002.smt2 \
   type002.smt2