rev diseq
authorTianyi Liang <tianyi-liang@uiowa.edu>
Fri, 24 Jan 2014 21:39:13 +0000 (15:39 -0600)
committerTianyi Liang <tianyi-liang@uiowa.edu>
Fri, 24 Jan 2014 21:39:59 +0000 (15:39 -0600)
src/theory/strings/theory_strings.cpp
src/theory/strings/theory_strings.h

index d41359a882adf96c8e0f68ba419b297aad8874bf..effbbca2e24acad49da7b69327b191ba90b499e6 100644 (file)
@@ -995,7 +995,7 @@ bool TheoryStrings::processNEqc(std::vector< std::vector< Node > > &normal_forms
                                do
                                {
                                        //---------------------do simple stuff first
-                                       if( processSimpleNeq( normal_forms, normal_form_src, curr_exp, i, j, index_i, index_j, false ) ){
+                                       if( processSimpleNEq( normal_forms, normal_form_src, curr_exp, i, j, index_i, index_j, false ) ){
                                                //added a lemma, return
                                                return true;
                                        }
@@ -1140,7 +1140,7 @@ bool TheoryStrings::processReverseNEq( std::vector< std::vector< Node > > &norma
        t_curr_exp.insert( t_curr_exp.begin(), curr_exp.begin(), curr_exp.end() );
        unsigned index_i = 0;
        unsigned index_j = 0;
-       bool ret = processSimpleNeq( normal_forms, normal_form_src, t_curr_exp, i, j, index_i, index_j, true );
+       bool ret = processSimpleNEq( normal_forms, normal_form_src, t_curr_exp, i, j, index_i, index_j, true );
 
        //reverse normal form of i, j
        std::reverse( normal_forms[i].begin(), normal_forms[i].end() );
@@ -1149,7 +1149,7 @@ bool TheoryStrings::processReverseNEq( std::vector< std::vector< Node > > &norma
        return ret;
 }
 
-bool TheoryStrings::processSimpleNeq( std::vector< std::vector< Node > > &normal_forms,
+bool TheoryStrings::processSimpleNEq( std::vector< std::vector< Node > > &normal_forms,
                                                                          std::vector< Node > &normal_form_src, std::vector< Node > &curr_exp,
                                                                          unsigned i, unsigned j, unsigned& index_i, unsigned& index_j, bool isRev ) {
        bool success;
@@ -1349,7 +1349,8 @@ bool TheoryStrings::normalizeEquivalenceClass( Node eqc, std::vector< Node > & v
     }
 }
 
-bool TheoryStrings::normalizeDisequality( Node ni, Node nj ) {
+//return true for lemma, false if we succeed
+bool TheoryStrings::processDeq( Node ni, Node nj ) {
        //Assert( areDisequal( ni, nj ) );
        if( d_normal_forms[ni].size()>1 || d_normal_forms[nj].size()>1 ){
                std::vector< Node > nfi;
@@ -1357,106 +1358,75 @@ bool TheoryStrings::normalizeDisequality( Node ni, Node nj ) {
                std::vector< Node > nfj;
                nfj.insert( nfj.end(), d_normal_forms[nj].begin(), d_normal_forms[nj].end() );
 
+               //int revRet = processReverseDeq( nfi, nfj, ni, nj );
+               //if( revRet!=0 ){
+               //      return revRet==-1;
+               //}
+               
+               nfi.clear();
+               nfi.insert( nfi.end(), d_normal_forms[ni].begin(), d_normal_forms[ni].end() );
+               nfj.clear();
+               nfj.insert( nfj.end(), d_normal_forms[nj].begin(), d_normal_forms[nj].end() );
+
                unsigned index = 0;
                while( index<nfi.size() || index<nfj.size() ){
-                       if( index>=nfi.size() || index>=nfj.size() ){
-                               std::vector< Node > ant;
-                               //we have a conflict : because the lengths are equal, the remainder needs to be empty, which will lead to a conflict
-                               Node lni = getLength( ni );
-                               Node lnj = getLength( nj );
-                               ant.push_back( lni.eqNode( lnj ) );
-                               ant.push_back( getLengthTerm( ni ).eqNode( d_normal_forms_base[ni] ) );
-                               ant.push_back( getLengthTerm( nj ).eqNode( d_normal_forms_base[nj] ) );
-                               ant.insert( ant.end(), d_normal_forms_exp[ni].begin(), d_normal_forms_exp[ni].end() );
-                               ant.insert( ant.end(), d_normal_forms_exp[nj].begin(), d_normal_forms_exp[nj].end() );
-                               std::vector< Node > cc;
-                               std::vector< Node >& nfk = index>=nfi.size() ? nfj : nfi;
-                               for( unsigned index_k=index; index_k<nfk.size(); index_k++ ){
-                                       cc.push_back( nfk[index_k].eqNode( d_emptyString ) );
-                               }
-                               Node conc = cc.size()==1 ? cc[0] : NodeManager::currentNM()->mkNode( kind::AND, cc );
-                               conc = Rewriter::rewrite( conc );
-                               sendLemma(mkExplain( ant ), conc, "Disequality Normalize Empty");
-                               return true;
+                       int ret = processSimpleDeq( nfi, nfj, ni, nj, index, false );
+                       if( ret!=0 ){
+                               return ret==-1;
                        }else{
+                               Assert( index<nfi.size() && index<nfj.size() );
                                Node i = nfi[index];
                                Node j = nfj[index];
                                Trace("strings-solve-debug")  << "...Processing " << i << " " << j << std::endl;
                                if( !areEqual( i, j ) ) {
-                                       if( i.getKind()==kind::CONST_STRING && j.getKind()==kind::CONST_STRING ){
-                                               unsigned int len_short = i.getConst<String>().size() < j.getConst<String>().size() ? i.getConst<String>().size() : j.getConst<String>().size();
-                                               String si = i.getConst<String>().substr(0, len_short);
-                                               String sj = j.getConst<String>().substr(0, len_short);
-                                               if(si == sj) {
-                                                       if( i.getConst<String>().size() < j.getConst<String>().size() ) {
-                                                               Node remainderStr = NodeManager::currentNM()->mkConst( j.getConst<String>().substr(len_short) );
-                                                               Trace("strings-solve-debug-test") << "Break normal form of " << nfj[index] << " into " << nfi[index] << ", " << remainderStr << std::endl;
-                                                               nfj.insert( nfj.begin() + index + 1, remainderStr );
-                                                               nfj[index] = nfi[index];
-                                                       } else {
-                                                               Node remainderStr = NodeManager::currentNM()->mkConst( i.getConst<String>().substr(len_short) );
-                                                               Trace("strings-solve-debug-test") << "Break normal form of " << nfi[index] << " into " << nfj[index] << ", " << remainderStr << std::endl;
-                                                               nfi.insert( nfi.begin() + index + 1, remainderStr );
-                                                               nfi[index] = nfj[index];
-                                                       }
-                                               } else {
-                                                       //conflict
-                                                       return false;
-                                               }
+                                       Assert( i.getKind()!=kind::CONST_STRING || j.getKind()!=kind::CONST_STRING );
+                                       Node li = getLength( i );
+                                       Node lj = getLength( j );
+                                       if( areDisequal(li, lj) ){
+                                               //if( i.getKind()==kind::CONST_STRING || j.getKind()==kind::CONST_STRING ){
+                       
+                                               Trace("strings-solve") << "Non-Simple Case 1 : add lemma " << std::endl;
+                                               //must add lemma
+                                               std::vector< Node > antec;
+                                               std::vector< Node > antec_new_lits;
+                                               antec.insert( antec.end(), d_normal_forms_exp[ni].begin(), d_normal_forms_exp[ni].end() );
+                                               antec.insert( antec.end(), d_normal_forms_exp[nj].begin(), d_normal_forms_exp[nj].end() );
+                                               antec.push_back( ni.eqNode( nj ).negate() );
+                                               antec_new_lits.push_back( li.eqNode( lj ).negate() );
+                                               std::vector< Node > conc;
+                                               Node sk1 = NodeManager::currentNM()->mkSkolem( "x_dsplit_$$", ni.getType(), "created for disequality normalization" );
+                                               Node sk2 = NodeManager::currentNM()->mkSkolem( "y_dsplit_$$", ni.getType(), "created for disequality normalization" );
+                                               Node sk3 = NodeManager::currentNM()->mkSkolem( "z_dsplit_$$", ni.getType(), "created for disequality normalization" );
+                                               //Node nemp = sk1.eqNode(d_emptyString).negate();
+                                               //conc.push_back(nemp);
+                                               //nemp = sk2.eqNode(d_emptyString).negate();
+                                               //conc.push_back(nemp);
+                                               Node nemp = sk3.eqNode(d_emptyString).negate();
+                                               conc.push_back(nemp);
+                                               Node lsk1 = getLength( sk1 );
+                                               conc.push_back( lsk1.eqNode( li ) );
+                                               Node lsk2 = getLength( sk2 );
+                                               conc.push_back( lsk2.eqNode( lj ) );
+                                               conc.push_back( NodeManager::currentNM()->mkNode( kind::OR,
+                                                                                       j.eqNode( mkConcat( sk1, sk3 ) ), i.eqNode( mkConcat( sk2, sk3 ) ) ) );
+                                               
+                                               sendLemma( mkExplain( antec, antec_new_lits ), NodeManager::currentNM()->mkNode( kind::AND, conc ), "D-DISL-Split" );
+                                               return true;
+                                       }else if( areEqual( li, lj ) ){
+                                               Assert( !areDisequal( i, j ) );
+                                               //splitting on demand : try to make them disequal
+                                               Node eq = i.eqNode( j );
+                                               sendSplit( i, j, "D-EQL-Split" );
+                                               eq = Rewriter::rewrite( eq );
+                                               d_pending_req_phase[ eq ] = false;
+                                               return true;
                                        }else{
-                                               Node li = getLength( i );
-                                               Node lj = getLength( j );
-                                               if( areDisequal(li, lj) ){
-                                                       //if( i.getKind()==kind::CONST_STRING || j.getKind()==kind::CONST_STRING ){
-                               
-                                                       Trace("strings-solve") << "Case 2 : add lemma " << std::endl;
-                                                       //must add lemma
-                                                       std::vector< Node > antec;
-                                                       std::vector< Node > antec_new_lits;
-                                                       antec.insert( antec.end(), d_normal_forms_exp[ni].begin(), d_normal_forms_exp[ni].end() );
-                                                       antec.insert( antec.end(), d_normal_forms_exp[nj].begin(), d_normal_forms_exp[nj].end() );
-                                                       antec.push_back( ni.eqNode( nj ).negate() );
-                                                       antec_new_lits.push_back( li.eqNode( lj ).negate() );
-                                                       std::vector< Node > conc;
-                                                       Node sk1 = NodeManager::currentNM()->mkSkolem( "x_dsplit_$$", ni.getType(), "created for disequality normalization" );
-                                                       Node sk2 = NodeManager::currentNM()->mkSkolem( "y_dsplit_$$", ni.getType(), "created for disequality normalization" );
-                                                       Node sk3 = NodeManager::currentNM()->mkSkolem( "z_dsplit_$$", ni.getType(), "created for disequality normalization" );
-                                                       //Node nemp = sk1.eqNode(d_emptyString).negate();
-                                                       //conc.push_back(nemp);
-                                                       //nemp = sk2.eqNode(d_emptyString).negate();
-                                                       //conc.push_back(nemp);
-                                                       Node nemp = sk3.eqNode(d_emptyString).negate();
-                                                       conc.push_back(nemp);
-                                                       Node lsk1 = getLength( sk1 );
-                                                       conc.push_back( lsk1.eqNode( li ) );
-                                                       Node lsk2 = getLength( sk2 );
-                                                       conc.push_back( lsk2.eqNode( lj ) );
-                                                       conc.push_back( NodeManager::currentNM()->mkNode( kind::OR,
-                                                                                               j.eqNode( mkConcat( sk1, sk3 ) ), i.eqNode( mkConcat( sk2, sk3 ) ) ) );
-                                                       
-                                                       sendLemma( mkExplain( antec, antec_new_lits ), NodeManager::currentNM()->mkNode( kind::AND, conc ), "D-DISL-Split" );
-                                                       return true;
-                                               }else if( areEqual( li, lj ) ){
-                                                       if( areDisequal( i, j ) ){
-                                                               Trace("strings-solve") << "Case 1 : found equal length disequal sub strings " << i << " " << j << std::endl;
-                                                               //we are done: D-Remove
-                                                               return false;
-                                                       } else {
-                                                               //splitting on demand : try to make them disequal
-                                                               Node eq = i.eqNode( j );
-                                                               sendSplit( i, j, "D-EQL-Split" );
-                                                               eq = Rewriter::rewrite( eq );
-                                                               d_pending_req_phase[ eq ] = false;
-                                                               return true;
-                                                       }
-                                               }else{
-                                                       //splitting on demand : try to make lengths equal
-                                                       Node eq = li.eqNode( lj );
-                                                       sendSplit( li, lj, "D-UNK-Split" );
-                                                       eq = Rewriter::rewrite( eq );
-                                                       d_pending_req_phase[ eq ] = true;
-                                                       return true;
-                                               }
+                                               //splitting on demand : try to make lengths equal
+                                               Node eq = li.eqNode( lj );
+                                               sendSplit( li, lj, "D-UNK-Split" );
+                                               eq = Rewriter::rewrite( eq );
+                                               d_pending_req_phase[ eq ] = true;
+                                               return true;
                                        }
                                }
                                index++;
@@ -1467,6 +1437,92 @@ bool TheoryStrings::normalizeDisequality( Node ni, Node nj ) {
        return false;
 }
 
+int TheoryStrings::processReverseDeq( std::vector< Node >& nfi, std::vector< Node >& nfj, Node ni, Node nj ) {
+       //reverse normal form of i, j
+       std::reverse( nfi.begin(), nfi.end() );
+       std::reverse( nfj.begin(), nfj.end() );
+
+       unsigned index = 0;
+       int ret = processSimpleDeq( nfi, nfj, ni, nj, index, true );
+
+       //reverse normal form of i, j
+       std::reverse( nfi.begin(), nfi.end() );
+       std::reverse( nfj.begin(), nfj.end() );
+
+       return ret;
+}
+
+int TheoryStrings::processSimpleDeq( std::vector< Node >& nfi, std::vector< Node >& nfj, Node ni, Node nj, unsigned& index, bool isRev ) {
+       while( index<nfi.size() || index<nfj.size() ){
+               if( index>=nfi.size() || index>=nfj.size() ){
+                       std::vector< Node > ant;
+                       //we have a conflict : because the lengths are equal, the remainder needs to be empty, which will lead to a conflict
+                       Node lni = getLength( ni );
+                       Node lnj = getLength( nj );
+                       ant.push_back( lni.eqNode( lnj ) );
+                       ant.push_back( getLengthTerm( ni ).eqNode( d_normal_forms_base[ni] ) );
+                       ant.push_back( getLengthTerm( nj ).eqNode( d_normal_forms_base[nj] ) );
+                       ant.insert( ant.end(), d_normal_forms_exp[ni].begin(), d_normal_forms_exp[ni].end() );
+                       ant.insert( ant.end(), d_normal_forms_exp[nj].begin(), d_normal_forms_exp[nj].end() );
+                       std::vector< Node > cc;
+                       std::vector< Node >& nfk = index>=nfi.size() ? nfj : nfi;
+                       for( unsigned index_k=index; index_k<nfk.size(); index_k++ ){
+                               cc.push_back( nfk[index_k].eqNode( d_emptyString ) );
+                       }
+                       Node conc = cc.size()==1 ? cc[0] : NodeManager::currentNM()->mkNode( kind::AND, cc );
+                       conc = Rewriter::rewrite( conc );
+                       sendLemma(mkExplain( ant ), conc, "Disequality Normalize Empty");
+                       return -1;
+               } else {
+                       Node i = nfi[index];
+                       Node j = nfj[index];
+                       Trace("strings-solve-debug")  << "...Processing " << i << " " << j << std::endl;
+                       if( !areEqual( i, j ) ) {
+                               if( i.getKind()==kind::CONST_STRING && j.getKind()==kind::CONST_STRING ) {
+                                       unsigned int len_short = i.getConst<String>().size() < j.getConst<String>().size() ? i.getConst<String>().size() : j.getConst<String>().size();
+                                       bool isSameFix = isRev ? i.getConst<String>().rstrncmp(j.getConst<String>(), len_short): i.getConst<String>().strncmp(j.getConst<String>(), len_short);
+                                       if( isSameFix ) {
+                                               //same prefix/suffix
+                                               //k is the index of the string that is shorter
+                                               Node nk = i.getConst<String>().size() < j.getConst<String>().size() ? i : j;
+                                               Node nl = i.getConst<String>().size() < j.getConst<String>().size() ? j : i;
+                                               Node remainderStr;
+                                               if(isRev) {
+                                                       int new_len = nl.getConst<String>().size() - len_short;
+                                                       remainderStr = NodeManager::currentNM()->mkConst( nl.getConst<String>().substr(0, new_len) );
+                                                       Trace("strings-solve-debug-test") << "Rev. Break normal form of " << nl << " into " << nk << ", " << remainderStr << std::endl;
+                                               } else {
+                                                       remainderStr = NodeManager::currentNM()->mkConst( j.getConst<String>().substr(len_short) );
+                                                       Trace("strings-solve-debug-test") << "Break normal form of " << nl << " into " << nk << ", " << remainderStr << std::endl;
+                                               }
+                                               if( i.getConst<String>().size() < j.getConst<String>().size() ) {
+                                                       nfj.insert( nfj.begin() + index + 1, remainderStr );
+                                                       nfj[index] = nfi[index];
+                                               } else {
+                                                       nfi.insert( nfi.begin() + index + 1, remainderStr );
+                                                       nfi[index] = nfj[index];
+                                               }
+                                       } else {
+                                               return 1;
+                                       }
+                               } else {
+                                       Node li = getLength( i );
+                                       Node lj = getLength( j );
+                                       if( areEqual( li, lj ) && areDisequal( i, j ) ) {
+                                               Trace("strings-solve") << "Simple Case 2 : found equal length disequal sub strings " << i << " " << j << std::endl;
+                                               //we are done: D-Remove
+                                               return 1;
+                                       }else{
+                                               return 0;
+                                       }
+                               }
+                       }
+                       index++;
+               }
+       }
+       return 0;
+}
+
 void TheoryStrings::addNormalFormPair( Node n1, Node n2 ) {
   if( !isNormalFormPair( n1, n2 ) ){
                //Assert( !isNormalFormPair( n1, n2 ) );
@@ -1532,7 +1588,7 @@ void TheoryStrings::sendInfer( Node eq_exp, Node eq, const char * c ) {
        if( eq==d_false ){
                sendLemma( eq_exp, eq, c );
        }else{
-               Trace("strings-lemma") << "Strings::Infer " << eq << " from " << eq_exp << std::endl;
+               Trace("strings-lemma") << "Strings::Infer " << eq << " from " << eq_exp << " by " << c << std::endl;
                d_pending.push_back( eq );
                d_pending_exp[eq] = eq_exp;
                d_infer.push_back(eq);
@@ -1844,7 +1900,7 @@ bool TheoryStrings::checkNormalForms() {
                                                Trace("strings-solve") << " against ";
                                                printConcat( d_normal_forms[cols[i][k]], "strings-solve" );
                                                Trace("strings-solve")  << "..." << std::endl;
-                                               if( normalizeDisequality( cols[i][j], cols[i][k] ) ){
+                                               if( processDeq( cols[i][j], cols[i][k] ) ){
                                                        break;
                                                }
                                        }
index e7f8751577d7643b82c3dc15cf6aa4dd1fb23ba2..1f69c81be4c083fedf92e2207ee6f1a5fa25a217 100644 (file)
@@ -211,11 +211,13 @@ private:
                                         std::vector< Node > &normal_form_src);
        bool processReverseNEq(std::vector< std::vector< Node > > &normal_forms,
                                                   std::vector< Node > &normal_form_src, std::vector< Node > &curr_exp, unsigned i, unsigned j );
-       bool processSimpleNeq( std::vector< std::vector< Node > > &normal_forms,
+       bool processSimpleNEq( std::vector< std::vector< Node > > &normal_forms,
                                                   std::vector< Node > &normal_form_src, std::vector< Node > &curr_exp, unsigned i, unsigned j,
                                                   unsigned& index_i, unsigned& index_j, bool isRev );
     bool normalizeEquivalenceClass( Node n, std::vector< Node > & visited, std::vector< Node > & nf, std::vector< Node > & nf_exp );
-    bool normalizeDisequality( Node n1, Node n2 );
+    bool processDeq( Node n1, Node n2 );
+       int processReverseDeq( std::vector< Node >& nfi, std::vector< Node >& nfj, Node ni, Node nj );
+       int processSimpleDeq( std::vector< Node >& nfi, std::vector< Node >& nfj, Node ni, Node nj, unsigned& index, bool isRev );
        bool unrollStar( Node atom );
 
        bool checkLengths();