Merge branch '1.4.x'
[cvc5.git] / src / theory / strings / regexp_operation.cpp
index 092c7e16643fcdf126ef39bb46b88eb2c3e91846..e769eb712dfae991246eb4a17cdf5a407899132c 100644 (file)
@@ -1,12 +1,11 @@
 /*********************                                                        */
-
 /*! \file regexp_operation.cpp
  ** \verbatim
  ** Original author: Tianyi Liang
- ** Major contributors: none
+ ** Major contributors: Morgan Deters
  ** Minor contributors (to current version): none
  ** This file is part of the CVC4 project.
- ** Copyright (c) 2009-2013  New York University and The University of Iowa
+ ** Copyright (c) 2009-2014  New York University and The University of Iowa
  ** See the file COPYING in the top-level source directory for licensing
  ** information.\endverbatim
  **
@@ -1184,8 +1183,19 @@ void RegExpOpr::getCharSet( Node r, std::set<unsigned> &pcset, SetNodes &pvset )
   }
 }
 
+bool RegExpOpr::isPairNodesInSet(std::set< PairNodes > &s, Node n1, Node n2) {
+  for(std::set< PairNodes >::const_iterator itr = s.begin();
+      itr != s.end(); ++itr) {
+    if(itr->first == n1 && itr->second == n2 ||
+       itr->first == n2 && itr->second == n1) {
+      return true;
+    }
+  }
+  return false;
+}
 
 Node RegExpOpr::intersectInternal( Node r1, Node r2, std::map< unsigned, std::set< PairNodes > > cache, bool &spflag ) {
+  Trace("regexp-intersect") << "Starting INTERSECT:\n  "<< mkString(r1) << ",\n  " << mkString(r2) << std::endl;
   if(spflag) {
     //TODO: var
     return Node::null();
@@ -1231,11 +1241,18 @@ Node RegExpOpr::intersectInternal( Node r1, Node r2, std::map< unsigned, std::se
             spflag = true;
           }
         }
+        if(Trace.isOn("regexp-debug")) {
+          Trace("regexp-debug") << "Try CSET( " << cset.size() << " ) = ";
+          for(std::set<unsigned>::const_iterator itr = cset.begin();
+            itr != cset.end(); itr++) {
+            Trace("regexp-debug") << *itr << ", ";
+          }
+          Trace("regexp-debug") << std::endl;
+        }
         for(std::set<unsigned>::const_iterator itr = cset.begin();
           itr != cset.end(); itr++) {
           CVC4::String c( CVC4::String::convertUnsignedIntToChar(*itr) );
-          std::pair< Node, Node > p(r1, r2);
-          if(cache[ *itr ].find(p) == cache[ *itr ].end()) {
+          if(!isPairNodesInSet(cache[ *itr ], r1, r2)) {
             Node r1l = derivativeSingle(r1, c);
             Node r2l = derivativeSingle(r2, c);
             std::map< unsigned, std::set< PairNodes > > cache2(cache);
@@ -1264,10 +1281,208 @@ Node RegExpOpr::intersectInternal( Node r1, Node r2, std::map< unsigned, std::se
   Trace("regexp-intersect") << "INTERSECT( " << mkString(r1) << ", " << mkString(r2) << " ) = " << mkString(rNode) << std::endl;
   return rNode;
 }
+
+bool RegExpOpr::containC2(unsigned cnt, Node n) {
+  if(n.getKind() == kind::REGEXP_RV) {
+    unsigned y = n[0].getConst<Rational>().getNumerator().toUnsignedInt();
+    return cnt == y;
+  } else if(n.getKind() == kind::REGEXP_CONCAT) {
+    for( unsigned i=0; i<n.getNumChildren(); i++ ) {
+      if(containC2(cnt, n[i])) {
+        return true;
+      }
+    }
+  } else if(n.getKind() == kind::REGEXP_STAR) {
+    return containC2(cnt, n[0]);
+  } else if(n.getKind() == kind::REGEXP_UNION) {
+    for( unsigned i=0; i<n.getNumChildren(); i++ ) {
+      if(containC2(cnt, n[i])) {
+        return true;
+      }
+    }
+  }
+  return false;
+}
+Node RegExpOpr::convert1(unsigned cnt, Node n) {
+  Trace("regexp-debug") << "Converting " << n << " at " << cnt << "... " << std::endl;
+  Node r1, r2;
+  convert2(cnt, n, r1, r2);
+  Trace("regexp-debug") << "... getting r1=" << r1 << ", and r2=" << r2 << std::endl;
+  Node ret = NodeManager::currentNM()->mkNode(kind::REGEXP_CONCAT, 
+     NodeManager::currentNM()->mkNode(kind::REGEXP_STAR, r1), r2);
+  ret = Rewriter::rewrite( ret );
+  Trace("regexp-debug") << "... done convert at " << cnt << ", with return " << ret << std::endl;
+  return ret;
+}
+void RegExpOpr::convert2(unsigned cnt, Node n, Node &r1, Node &r2) {
+  if(n == d_emptyRegexp) {
+    r1 = d_emptyRegexp;
+    r2 = d_emptyRegexp;
+  } else if(n == d_emptySingleton) {
+    r1 = d_emptySingleton;
+    r2 = d_emptySingleton;
+  } else if(n.getKind() == kind::REGEXP_RV) {
+    unsigned y = n[0].getConst<Rational>().getNumerator().toUnsignedInt();
+    r1 = d_emptySingleton;
+    if(cnt == y) {
+      r2 = d_emptyRegexp;
+    } else {
+      r2 = n;
+    }
+  } else if(n.getKind() == kind::REGEXP_CONCAT) {
+    //TODO
+    //convert2 x (r@(Seq l r1))
+    //   | contains x r1 = let (r2,r3) = convert2 x r1
+    //                     in (Seq l r2, r3)
+    //   | otherwise = (Empty, r)
+    bool flag = true;
+    std::vector<Node> vr1, vr2;
+    for( unsigned i=0; i<n.getNumChildren(); i++ ) {
+      if(containC2(cnt, n[i])) {
+        Node t1, t2;
+        convert2(cnt, n[i], t1, t2);
+        vr1.push_back(t1);
+        r1 = vr1.size()==0 ? d_emptyRegexp : vr1.size()==1 ? vr1[0] :
+             NodeManager::currentNM()->mkNode(kind::REGEXP_CONCAT, vr1);
+        vr2.push_back(t2);
+        for( unsigned j=i+1; j<n.getNumChildren(); j++ ) {
+          vr2.push_back(n[j]);
+        }
+        r2 = vr2.size()==0 ? d_emptyRegexp : vr2.size()==1 ? vr2[0] :
+             NodeManager::currentNM()->mkNode(kind::REGEXP_CONCAT, vr2);
+        flag = false;
+        break;
+      } else {
+        vr1.push_back(n[i]);
+      }
+    }
+    if(flag) {
+      r1 = d_emptySingleton;
+      r2 = n;
+    }
+  } else if(n.getKind() == kind::REGEXP_UNION) {
+    std::vector<Node> vr1, vr2;
+    for( unsigned i=0; i<n.getNumChildren(); i++ ) {
+      Node t1, t2;
+      convert2(cnt, n[i], t1, t2);
+      vr1.push_back(t1);
+      vr2.push_back(t2);
+    }
+    r1 = NodeManager::currentNM()->mkNode(kind::REGEXP_UNION, vr1);
+    r2 = NodeManager::currentNM()->mkNode(kind::REGEXP_UNION, vr2);
+  } else if(n.getKind() == kind::STRING_TO_REGEXP) {
+      r1 = d_emptySingleton;
+      r2 = n;
+  } else {
+    //is it possible?
+  }
+}
+Node RegExpOpr::intersectInternal2( Node r1, Node r2, std::map< PairNodes, Node > cache, bool &spflag, unsigned cnt ) {
+  Trace("regexp-intersect") << "Starting INTERSECT:\n  "<< mkString(r1) << ",\n  " << mkString(r2) << std::endl;
+  //if(Trace.isOn("regexp-debug")) {
+  //  Trace("regexp-debug") << "... with cache:\n";
+  //  for(std::map< PairNodes, Node >::const_iterator itr=cache.begin();
+  //      itr!=cache.end();itr++) {
+  //        Trace("regexp-debug") << "(" << itr->first.first << "," << itr->first.second << ")->" << itr->second << std::endl;
+  //      }
+  //}
+  if(spflag) {
+    //TODO: var
+    return Node::null();
+  }
+  std::pair < Node, Node > p(r1, r2);
+  std::map < std::pair< Node, Node >, Node >::const_iterator itr = d_inter_cache.find(p);
+  Node rNode;
+  if(itr != d_inter_cache.end()) {
+    rNode = itr->second;
+  } else {
+    if(r1 == d_emptyRegexp || r2 == d_emptyRegexp) {
+      rNode = d_emptyRegexp;
+    } else if(r1 == d_emptySingleton || r2 == d_emptySingleton) {
+      Node exp;
+      int r = delta((r1 == d_emptySingleton ? r2 : r1), exp);
+      if(r == 0) {
+        //TODO: variable
+        spflag = true;
+      } else if(r == 1) {
+        rNode = d_emptySingleton;
+      } else {
+        rNode = d_emptyRegexp;
+      }
+    } else if(r1 == r2) {
+      rNode = convert1(cnt, r1);
+    } else {
+      PairNodes p(r1, r2);
+      std::map< PairNodes, Node >::const_iterator itrcache = cache.find(p);
+      if(itrcache != cache.end()) {
+        rNode = itrcache->second;
+      } else {
+        if(checkConstRegExp(r1) && checkConstRegExp(r2)) {
+          std::vector< unsigned > cset;
+          std::set< unsigned > cset1, cset2;
+          std::set< Node > vset1, vset2;
+          firstChars(r1, cset1, vset1);
+          firstChars(r2, cset2, vset2);
+          std::set_intersection(cset1.begin(), cset1.end(), cset2.begin(), cset1.end(),
+               std::inserter(cset, cset.begin()));
+          std::vector< Node > vec_nodes;
+          Node delta_exp;
+          int flag = delta(r1, delta_exp);
+          int flag2 = delta(r2, delta_exp);
+          if(flag != 2 && flag2 != 2) {
+            if(flag == 1 && flag2 == 1) {
+              vec_nodes.push_back(d_emptySingleton);
+            } else {
+              //TODO
+              spflag = true;
+            }
+          }
+          if(Trace.isOn("regexp-debug")) {
+            Trace("regexp-debug") << "Try CSET( " << cset.size() << " ) = ";
+            for(std::vector<unsigned>::const_iterator itr = cset.begin();
+              itr != cset.end(); itr++) {
+              CVC4::String c( CVC4::String::convertUnsignedIntToChar(*itr) );
+              Trace("regexp-debug") << c << ", ";
+            }
+            Trace("regexp-debug") << std::endl;
+          }
+          for(std::vector<unsigned>::const_iterator itr = cset.begin();
+            itr != cset.end(); itr++) {
+            CVC4::String c( CVC4::String::convertUnsignedIntToChar(*itr) );
+            Node r1l = derivativeSingle(r1, c);
+            Node r2l = derivativeSingle(r2, c);
+            std::map< PairNodes, Node > cache2(cache);
+            PairNodes p(r1, r2);
+            cache2[ p ] = NodeManager::currentNM()->mkNode(kind::REGEXP_RV, NodeManager::currentNM()->mkConst(CVC4::Rational(cnt)));
+            Node rt = intersectInternal2(r1l, r2l, cache2, spflag, cnt+1);
+            rt = convert1(cnt, rt);
+            if(spflag) {
+              //TODO:
+              return Node::null();
+            }
+            rt = Rewriter::rewrite( NodeManager::currentNM()->mkNode(kind::REGEXP_CONCAT,
+              NodeManager::currentNM()->mkNode(kind::STRING_TO_REGEXP, NodeManager::currentNM()->mkConst(c)), rt) );
+            vec_nodes.push_back(rt);
+          }
+          rNode = vec_nodes.size()==0 ? d_emptyRegexp : vec_nodes.size()==1 ? vec_nodes[0] :
+              NodeManager::currentNM()->mkNode(kind::REGEXP_UNION, vec_nodes);
+          rNode = Rewriter::rewrite( rNode );
+        } else {
+          //TODO: non-empty var set
+          spflag = true;
+        }
+      }
+    }
+    d_inter_cache[p] = rNode;
+  }
+  Trace("regexp-intersect") << "End of INTERSECT( " << mkString(r1) << ", " << mkString(r2) << " ) = " << mkString(rNode) << std::endl;
+  return rNode;
+}
 Node RegExpOpr::intersect(Node r1, Node r2, bool &spflag) {
-  std::map< unsigned, std::set< PairNodes > > cache;
+  //std::map< unsigned, std::set< PairNodes > > cache;
+  std::map< PairNodes, Node > cache;
   if(checkConstRegExp(r1) && checkConstRegExp(r2)) {
-    return intersectInternal(r1, r2, cache, spflag);
+    return intersectInternal2(r1, r2, cache, spflag, 1);
   } else {
     spflag = true;
     return Node::null();
@@ -1517,6 +1732,12 @@ std::string RegExpOpr::mkString( Node r ) {
         retStr += "]";
         break;
       }
+      case kind::REGEXP_RV: {
+        retStr += "<";
+        retStr += r[0].getConst<Rational>().getNumerator().toString();
+        retStr += ">";
+        break;
+      }
       default:
         Trace("strings-error") << "Unsupported term: " << r << " in RegExp." << std::endl;
         //Assert( false );