Minor refactoring in strings related to length.
authorajreynol <andrew.j.reynolds@gmail.com>
Wed, 21 Oct 2015 08:47:52 +0000 (10:47 +0200)
committerajreynol <andrew.j.reynolds@gmail.com>
Wed, 21 Oct 2015 08:47:52 +0000 (10:47 +0200)
src/theory/strings/theory_strings.cpp
src/theory/strings/theory_strings.h

index 9899a7a4816a28888d2f33793089f712bb6da12a..ea48d8805156edbc38c3ff414e57c7c1cca0b9de 100644 (file)
@@ -65,6 +65,7 @@ TheoryStrings::TheoryStrings(context::Context* c, context::UserContext* u, Outpu
   d_nf_pairs(c),
   d_loop_antec(u),
   d_length_intro_vars(u),
+  d_pregistered_terms_cache(u),
   d_registered_terms_cache(u),
   d_preproc(u),
   d_preproc_cache(u),
@@ -172,17 +173,23 @@ bool TheoryStrings::areDisequal( Node a, Node b ){
   }
 }
 
-Node TheoryStrings::getLengthExp( Node t, std::vector< Node >& exp, Node te ) {
+Node TheoryStrings::getLengthExp( Node t, std::vector< Node >& exp, Node te ){
   Assert( areEqual( t, te ) );
-  EqcInfo * ei = getOrMakeEqcInfo( t, false );
-  Node length_term = ei ? ei->d_length_term : Node::null();
-  if( length_term.isNull() ){
-    //typically shouldnt be necessary
-    length_term = t;
+  Node lt = mkLength( te );
+  if( hasTerm( lt ) ){
+    // use own length if it exists, leads to shorter explanation
+    return lt;
+  }else{
+    EqcInfo * ei = getOrMakeEqcInfo( t, false );
+    Node length_term = ei ? ei->d_length_term : Node::null();
+    if( length_term.isNull() ){
+      //typically shouldnt be necessary
+      length_term = t;
+    }
+    Debug("strings") << "TheoryStrings::getLengthTerm " << t << " is " << length_term << std::endl;
+    addToExplanation( length_term, te, exp );
+    return Rewriter::rewrite( NodeManager::currentNM()->mkNode( kind::STRING_LENGTH, length_term ) );
   }
-  Debug("strings") << "TheoryStrings::getLengthTerm " << t << " is " << length_term << std::endl;
-  addToExplanation( length_term, te, exp );
-  return Rewriter::rewrite( NodeManager::currentNM()->mkNode( kind::STRING_LENGTH, length_term ) );
 }
 
 Node TheoryStrings::getLength( Node t, std::vector< Node >& exp ) {
@@ -430,7 +437,8 @@ void TheoryStrings::collectModelInfo( TheoryModel* m, bool fullModel ) {
 
 
 void TheoryStrings::preRegisterTerm(TNode n) {
-  if( d_registered_terms_cache.find(n) == d_registered_terms_cache.end() ) {
+  if( d_pregistered_terms_cache.find(n) == d_pregistered_terms_cache.end() ) {
+    d_pregistered_terms_cache.insert(n);
     //check for logic exceptions
     if( !options::stringExp() ){
       if( n.getKind()==kind::STRING_STRIDOF ||
@@ -456,9 +464,9 @@ void TheoryStrings::preRegisterTerm(TNode n) {
       }
       default: {
         if( n.getType().isString() ) {
-          registerTerm(n);
+          registerTerm( n, 0 );
           // FMF
-          if( n.getKind() == kind::VARIABLE && options::stringFMF() ) {
+          if( n.getKind() == kind::VARIABLE && options::stringFMF() ){
             d_input_vars.insert(n);
           }
         } else if (n.getType().isBoolean()) {
@@ -470,7 +478,6 @@ void TheoryStrings::preRegisterTerm(TNode n) {
         }
       }
     }
-    d_registered_terms_cache.insert(n);
   }
 }
 
@@ -718,6 +725,8 @@ void TheoryStrings::eqNotifyNewClass(TNode t){
     Node r = d_equalityEngine.getRepresentative(t[0]);
     EqcInfo * ei = getOrMakeEqcInfo( r, true );
     ei->d_length_term = t[0];
+    //we care about the length of this string
+    registerTerm( t[0], 1 );
   }
 }
 
@@ -791,11 +800,9 @@ void TheoryStrings::assertPendingFact(Node atom, bool polarity, Node exp) {
   Assert(atom.getKind() != kind::OR, "Infer error: a split.");
   if( atom.getKind()==kind::EQUAL ){
     Trace("strings-pending-debug") << "  Register term" << std::endl;
-    //AJR : is this necessary?
     for( unsigned j=0; j<2; j++ ) {
       if( !d_equalityEngine.hasTerm( atom[j] ) && atom[j].getType().isString() ) {
-        //TODO: check!!!
-        registerTerm( atom[j] );
+        registerTerm( atom[j], 0 );
       }
     }
     Trace("strings-pending-debug") << "  Now assert equality" << std::endl;
@@ -892,6 +899,7 @@ void TheoryStrings::checkInit() {
       if( tn.isString() ){
         d_strings_eqc.push_back( eqc );
       }
+      Node var;
       eq::EqClassIterator eqc_i = eq::EqClassIterator( eqc, &d_equalityEngine );
       while( !eqc_i.isFinished() ) {
         Node n = *eqc_i;
@@ -988,6 +996,15 @@ void TheoryStrings::checkInit() {
               congruent[k]++;
             }
           }
+        }else{
+          if( d_congruent.find( n )==d_congruent.end() ){
+            if( var.isNull() ){
+              var = n;
+            }else{
+              Trace("strings-process-debug") << "  congruent variable : " << n << std::endl;
+              d_congruent.insert( n );
+            }
+          }
         }
         ++eqc_i;
       }
@@ -1494,6 +1511,12 @@ void TheoryStrings::checkFlatForms() {
         }
       }
     }
+    if( !hasProcessed() ){
+      // simple extended func reduction
+      Trace("strings-process") << "Check extended function reduction effort=1..." << std::endl;
+      checkExtfReduction( 1 );
+      Trace("strings-process") << "Done check extended function reduction" << std::endl;
+    }
   }
 }
 
@@ -1508,8 +1531,8 @@ Node TheoryStrings::checkCycles( Node eqc, std::vector< Node >& curr, std::vecto
     while( !eqc_i.isFinished() ) {
       Node n = (*eqc_i);
       if( d_congruent.find( n )==d_congruent.end() ){
-        Trace("strings-cycle") << eqc << " check term : " << n << " in " << eqc << std::endl;
-        if( n.getKind() == kind::STRING_CONCAT ) {
+        if( n.getKind() == kind::STRING_CONCAT ){
+          Trace("strings-cycle") << eqc << " check term : " << n << " in " << eqc << std::endl;
           if( eqc!=d_emptyString_r ){
             d_eqc[eqc].push_back( n );
           }
@@ -1569,10 +1592,19 @@ Node TheoryStrings::checkCycles( Node eqc, std::vector< Node >& curr, std::vecto
 
 
 void TheoryStrings::checkNormalForms(){
-  // simple extended func reduction
-  Trace("strings-process") << "Check extended function reduction effort=1..." << std::endl;
-  checkExtfReduction( 1 );
-  Trace("strings-process") << "Done check extended function reduction" << std::endl;
+  if( !options::stringEagerLen() ){
+    for( unsigned i=0; i<d_strings_eqc.size(); i++ ) {
+      Node eqc = d_strings_eqc[i];
+      eq::EqClassIterator eqc_i = eq::EqClassIterator( eqc, &d_equalityEngine );
+      while( !eqc_i.isFinished() ) {
+        Node n = (*eqc_i);
+        if( d_congruent.find( n )==d_congruent.end() ){
+          registerTerm( n, 2 );
+        }
+        ++eqc_i;
+      }
+    }
+  }
   if( !hasProcessed() ){
     Trace("strings-process") << "Normalize equivalence classes...." << std::endl;
     //calculate normal forms for each equivalence class, possibly adding splitting lemmas
@@ -1627,32 +1659,19 @@ void TheoryStrings::checkNormalForms(){
         Trace("strings-process-debug") << "Done check disequalities, addedFact = " << !d_pending.empty() << " " << !d_lemma_cache.empty() << ", d_conflict = " << d_conflict << std::endl;
       }
     }
+    Trace("strings-solve") << "Finished check normal forms, #lemmas = " << d_lemma_cache.size() << ", conflict = " << d_conflict << std::endl;
   }
-  Trace("strings-solve") << "Finished check normal forms, #lemmas = " << d_lemma_cache.size() << ", conflict = " << d_conflict << std::endl;
 }
 
 //nf_exp is conjunction
 bool TheoryStrings::normalizeEquivalenceClass( Node eqc, std::vector< Node > & nf, std::vector< Node > & nf_exp ) {
   Trace("strings-process-debug") << "Process equivalence class " << eqc << std::endl;
   if( areEqual( eqc, d_emptyString ) ) {
-    eq::EqClassIterator eqc_i = eq::EqClassIterator( eqc, &d_equalityEngine );
-    while( !eqc_i.isFinished() ) {
-      Node n = (*eqc_i);
-      if( d_congruent.find( n )==d_congruent.end() ){
-        if( n.getKind()==kind::STRING_CONCAT ){
-          //std::vector< Node > exp;
-          //exp.push_back( n.eqNode( d_emptyString ) );
-          //Node ant = mkExplain( exp );
-          Node ant = n.eqNode( d_emptyString );
-          for( unsigned i=0; i<n.getNumChildren(); i++ ){
-            if( !areEqual( n[i], d_emptyString ) ){
-              //sendLemma( ant, n[i].eqNode( d_emptyString ), "CYCLE" );
-              sendInfer( ant, n[i].eqNode( d_emptyString ), "CYCLE" );
-            }
-          }
-        }
+    for( unsigned j=0; j<d_eqc[eqc].size(); j++ ){
+      Node n = d_eqc[eqc][j];
+      for( unsigned i=0; i<n.getNumChildren(); i++ ){
+        Assert( areEqual( n[i], d_emptyString ) );
       }
-      ++eqc_i;
     }
     //do nothing
     Trace("strings-process-debug") << "Return process equivalence class " << eqc << " : empty." << std::endl;
@@ -1709,7 +1728,7 @@ bool TheoryStrings::normalizeEquivalenceClass( Node eqc, std::vector< Node > & n
 }
 
 bool TheoryStrings::getNormalForms( Node &eqc, std::vector< Node > & nf,
-                                    std::vector< std::vector< Node > > &normal_forms,  std::vector< std::vector< Node > > &normal_forms_exp, 
+                                    std::vector< std::vector< Node > > &normal_forms,  std::vector< std::vector< Node > > &normal_forms_exp,
                                     std::vector< Node > &normal_form_src) {
   Trace("strings-process-debug") << "Get normal forms " << eqc << std::endl;
   eq::EqClassIterator eqc_i = eq::EqClassIterator( eqc, &d_equalityEngine );
@@ -1756,7 +1775,6 @@ bool TheoryStrings::getNormalForms( Node &eqc, std::vector< Node > & nf,
           }
         }
         //if not equal to self
-        //if( nf_n.size()!=1 || (nf_n.size()>1 && nf_n[0]!=eqc ) ){
         if( nf_n.size()>1 || ( nf_n.size()==1 && nf_n[0].getKind()==kind::CONST_STRING ) ){
           if( nf_n.size()>1 ) {
             for( unsigned i=0; i<nf_n.size(); i++ ){
@@ -1772,7 +1790,7 @@ bool TheoryStrings::getNormalForms( Node &eqc, std::vector< Node > & nf,
           normal_forms_exp.push_back(nf_exp_n);
           normal_form_src.push_back(n);
         }else{
-          //this was redundant: combination of eqc + empty string(s)
+          //this was redundant: combination of self + empty string(s)
           Node nn = nf_n.size()==0 ? d_emptyString : nf_n[0];
           Assert( areEqual( nn, eqc ) );
           //Assert( areEqual( nf_n[0], eqc ) );
@@ -1787,7 +1805,6 @@ bool TheoryStrings::getNormalForms( Node &eqc, std::vector< Node > & nf,
           }
           */
         }
-        //}
       }
     }
     ++eqc_i;
@@ -1861,8 +1878,7 @@ bool TheoryStrings::processNEqc( std::vector< std::vector< Node > > &normal_form
         unsigned index_i = 0;
         unsigned index_j = 0;
         bool success;
-        do
-        {
+        do{
           //simple check
           if( processSimpleNEq( normal_forms, normal_form_src, curr_exp, i, j, index_i, index_j, false ) ){
             //added a lemma, return
@@ -2014,7 +2030,7 @@ bool TheoryStrings::processNEqc( std::vector< std::vector< Node > > &normal_form
   return false;
 }
 
-bool TheoryStrings::processReverseNEq( std::vector< std::vector< Node > > &normal_forms, std::vector< Node > &normal_form_src, 
+bool TheoryStrings::processReverseNEq( std::vector< std::vector< Node > > &normal_forms, std::vector< Node > &normal_form_src,
                                        std::vector< Node > &curr_exp, unsigned i, unsigned j ) {
   //reverse normal form of i, j
   std::reverse( normal_forms[i].begin(), normal_forms[i].end() );
@@ -2155,7 +2171,7 @@ bool TheoryStrings::processSimpleNEq( std::vector< std::vector< Node > > &normal
   return false;
 }
 
-bool TheoryStrings::detectLoop( std::vector< std::vector< Node > > &normal_forms, int i, int j, 
+bool TheoryStrings::detectLoop( std::vector< std::vector< Node > > &normal_forms, int i, int j,
                                 int index_i, int index_j, int &loop_in_i, int &loop_in_j) {
   int has_loop[2] = { -1, -1 };
   if( options::stringLB() != 2 ) {
@@ -2568,61 +2584,69 @@ bool TheoryStrings::isNormalFormPair2( Node n1, Node n2 ) {
   return false;
 }
 
-bool TheoryStrings::registerTerm( Node n ) {
-  if(d_registered_terms_cache.find(n) == d_registered_terms_cache.end()) {
-    d_registered_terms_cache.insert(n);
-    Debug("strings-register") << "TheoryStrings::registerTerm() " << n << endl;
-    if(n.getType().isString()) {
-      //register length information:
-      //  for variables, split on empty vs positive length
-      //  for concat/const, introduce proxy var and state length relation
-      if( n.getKind()!=kind::STRING_CONCAT && n.getKind()!=kind::CONST_STRING ) {
-        if( d_length_intro_vars.find(n)==d_length_intro_vars.end() ) {
-          sendLengthLemma( n );
-          ++(d_statistics.d_splits);
-        }
-      } else {
-        Node sk = mkSkolemS("lsym", 2);
-        StringsProxyVarAttribute spva;
-        sk.setAttribute(spva,true);
-        Node eq = Rewriter::rewrite( sk.eqNode(n) );
-        Trace("strings-lemma") << "Strings::Lemma LENGTH Term : " << eq << std::endl;
-        d_proxy_var[n] = sk;
-        Trace("strings-assert") << "(assert " << eq << ")" << std::endl;
-        d_out->lemma(eq);
-        Node skl = NodeManager::currentNM()->mkNode( kind::STRING_LENGTH, sk );
-        Node lsum;
-        if( n.getKind() == kind::STRING_CONCAT ) {
-          std::vector<Node> node_vec;
-          for( unsigned i=0; i<n.getNumChildren(); i++ ) {
-            if( n[i].getAttribute(StringsProxyVarAttribute()) ){
-              Assert( d_proxy_var_to_length.find( n[i] )!=d_proxy_var_to_length.end() );
-              node_vec.push_back( d_proxy_var_to_length[n[i]] );
-            }else{
-              Node lni = NodeManager::currentNM()->mkNode( kind::STRING_LENGTH, n[i] );
-              node_vec.push_back(lni);
+void TheoryStrings::registerTerm( Node n, int effort ) {
+  // 0 : upon preregistration or internal assertion
+  // 1 : upon occurrence in length term
+  // 2 : before normal form computation
+  // 3 : called on normal form terms
+  bool do_register = false;
+  if( options::stringEagerLen() ){
+    do_register = effort==0;
+  }else{
+    do_register = effort>0 || n.getKind()!=kind::STRING_CONCAT;
+  }
+  if( do_register ){
+    if(d_registered_terms_cache.find(n) == d_registered_terms_cache.end()) {
+      d_registered_terms_cache.insert(n);
+      Debug("strings-register") << "TheoryStrings::registerTerm() " << n << ", effort = " << effort << std::endl;
+      if(n.getType().isString()) {
+        //register length information:
+        //  for variables, split on empty vs positive length
+        //  for concat/const, introduce proxy var and state length relation
+        if( n.getKind()!=kind::STRING_CONCAT && n.getKind()!=kind::CONST_STRING ) {
+          if( d_length_intro_vars.find(n)==d_length_intro_vars.end() ) {
+            sendLengthLemma( n );
+            ++(d_statistics.d_splits);
+          }
+        } else {
+          Node sk = mkSkolemS("lsym", 2);
+          StringsProxyVarAttribute spva;
+          sk.setAttribute(spva,true);
+          Node eq = Rewriter::rewrite( sk.eqNode(n) );
+          Trace("strings-lemma") << "Strings::Lemma LENGTH Term : " << eq << std::endl;
+          d_proxy_var[n] = sk;
+          Trace("strings-assert") << "(assert " << eq << ")" << std::endl;
+          d_out->lemma(eq);
+          Node skl = NodeManager::currentNM()->mkNode( kind::STRING_LENGTH, sk );
+          Node lsum;
+          if( n.getKind() == kind::STRING_CONCAT ) {
+            std::vector<Node> node_vec;
+            for( unsigned i=0; i<n.getNumChildren(); i++ ) {
+              if( n[i].getAttribute(StringsProxyVarAttribute()) ){
+                Assert( d_proxy_var_to_length.find( n[i] )!=d_proxy_var_to_length.end() );
+                node_vec.push_back( d_proxy_var_to_length[n[i]] );
+              }else{
+                Node lni = NodeManager::currentNM()->mkNode( kind::STRING_LENGTH, n[i] );
+                node_vec.push_back(lni);
+              }
             }
+            lsum = NodeManager::currentNM()->mkNode( kind::PLUS, node_vec );
+          } else if( n.getKind() == kind::CONST_STRING ) {
+            lsum = NodeManager::currentNM()->mkConst( ::CVC4::Rational( n.getConst<String>().size() ) );
           }
-          lsum = NodeManager::currentNM()->mkNode( kind::PLUS, node_vec );
-        } else if( n.getKind() == kind::CONST_STRING ) {
-          lsum = NodeManager::currentNM()->mkConst( ::CVC4::Rational( n.getConst<String>().size() ) );
-        }
-        lsum = Rewriter::rewrite( lsum );
-        d_proxy_var_to_length[sk] = lsum;
-        if( options::stringEagerLen() || n.getKind()==kind::CONST_STRING ){
+          lsum = Rewriter::rewrite( lsum );
+          d_proxy_var_to_length[sk] = lsum;
           Node ceq = Rewriter::rewrite( skl.eqNode( lsum ) );
           Trace("strings-lemma") << "Strings::Lemma LENGTH : " << ceq << std::endl;
           Trace("strings-lemma-debug") << "  prerewrite : " << skl.eqNode( lsum ) << std::endl;
           Trace("strings-assert") << "(assert " << ceq << ")" << std::endl;
           d_out->lemma(ceq);
         }
+      } else {
+        AlwaysAssert(false, "String Terms only in registerTerm.");
       }
-      return true;
-    } else {
-      AlwaysAssert(false, "String Terms only in registerTerm.");
     }
   }
-  return false;
 }
 
 void TheoryStrings::sendLemma( Node ant, Node conc, const char * c ) {
@@ -2696,7 +2720,6 @@ void TheoryStrings::sendLengthLemma( Node n ){
   Node n_len = NodeManager::currentNM()->mkNode( kind::STRING_LENGTH, n);
   if( options::stringSplitEmp() || !options::stringLenGeqZ() ){
     Node n_len_eq_z = n_len.eqNode( d_zero );
-    //registerTerm( d_emptyString );
     Node n_len_eq_z_2 = n.eqNode( d_emptyString );
     n_len_eq_z = Rewriter::rewrite( n_len_eq_z );
     n_len_eq_z_2 = Rewriter::rewrite( n_len_eq_z_2 );
@@ -2963,7 +2986,8 @@ void TheoryStrings::checkLengthsEqc() {
         Trace("strings-process-debug") << "No length term for eqc " << d_strings_eqc[i] << " " << d_eqc_to_len_term[d_strings_eqc[i]] << std::endl;
         if( !options::stringEagerLen() ){
           Node c = mkConcat( d_normal_forms[d_strings_eqc[i]] );
-          registerTerm( c );
+          registerTerm( c, 3 );
+          /*
           if( !c.isConst() ){
             NodeNodeMap::const_iterator it = d_proxy_var.find( c );
             if( it!=d_proxy_var.end() ){
@@ -2974,6 +2998,7 @@ void TheoryStrings::checkLengthsEqc() {
               sendLemma( d_true, ceq, "LEN-NORM-I" );
             }
           }
+          */
         }
       }
       //} else {
index 125e1c1eb36d447e791b4272e8636bf1d836d659..40358649b30160a743f7ec5a31ef7049b317dcd7 100644 (file)
@@ -165,6 +165,7 @@ private:
   NodeSet d_loop_antec;
   NodeSet d_length_intro_vars;
   // preReg cache
+  NodeSet d_pregistered_terms_cache;
   NodeSet d_registered_terms_cache;
   // preprocess cache
   StringsPreprocess d_preproc;
@@ -332,7 +333,7 @@ protected:
   void addToExplanation( Node lit, std::vector< Node >& exp );
 
   //register term
-  bool registerTerm( Node n );
+  void registerTerm( Node n, int effort );
   //send lemma
   void sendLemma( Node ant, Node conc, const char * c );
   void sendInfer( Node eq_exp, Node eq, const char * c );