minor change: add a heuristic for preventing constant splitting.
authorTianyi Liang <tianyi-liang@uiowa.edu>
Thu, 24 Apr 2014 22:30:15 +0000 (17:30 -0500)
committerTianyi Liang <tianyi-liang@uiowa.edu>
Thu, 24 Apr 2014 22:30:15 +0000 (17:30 -0500)
src/theory/strings/theory_strings.cpp
src/theory/strings/theory_strings.h
src/util/regexp.cpp
src/util/regexp.h

index f2172071de80a594319ac56274ed7626fc8ff558..d03faa72a11ff537f947d582e04c58a96d5120db 100644 (file)
@@ -65,7 +65,7 @@ TheoryStrings::TheoryStrings(context::Context* c, context::UserContext* u, Outpu
        d_curr_cardinality(c, 0)
 {
     // The kinds we are treating as function application in congruence
-    //d_equalityEngine.addFunctionKind(kind::STRING_IN_REGEXP);
+    d_equalityEngine.addFunctionKind(kind::STRING_IN_REGEXP);
     d_equalityEngine.addFunctionKind(kind::STRING_LENGTH);
     d_equalityEngine.addFunctionKind(kind::STRING_CONCAT);
     d_equalityEngine.addFunctionKind(kind::STRING_STRCTN);
@@ -418,7 +418,7 @@ void TheoryStrings::preRegisterTerm(TNode n) {
                break;
        case kind::STRING_IN_REGEXP:
                //do not add trigger here
-               //d_equalityEngine.addTriggerPredicate(n);
+               d_equalityEngine.addTriggerPredicate(n);
                break;
        case kind::STRING_SUBSTR_TOTAL: {
                Node lenxgti = NodeManager::currentNM()->mkNode( kind::GEQ, 
@@ -560,7 +560,7 @@ void TheoryStrings::check(Effort e) {
        //must record string in regular expressions
        if ( atom.getKind() == kind::STRING_IN_REGEXP ) {
                addMembership(assertion);
-               //d_equalityEngine.assertPredicate(atom, polarity, fact);
+               d_equalityEngine.assertPredicate(atom, polarity, fact);
        } else if (atom.getKind() == kind::STRING_STRCTN) {
                if(polarity) {
                        d_str_pos_ctn.push_back( atom );
@@ -1232,18 +1232,37 @@ bool TheoryStrings::processNEqc(std::vector< std::vector< Node > > &normal_forms
                                                                        Assert( other_str.getKind()!=kind::CONST_STRING, "Other string is not constant." );
                                                                        Assert( other_str.getKind()!=kind::STRING_CONCAT, "Other string is not CONCAT." );
                                                                        antec.insert(antec.end(), curr_exp.begin(), curr_exp.end() );
-                                                                       Node firstChar = const_str.getConst<String>().size() == 1 ? const_str :
-                                                                               NodeManager::currentNM()->mkConst( const_str.getConst<String>().substr(0, 1) );
-                                                                       //split the string
-                                                                       Node eq1 = Rewriter::rewrite( other_str.eqNode( d_emptyString ) );
-                                                                       Node eq2 = mkSplitEq( "c_spt_$$", "created for v/c split", other_str, firstChar, false );
-                                                                       d_pending_req_phase[ eq1 ] = true;
-                                                                       conc = NodeManager::currentNM()->mkNode( kind::OR, eq1, eq2 );
-                                                                       Trace("strings-solve-debug") << "Break normal form constant/variable " << std::endl;
-
-                                                                       Node ant = mkExplain( antec );
-                                                                       sendLemma( ant, conc, "CST-SPLIT" );
-                                                                       ++(d_statistics.d_eq_splits);
+                                                                       //Opt
+                                                                       bool optflag = false;
+                                                                       if( normal_forms[nconst_k].size() > nconst_index_k + 1 &&
+                                                                               normal_forms[nconst_k][nconst_index_k + 1].isConst() ) {
+                                                                               CVC4::String stra = const_str.getConst<String>();
+                                                                               CVC4::String strb = normal_forms[nconst_k][nconst_index_k + 1].getConst<String>();
+                                                                               CVC4::String fc = strb.substr(0, 1);
+                                                                               if( stra.find(fc) == std::string::npos ||
+                                                                                       (stra.find(strb) == std::string::npos &&
+                                                                                        !stra.overlap(strb)) ) {
+                                                                                       Node sk = NodeManager::currentNM()->mkSkolem( "sopt_$$", NodeManager::currentNM()->stringType(), "created for string sp" );
+                                                                                       Node eq = other_str.eqNode( mkConcat(const_str, sk) );
+                                                                                       Node ant = mkExplain( antec );
+                                                                                       sendLemma(ant, eq, "CST-EPS");
+                                                                                       optflag = true;
+                                                                               }
+                                                                       }
+                                                                       if(!optflag){
+                                                                               Node firstChar = const_str.getConst<String>().size() == 1 ? const_str :
+                                                                                       NodeManager::currentNM()->mkConst( const_str.getConst<String>().substr(0, 1) );
+                                                                               //split the string
+                                                                               Node eq1 = Rewriter::rewrite( other_str.eqNode( d_emptyString ) );
+                                                                               Node eq2 = mkSplitEq( "c_spt_$$", "created for v/c split", other_str, firstChar, false );
+                                                                               d_pending_req_phase[ eq1 ] = true;
+                                                                               conc = NodeManager::currentNM()->mkNode( kind::OR, eq1, eq2 );
+                                                                               Trace("strings-solve-debug") << "Break normal form constant/variable " << std::endl;
+
+                                                                               Node ant = mkExplain( antec );
+                                                                               sendLemma( ant, conc, "CST-SPLIT" );
+                                                                               ++(d_statistics.d_eq_splits);
+                                                                       }
                                                                        return true;
                                                                } else {
                                                                        std::vector< Node > antec_new_lits;
@@ -1785,10 +1804,7 @@ void TheoryStrings::sendSplit( Node a, Node b, const char * c, bool preq ) {
 }
 
 Node TheoryStrings::mkConcat( Node n1, Node n2 ) {
-       std::vector< Node > c;
-       c.push_back( n1 );
-       c.push_back( n2 );
-       return mkConcat( c );
+       return NodeManager::currentNM()->mkNode( kind::STRING_CONCAT, n1, n2 );
 }
 
 Node TheoryStrings::mkConcat( std::vector< Node >& c ) {
@@ -2888,6 +2904,11 @@ void TheoryStrings::addMembership(Node assertion) {
        d_regexp_memberships.push_back( assertion );
 }
 
+Node TheoryStrings::instantiateSymRegExp(Node r) {
+       //TODO:
+       return r;
+}
+
 //// Finite Model Finding
 
 Node TheoryStrings::getNextDecisionRequest() {
index 9f99012df9b0e9e26e055b7c694825d3a1c4c5d4..33283d1cfb4e9e4b6a377fe37759ae948b21fea9 100644 (file)
@@ -272,8 +272,8 @@ protected:
        void sendInfer( Node eq_exp, Node eq, const char * c );
        void sendSplit( Node a, Node b, const char * c, bool preq = true );
        /** mkConcat **/
-       Node mkConcat( Node n1, Node n2 );
-       Node mkConcat( std::vector< Node >& c );
+       inline Node mkConcat( Node n1, Node n2 );
+       inline Node mkConcat( std::vector< Node >& c );
        /** mkExplain **/
        Node mkExplain( std::vector< Node >& a );
        Node mkExplain( std::vector< Node >& a, std::vector< Node >& an );
@@ -323,6 +323,7 @@ private:
        bool splitRegExp( Node x, Node r, Node ant );
        bool addMembershipLength(Node atom);
        void addMembership(Node assertion);
+       Node instantiateSymRegExp(Node r);
 
 
        // Finite Model Finding
index b6db624d5770d7f8dee8c4f50c02d186227922e6..3bc17b050ee17df32595b252455c946a1f89025b 100644 (file)
@@ -23,6 +23,71 @@ using namespace std;
 \r
 namespace CVC4 {\r
 \r
+void String::toInternal(const std::string &s) {\r
+  d_str.clear();\r
+  unsigned i=0;\r
+  while(i < s.size()) {\r
+         if(s[i] == '\\') {\r
+                 i++;\r
+                 if(i < s.size()) {\r
+                         switch(s[i]) {\r
+                                 case 'n':  {d_str.push_back( convertCharToUnsignedInt('\n') );i++;} break;\r
+                                 case 't':  {d_str.push_back( convertCharToUnsignedInt('\t') );i++;} break;\r
+                                 case 'v':  {d_str.push_back( convertCharToUnsignedInt('\v') );i++;} break;\r
+                                 case 'b':  {d_str.push_back( convertCharToUnsignedInt('\b') );i++;} break;\r
+                                 case 'r':  {d_str.push_back( convertCharToUnsignedInt('\r') );i++;} break;\r
+                                 case 'f':  {d_str.push_back( convertCharToUnsignedInt('\f') );i++;} break;\r
+                                 case 'a':  {d_str.push_back( convertCharToUnsignedInt('\a') );i++;} break;\r
+                                 case '\\': {d_str.push_back( convertCharToUnsignedInt('\\') );i++;} break;\r
+                                 case 'x': {\r
+                                         if(i + 2 < s.size()) {\r
+                                               if((isdigit(s[i+1]) || (s[i+1] >= 'a' && s[i+1] >= 'f') || (s[i+1] >= 'A' && s[i+1] >= 'F')) &&\r
+                                                  (isdigit(s[i+2]) || (s[i+2] >= 'a' && s[i+2] >= 'f') || (s[i+2] >= 'A' && s[i+2] >= 'F'))) {\r
+                                                       d_str.push_back( convertCharToUnsignedInt( hexToDec(s[i+1]) * 16 + hexToDec(s[i+2]) ) );\r
+                                                       i += 3;\r
+                                               } else {\r
+                                                       throw CVC4::Exception( "Error String Literal: \"" + s + "\"" );\r
+                                               }\r
+                                         } else {\r
+                                               throw CVC4::Exception( "Error String Literal: \"" + s + "\"" );\r
+                                         }\r
+                                 }\r
+                                 break;\r
+                                 default: {\r
+                                         if(isdigit(s[i])) {\r
+                                                 int num = (int)s[i] - (int)'0';\r
+                                                 bool flag = num < 4;\r
+                                                 if(i+1 < s.size() && num < 8 && isdigit(s[i+1]) && s[i+1] < '8') {\r
+                                                         num = num * 8 + (int)s[i+1] - (int)'0';\r
+                                                         if(flag && i+2 < s.size() && isdigit(s[i+2]) && s[i+2] < '8') {\r
+                                                                 num = num * 8 + (int)s[i+2] - (int)'0';\r
+                                                                 d_str.push_back( convertCharToUnsignedInt((char)num) );\r
+                                                                 i += 3;\r
+                                                         } else {\r
+                                                                 d_str.push_back( convertCharToUnsignedInt((char)num) );\r
+                                                                 i += 2;\r
+                                                         }\r
+                                                 } else {\r
+                                                         d_str.push_back( convertCharToUnsignedInt((char)num) );\r
+                                                         i++;\r
+                                                 }\r
+                                         } else {\r
+                                                 d_str.push_back( convertCharToUnsignedInt(s[i]) );\r
+                                                 i++;\r
+                                         }\r
+                                 }\r
+                         }\r
+                 } else {\r
+                         throw CVC4::Exception( "should be handled by lexer: \"" + s + "\"" );\r
+                         //d_str.push_back( convertCharToUnsignedInt('\\') );\r
+                 }\r
+         } else {\r
+                 d_str.push_back( convertCharToUnsignedInt(s[i]) );\r
+                 i++;\r
+         }\r
+  }\r
+}\r
+\r
 void String::getCharSet(std::set<unsigned int> &cset) const {\r
        for(std::vector<unsigned int>::const_iterator itr = d_str.begin();\r
                itr != d_str.end(); itr++) {\r
@@ -30,6 +95,21 @@ void String::getCharSet(std::set<unsigned int> &cset) const {
                }\r
 }\r
 \r
+bool String::overlap(String &y) const {\r
+       unsigned n = y.size();\r
+       if(d_str.size() < y.size()) {\r
+               n = d_str.size();\r
+       }\r
+       for(unsigned i=1; i<n; i++) {\r
+               String s = suffix(i);\r
+               String p = y.prefix(i);\r
+               if(s == p) {\r
+                       return true;\r
+               }\r
+       }\r
+       return false;\r
+}\r
+\r
 std::string String::toString() const {\r
        std::string str;\r
        for(unsigned int i=0; i<d_str.size(); ++i) {\r
index 8c4a3922d24a768eb8402ba7ae4dd7e65a00692f..2bb2b5c4cb720cbcf9b00c58900614b602f71ab6 100644 (file)
@@ -70,70 +70,7 @@ private:
          }
   }
 
-  void toInternal(const std::string &s) {
-         d_str.clear();
-         unsigned i=0;
-         while(i < s.size()) {
-                 if(s[i] == '\\') {
-                         i++;
-                         if(i < s.size()) {
-                                 switch(s[i]) {
-                                         case 'n':  {d_str.push_back( convertCharToUnsignedInt('\n') );i++;} break;
-                                         case 't':  {d_str.push_back( convertCharToUnsignedInt('\t') );i++;} break;
-                                         case 'v':  {d_str.push_back( convertCharToUnsignedInt('\v') );i++;} break;
-                                         case 'b':  {d_str.push_back( convertCharToUnsignedInt('\b') );i++;} break;
-                                         case 'r':  {d_str.push_back( convertCharToUnsignedInt('\r') );i++;} break;
-                                         case 'f':  {d_str.push_back( convertCharToUnsignedInt('\f') );i++;} break;
-                                         case 'a':  {d_str.push_back( convertCharToUnsignedInt('\a') );i++;} break;
-                                         case '\\': {d_str.push_back( convertCharToUnsignedInt('\\') );i++;} break;
-                                         case 'x': {
-                                                 if(i + 2 < s.size()) {
-                                                       if((isdigit(s[i+1]) || (s[i+1] >= 'a' && s[i+1] >= 'f') || (s[i+1] >= 'A' && s[i+1] >= 'F')) &&
-                                                          (isdigit(s[i+2]) || (s[i+2] >= 'a' && s[i+2] >= 'f') || (s[i+2] >= 'A' && s[i+2] >= 'F'))) {
-                                                               d_str.push_back( convertCharToUnsignedInt( hexToDec(s[i+1]) * 16 + hexToDec(s[i+2]) ) );
-                                                               i += 3;
-                                                       } else {
-                                                               throw CVC4::Exception( "Error String Literal: \"" + s + "\"" );
-                                                       }
-                                                 } else {
-                                                       throw CVC4::Exception( "Error String Literal: \"" + s + "\"" );
-                                                 }
-                                         }
-                                         break;
-                                         default: {
-                                                 if(isdigit(s[i])) {
-                                                         int num = (int)s[i] - (int)'0';
-                                                         bool flag = num < 4;
-                                                         if(i+1 < s.size() && num < 8 && isdigit(s[i+1]) && s[i+1] < '8') {
-                                                                 num = num * 8 + (int)s[i+1] - (int)'0';
-                                                                 if(flag && i+2 < s.size() && isdigit(s[i+2]) && s[i+2] < '8') {
-                                                                         num = num * 8 + (int)s[i+2] - (int)'0';
-                                                                         d_str.push_back( convertCharToUnsignedInt((char)num) );
-                                                                         i += 3;
-                                                                 } else {
-                                                                         d_str.push_back( convertCharToUnsignedInt((char)num) );
-                                                                         i += 2;
-                                                                 }
-                                                         } else {
-                                                                 d_str.push_back( convertCharToUnsignedInt((char)num) );
-                                                                 i++;
-                                                         }
-                                                 } else {
-                                                         d_str.push_back( convertCharToUnsignedInt(s[i]) );
-                                                         i++;
-                                                 }
-                                         }
-                                 }
-                         } else {
-                                 throw CVC4::Exception( "should be handled by lexer: \"" + s + "\"" );
-                                 //d_str.push_back( convertCharToUnsignedInt('\\') );
-                         }
-                 } else {
-                         d_str.push_back( convertCharToUnsignedInt(s[i]) );
-                         i++;
-                 }
-         }
-  }
+  void toInternal(const std::string &s);
 
 public:
   String() {}
@@ -316,6 +253,15 @@ public:
     ret_vec.insert( ret_vec.end(), itr, itr + j );
     return String(ret_vec);
   }
+
+  String prefix(unsigned i) const {
+         return substr(0, i);
+  }
+  String suffix(unsigned i) const {
+         return substr(d_str.size() - i, i);
+  }
+  bool overlap(String &y) const;
+
   bool isNumber() const {
         if(d_str.size() == 0) return false;
         for(unsigned int i=0; i<d_str.size(); ++i) {