hot fix for pre-reg term caching in strings
authorTianyi Liang <tianyi-liang@uiowa.edu>
Mon, 17 Mar 2014 18:17:12 +0000 (13:17 -0500)
committerTianyi Liang <tianyi-liang@uiowa.edu>
Mon, 17 Mar 2014 18:22:08 +0000 (13:22 -0500)
src/theory/strings/theory_strings.cpp
src/theory/strings/theory_strings.h

index 448b94fd2ce6054746d6ab5387d95a88851e0309..a4d3145d9d70466b8ab610166342b41777ac47bc 100644 (file)
@@ -43,8 +43,9 @@ TheoryStrings::TheoryStrings(context::Context* c, context::UserContext* u, Outpu
     d_nf_pairs(c),
        d_loop_antec(u),
        d_length_intro_vars(u),
+       d_prereg_cached(u),
+       d_length_nodes(u),
        d_length_inst(u),
-       d_length_nodes(c),
        d_str_pos_ctn(c),
        d_str_neg_ctn(c),
        d_neg_ctn_eqlen(u),
@@ -401,9 +402,10 @@ void TheoryStrings::collectModelInfo( TheoryModel* m, bool fullModel ) {
 /////////////////////////////////////////////////////////////////////////////
 
 void TheoryStrings::preRegisterTerm(TNode n) {
-  Debug("strings-prereg") << "TheoryStrings::preRegisterTerm() " << n << endl;
-  //collectTerms( n );
-  switch (n.getKind()) {
+  if(d_prereg_cached.find(n) == d_prereg_cached.end()) {
+       Debug("strings-prereg") << "TheoryStrings::preRegisterTerm() " << n << endl;
+       //collectTerms( n );
+       switch (n.getKind()) {
        case kind::EQUAL:
                d_equalityEngine.addTriggerEquality(n);
                break;
@@ -456,6 +458,8 @@ void TheoryStrings::preRegisterTerm(TNode n) {
                  d_equalityEngine.addTerm(n);
                }
        }
+       }
+       d_prereg_cached.insert(n);
   }
 }
 
@@ -1788,40 +1792,41 @@ bool TheoryStrings::checkSimple() {
                        //then, add lemma
                        if( n.getKind() == kind::CONST_STRING || n.getKind() == kind::STRING_CONCAT ) {
                                if( d_length_nodes.find(n)==d_length_nodes.end() ) {
-                                       if( d_length_inst.find(n)==d_length_inst.end() ) {
+                                       Trace("strings-debug") << "get n: " << n << endl;
+                                       Node sk;
+                                       //if( d_length_inst.find(n)==d_length_inst.end() ) {
                                                //Node nr = d_equalityEngine.getRepresentative( n );
-                                               //if( d_length_nodes.find(nr)==d_length_nodes.end() ) {
-                                                       d_length_inst.insert(n);
-                                                       Trace("strings-debug") << "get n: " << n << endl;
-                                                       Node sk = NodeManager::currentNM()->mkSkolem( "lsym_$$", n.getType(), "created for length" );
-                                                       d_statistics.d_new_skolems += 1;
-                                                       d_length_intro_vars.insert( sk );
-                                                       Node eq = NodeManager::currentNM()->mkNode( kind::EQUAL, sk, n );
-                                                       eq = Rewriter::rewrite(eq);
-                                                       Trace("strings-lemma") << "Strings::Lemma LENGTH Term : " << eq << std::endl;
-                                                       d_out->lemma(eq);
-                                                       Node skl = NodeManager::currentNM()->mkNode( kind::STRING_LENGTH, sk );
-                                                       Node lsum;
-                                                       if( n.getKind() == kind::STRING_CONCAT ) {
-                                                               //add lemma
-                                                               std::vector<Node> node_vec;
-                                                               for( unsigned i=0; i<n.getNumChildren(); i++ ) {
-                                                                       Node lni = NodeManager::currentNM()->mkNode( kind::STRING_LENGTH, n[i] );
-                                                                       node_vec.push_back(lni);
-                                                               }
-                                                               lsum = Rewriter::rewrite( NodeManager::currentNM()->mkNode( kind::PLUS, node_vec ) );
-                                                       } else if( n.getKind() == kind::CONST_STRING ) {
-                                                               //add lemma
-                                                               lsum = NodeManager::currentNM()->mkConst( ::CVC4::Rational( n.getConst<String>().size() ) );
-                                                       }
-                                                       Node ceq = NodeManager::currentNM()->mkNode( kind::EQUAL, skl, lsum );
-                                                       ceq = Rewriter::rewrite(ceq);
-                                                       Trace("strings-lemma") << "Strings::Lemma LENGTH : " << ceq << std::endl;
-                                                       d_out->lemma(ceq);
-                                                       addedLemma = true;
-                                               //}
+                                               sk = NodeManager::currentNM()->mkSkolem( "lsym_$$", n.getType(), "created for length" );
+                                               d_statistics.d_new_skolems += 1;
+                                               d_length_intro_vars.insert( sk );
+                                               Node eq = sk.eqNode(n);
+                                               eq = Rewriter::rewrite(eq);
+                                               Trace("strings-lemma") << "Strings::Lemma LENGTH Term : " << eq << std::endl;
+                                               d_out->lemma(eq);
+                                       //} else {
+                                       //      sk = d_length_inst[n];
+                                       //}
+                                       Node skl = NodeManager::currentNM()->mkNode( kind::STRING_LENGTH, sk );
+                                       Node lsum;
+                                       if( n.getKind() == kind::STRING_CONCAT ) {
+                                               //add lemma
+                                               std::vector<Node> node_vec;
+                                               for( unsigned i=0; i<n.getNumChildren(); i++ ) {
+                                                       Node lni = NodeManager::currentNM()->mkNode( kind::STRING_LENGTH, n[i] );
+                                                       node_vec.push_back(lni);
+                                               }
+                                               lsum = Rewriter::rewrite( NodeManager::currentNM()->mkNode( kind::PLUS, node_vec ) );
+                                       } else if( n.getKind() == kind::CONST_STRING ) {
+                                               //add lemma
+                                               lsum = NodeManager::currentNM()->mkConst( ::CVC4::Rational( n.getConst<String>().size() ) );
                                        }
-                                       d_length_nodes[n] = true;
+                                       Node ceq = NodeManager::currentNM()->mkNode( kind::EQUAL, skl, lsum );
+                                       ceq = Rewriter::rewrite(ceq);
+                                       Trace("strings-lemma") << "Strings::Lemma LENGTH : " << ceq << std::endl;
+                                       d_out->lemma(ceq);
+                                       addedLemma = true;
+
+                                       d_length_nodes.insert(n);
                                }
                        }
                        ++eqc_i;
@@ -2734,9 +2739,13 @@ void TheoryStrings::assertNode( Node lit ) {
 }
 
 Node TheoryStrings::mkSplitEq( const char * c, const char * info, Node lhs, Node rhs, bool lgtZero ) {
-       Node sk = NodeManager::currentNM()->mkSkolem( c, lhs.getType(), info );
+       Node sk = NodeManager::currentNM()->mkSkolem( c, NodeManager::currentNM()->stringType(), info );
        d_statistics.d_new_skolems += 1;
-       Node eq = lhs.eqNode( mkConcat( rhs, sk ) );
+       Node cc = mkConcat( rhs, sk );
+       //if(rhs.isConst()) {
+       //      d_length_inst[cc] = lhs;
+       //}
+       Node eq = lhs.eqNode( cc );
        eq = Rewriter::rewrite( eq );
        if( lgtZero ) {
                Node sk_gt_zero = NodeManager::currentNM()->mkNode( kind::EQUAL, sk, d_emptyString).negate();
index c8a3748930d5357faaedc01ce258d0eb16060e02..902b902b6f3204d25c9408ece80d0d8a3e906b79 100644 (file)
@@ -157,6 +157,8 @@ private:
        // loop ant
        NodeSet d_loop_antec;
        NodeSet d_length_intro_vars;
+       // preReg cache
+       NodeSet d_prereg_cached;
 
        /////////////////////////////////////////////////////////////////////////////
        // MODEL GENERATION
@@ -194,8 +196,8 @@ private:
        std::map< Node, EqcInfo* > d_eqc_info;
        EqcInfo * getOrMakeEqcInfo( Node eqc, bool doMake = true );
        //maintain which concat terms have the length lemma instantiated
-       NodeSet d_length_inst;
-       NodeBoolMap d_length_nodes;
+       NodeSet d_length_nodes;
+       NodeNodeMap d_length_inst;
 private:
        void mergeCstVec(std::vector< Node > &vec_strings);
     bool getNormalForms(Node &eqc, std::vector< Node > & visited, std::vector< Node > & nf,