Make cbqi robust to term ITE removal. Separate vts infinities for int/real.
authorajreynol <andrew.j.reynolds@gmail.com>
Wed, 19 Aug 2015 19:45:37 +0000 (21:45 +0200)
committerajreynol <andrew.j.reynolds@gmail.com>
Wed, 19 Aug 2015 19:45:44 +0000 (21:45 +0200)
src/theory/quantifiers/ce_guided_single_inv.cpp
src/theory/quantifiers/inst_strategy_cbqi.cpp
src/theory/quantifiers/inst_strategy_cbqi.h
src/theory/quantifiers/instantiation_engine.cpp
src/theory/quantifiers/options
src/theory/quantifiers/term_database.cpp
src/theory/quantifiers/term_database.h
src/theory/theory_engine.h

index 17d85eb9b127ce2313e5e4983a355bf2c495248d..a1aa9dad8249dcd174c5f2ced085f112ed52bdf5 100644 (file)
@@ -719,10 +719,7 @@ bool CegConjectureSingleInv::addInstantiation( std::vector< Node >& subs ){
   }else{
     Trace("cegqi-engine") << siss.str() << std::endl;
     Node lem = d_single_inv[1].substitute( d_single_inv_var.begin(), d_single_inv_var.end(), subs.begin(), subs.end() );
-    Node delta = d_qe->getTermDatabase()->getVtsDelta( false, false );
-    Node inf = d_qe->getTermDatabase()->getVtsInfinity( false, false );
-    if( ( !delta.isNull() && TermDb::containsTerm( lem, delta ) ) || 
-        ( !inf.isNull() && TermDb::containsTerm( lem, inf ) ) ){
+    if( d_qe->getTermDatabase()->containsVtsTerm( lem ) ){
       Trace("cegqi-engine-debug") << "Rewrite based on vts symbols..." << std::endl;
       lem = d_qe->getTermDatabase()->rewriteVtsSymbols( lem );
     }
index dcbb79a35a1adcde653e97ad0def2479ac117d26..9d49f3d724b198f205283609fd91ab2adc1c4c43 100644 (file)
@@ -34,7 +34,6 @@ using namespace CVC4::theory::datatypes;
 //#define MBP_STRICT_ASSERTIONS
 
 
-
 CegInstantiator::CegInstantiator( QuantifiersEngine * qe, CegqiOutput * out, bool use_vts_delta, bool use_vts_inf ) :
 d_qe( qe ), d_out( out ), d_use_vts_delta( use_vts_delta ), d_use_vts_inf( use_vts_inf ){
   d_zero = NodeManager::currentNM()->mkConst( Rational( 0 ) );
@@ -70,23 +69,25 @@ bool CegInstantiator::addInstantiation( std::vector< Node >& subs, std::vector<
   if( i==d_vars.size() ){
     return addInstantiationCoeff( subs, vars, coeff, has_coeff, 0 );
   }else{
-    eq::EqualityEngine* ee = d_qe->getMasterEqualityEngine();
     std::map< Node, std::map< Node, bool > > subs_proc;
     //Node v = d_single_inv_map_to_prog[d_single_inv[0][i]];
     Node pv = d_vars[i];
     TypeNode pvtn = pv.getType();
+    Trace("cbqi-inst-debug") << "[Find instantiation for " << pv << "]" << std::endl;
 
     //if in effort=2, we must choose at least one model value
     if( (i+1)<d_vars.size() || effort!=2 ){
+      
       //[1] easy case : pv is in the equivalence class as another term not containing pv
-      if( ee->hasTerm( pv ) ){
-        //std::vector< Node > eqc_sk;
-        Node pvr = ee->getRepresentative( pv );
-        eq::EqClassIterator eqc_i = eq::EqClassIterator( pvr, ee );
-        while( !eqc_i.isFinished() ){
-          Node n = *eqc_i;
+      Trace("cbqi-inst-debug") << "[1] try based on equivalence class." << std::endl;
+      std::map< Node, Node >::iterator itr = d_curr_rep.find( pv );
+      if( itr!=d_curr_rep.end() ){
+        std::map< Node, std::vector< Node > >::iterator it_eqc = d_curr_eqc.find( itr->second );
+        Assert( it_eqc!=d_curr_eqc.end() );
+        for( unsigned k=0; k<it_eqc->second.size(); k++ ){
+          Node n = it_eqc->second[k];
           if( n!=pv ){
-            Trace("cbqi-inst-debug") << "[1] " << i << "...try based on equal term " << n << std::endl;
+            Trace("cbqi-inst-debug") << "..[1] " << i << "...try based on equal term " << n << std::endl;
             //compute d_subs_fv, which program variables are contained in n
             computeProgVars( n );
             //must be an eligible term
@@ -113,117 +114,110 @@ bool CegInstantiator::addInstantiation( std::vector< Node >& subs, std::vector<
                 }
               }
             }
-            //record this as skolem
-            //if( n.getKind()==SKOLEM ){
-            //  eqc_sk.push_back( n );
-            //}
           }
-          ++eqc_i;
         }
       }
 
       //[2] : we can solve an equality for pv
       ///iterate over equivalence classes to find cases where we can solve for the variable
-      Node vts_inf = d_qe->getTermDatabase()->getVtsInfinity( false, false );
       if( pvtn.isInteger() || pvtn.isReal() ){
-        eq::EqClassesIterator eqcs_i = eq::EqClassesIterator( ee );
-        while( !eqcs_i.isFinished() ){
-          Node r = *eqcs_i;
-          TypeNode rtn = r.getType();
-          if( rtn.isInteger() || rtn.isReal() ){
-            std::vector< Node > lhs;
-            std::vector< bool > lhs_v;
-            std::vector< Node > lhs_coeff;
-            eq::EqClassIterator eqc_i = eq::EqClassIterator( r, ee );
-            while( !eqc_i.isFinished() ){
-              Node n = *eqc_i;
-              computeProgVars( n );
-              //must be an eligible term
-              if( d_inelig.find( n )==d_inelig.end() ){
-                Node ns;
-                Node pv_coeff;
-                if( !d_prog_var[n].empty() ){
-                  ns = applySubstitution( n, subs, vars, coeff, has_coeff, pv_coeff );
-                  if( !ns.isNull() ){
-                    computeProgVars( ns );
-                  }
-                }else{
-                  ns = n;
-                }
+        Trace("cbqi-inst-debug") << "[2] try based on solving in arithmetic equivalence class." << std::endl;
+        for( unsigned k=0; k<d_curr_arith_eqc.size(); k++ ){
+          Node r = d_curr_arith_eqc[k];
+          std::vector< Node > lhs;
+          std::vector< bool > lhs_v;
+          std::vector< Node > lhs_coeff;
+          std::map< Node, std::vector< Node > >::iterator it_eqc = d_curr_eqc.find( r );
+          Assert( it_eqc!=d_curr_eqc.end() );
+          for( unsigned kk=0; kk<it_eqc->second.size(); kk++ ){
+            Node n = it_eqc->second[kk];
+            Trace("cbqi-inst-debug2") << "...look at term " << n << std::endl;
+            //compute the variables in n
+            computeProgVars( n );
+            //must be an eligible term
+            if( d_inelig.find( n )==d_inelig.end() ){
+              Node ns;
+              Node pv_coeff;
+              if( !d_prog_var[n].empty() ){
+                ns = applySubstitution( n, subs, vars, coeff, has_coeff, pv_coeff );
                 if( !ns.isNull() ){
-                  bool hasVar = d_prog_var[ns].find( pv )!=d_prog_var[ns].end();
-                  for( unsigned j=0; j<lhs.size(); j++ ){
-                    //if this term or the another has pv in it, try to solve for it
-                    if( hasVar || lhs_v[j] ){
-                      Trace("cbqi-inst-debug") << "[2] " << i << "...try based on equality " << lhs[j] << " = " << ns << std::endl;
-                      Node eq_lhs = lhs[j];
-                      Node eq_rhs = ns;
-                      //make the same coefficient
-                      if( pv_coeff!=lhs_coeff[j] ){
-                        if( !pv_coeff.isNull() ){
-                          Trace("cbqi-inst-debug") << "...mult lhs by " << pv_coeff << std::endl;
-                          eq_lhs = NodeManager::currentNM()->mkNode( MULT, pv_coeff, eq_lhs );
-                          eq_lhs = Rewriter::rewrite( eq_lhs );
-                        }
-                        if( !lhs_coeff[j].isNull() ){
-                          Trace("cbqi-inst-debug") << "...mult rhs by " << lhs_coeff[j] << std::endl;
-                          eq_rhs = NodeManager::currentNM()->mkNode( MULT, lhs_coeff[j], eq_rhs );
-                          eq_rhs = Rewriter::rewrite( eq_rhs );
-                        }
+                  computeProgVars( ns );
+                }
+              }else{
+                ns = n;
+              }
+              if( !ns.isNull() ){
+                bool hasVar = d_prog_var[ns].find( pv )!=d_prog_var[ns].end();
+                for( unsigned j=0; j<lhs.size(); j++ ){
+                  //if this term or the another has pv in it, try to solve for it
+                  if( hasVar || lhs_v[j] ){
+                    Trace("cbqi-inst-debug") << "..[2] " << i << "...try based on equality " << lhs[j] << " = " << ns << std::endl;
+                    Node eq_lhs = lhs[j];
+                    Node eq_rhs = ns;
+                    //make the same coefficient
+                    if( pv_coeff!=lhs_coeff[j] ){
+                      if( !pv_coeff.isNull() ){
+                        Trace("cbqi-inst-debug") << "...mult lhs by " << pv_coeff << std::endl;
+                        eq_lhs = NodeManager::currentNM()->mkNode( MULT, pv_coeff, eq_lhs );
+                        eq_lhs = Rewriter::rewrite( eq_lhs );
                       }
-                      Node eq = eq_lhs.eqNode( eq_rhs );
-                      eq = Rewriter::rewrite( eq );
-                      //cannot contain infinity
-                      if( vts_inf.isNull() || !TermDb::containsTerm( eq, vts_inf ) ){
-                        Trace("cbqi-inst-debug") << "...equality is " << eq << std::endl;
-                        std::map< Node, Node > msum;
-                        if( QuantArith::getMonomialSumLit( eq, msum ) ){
-                          if( Trace.isOn("cbqi-inst-debug") ){
-                            Trace("cbqi-inst-debug") << "...got monomial sum: " << std::endl;
-                            QuantArith::debugPrintMonomialSum( msum, "cbqi-inst-debug" );
-                            Trace("cbqi-inst-debug") << "isolate for " << pv << "..." << std::endl;
-                          }
-                          Node veq;
-                          if( QuantArith::isolate( pv, msum, veq, EQUAL, true )!=0 ){
-                            Trace("cbqi-inst-debug") << "...isolated equality " << veq << "." << std::endl;
-                            Node veq_c;
-                            if( veq[0]!=pv ){
-                              Node veq_v;
-                              if( QuantArith::getMonomial( veq[0], veq_c, veq_v ) ){
-                                Assert( veq_v==pv );
-                              }
-                            }
-                            Node val = veq[1];
-                            //eliminate coefficient if real
-                            if( !pvtn.isInteger() && !veq_c.isNull() ){
-                              val = NodeManager::currentNM()->mkNode( MULT, val, NodeManager::currentNM()->mkConst( Rational(1) / veq_c.getConst<Rational>() ) );
-                              val = Rewriter::rewrite( val );
-                              veq_c = Node::null();
+                      if( !lhs_coeff[j].isNull() ){
+                        Trace("cbqi-inst-debug") << "...mult rhs by " << lhs_coeff[j] << std::endl;
+                        eq_rhs = NodeManager::currentNM()->mkNode( MULT, lhs_coeff[j], eq_rhs );
+                        eq_rhs = Rewriter::rewrite( eq_rhs );
+                      }
+                    }
+                    Node eq = eq_lhs.eqNode( eq_rhs );
+                    eq = Rewriter::rewrite( eq );
+                    //cannot contain infinity
+                    if( !d_qe->getTermDatabase()->containsVtsInfinity( eq ) ){
+                      Trace("cbqi-inst-debug") << "...equality is " << eq << std::endl;
+                      std::map< Node, Node > msum;
+                      if( QuantArith::getMonomialSumLit( eq, msum ) ){
+                        if( Trace.isOn("cbqi-inst-debug") ){
+                          Trace("cbqi-inst-debug") << "...got monomial sum: " << std::endl;
+                          QuantArith::debugPrintMonomialSum( msum, "cbqi-inst-debug" );
+                          Trace("cbqi-inst-debug") << "isolate for " << pv << "..." << std::endl;
+                        }
+                        Node veq;
+                        if( QuantArith::isolate( pv, msum, veq, EQUAL, true )!=0 ){
+                          Trace("cbqi-inst-debug") << "...isolated equality " << veq << "." << std::endl;
+                          Node veq_c;
+                          if( veq[0]!=pv ){
+                            Node veq_v;
+                            if( QuantArith::getMonomial( veq[0], veq_c, veq_v ) ){
+                              Assert( veq_v==pv );
                             }
-                            if( subs_proc[val].find( veq_c )==subs_proc[val].end() ){
-                              subs_proc[val][veq_c] = true;
-                              if( addInstantiationInc( val, pv, veq_c, subs, vars, coeff, has_coeff, theta, i, effort ) ){
-                                return true;
-                              }
+                          }
+                          Node val = veq[1];
+                          //eliminate coefficient if real
+                          if( !pvtn.isInteger() && !veq_c.isNull() ){
+                            val = NodeManager::currentNM()->mkNode( MULT, val, NodeManager::currentNM()->mkConst( Rational(1) / veq_c.getConst<Rational>() ) );
+                            val = Rewriter::rewrite( val );
+                            veq_c = Node::null();
+                          }
+                          if( subs_proc[val].find( veq_c )==subs_proc[val].end() ){
+                            subs_proc[val][veq_c] = true;
+                            if( addInstantiationInc( val, pv, veq_c, subs, vars, coeff, has_coeff, theta, i, effort ) ){
+                              return true;
                             }
                           }
                         }
                       }
                     }
                   }
-                  lhs.push_back( ns );
-                  lhs_v.push_back( hasVar );
-                  lhs_coeff.push_back( pv_coeff );
                 }
+                lhs.push_back( ns );
+                lhs_v.push_back( hasVar );
+                lhs_coeff.push_back( pv_coeff );
               }
-              ++eqc_i;
             }
           }
-          ++eqcs_i;
         }
       }
 
       //[3] directly look at assertions
+      Trace("cbqi-inst-debug") << "[3] try based on assertions." << std::endl;
       std::vector< Node > mbp_bounds[2];
       std::vector< Node > mbp_coeff[2];
       std::vector< bool > mbp_strict[2];
@@ -231,13 +225,11 @@ bool CegInstantiator::addInstantiation( std::vector< Node >& subs, std::vector<
       unsigned rmax = Theory::theoryOf( pv )==Theory::theoryOf( pv.getType() ) ? 1 : 2;
       for( unsigned r=0; r<rmax; r++ ){
         TheoryId tid = r==0 ? Theory::theoryOf( pv ) : Theory::theoryOf( pv.getType() );
-        Theory* theory = d_qe->getTheoryEngine()->theoryOf( tid );
-        Trace("cbqi-inst-debug2") << "Theory of " << pv << " (r=" << r << ") is " << tid << std::endl;
-        if (theory && d_qe->getTheoryEngine()->isTheoryEnabled(tid)) {
-          Trace("cbqi-inst-debug2") << "Look at assertions of " << tid << std::endl;
-          context::CDList<Assertion>::const_iterator it = theory->facts_begin(), it_end = theory->facts_end();
-          for (unsigned j = 0; it != it_end; ++ it, ++j) {
-            Node lit = (*it).assertion;
+        Trace("cbqi-inst-debug2") << "  look at assertions of " << tid << std::endl;
+        std::map< TheoryId, std::vector< Node > >::iterator ita = d_curr_asserts.find( tid );
+        if( ita!=d_curr_asserts.end() ){
+          for (unsigned j = 0; j<ita->second.size(); j++) {
+            Node lit = ita->second[j];
             Trace("cbqi-inst-debug2") << "  look at " << lit << std::endl;
             Node atom = lit.getKind()==NOT ? lit[0] : lit;
             bool pol = lit.getKind()!=NOT;
@@ -273,9 +265,9 @@ bool CegInstantiator::addInstantiation( std::vector< Node >& subs, std::vector<
                 if( !atom_lhs.isNull() && d_prog_var[atom_lhs].find( pv )!=d_prog_var[atom_lhs].end() ){
                   Node satom = NodeManager::currentNM()->mkNode( atom.getKind(), atom_lhs, atom_rhs );
                   //cannot contain infinity
-                  if( vts_inf.isNull() || !TermDb::containsTerm( atom_lhs, vts_inf ) ){
-                    Trace("cbqi-inst-debug") << "[3] From assertion : " << atom << ", pol = " << pol << std::endl;
-                    Trace("cbqi-inst-debug") << "       substituted : " << satom << ", pol = " << pol << std::endl;
+                  if( !d_qe->getTermDatabase()->containsVtsInfinity( atom_lhs ) ){
+                    Trace("cbqi-inst-debug") << "..[3] From assertion : " << atom << ", pol = " << pol << std::endl;
+                    Trace("cbqi-inst-debug") << "         substituted : " << satom << ", pol = " << pol << std::endl;
                     std::map< Node, Node > msum;
                     if( QuantArith::getMonomialSumLit( satom, msum ) ){
                       if( Trace.isOn("cbqi-inst-debug") ){
@@ -385,11 +377,10 @@ bool CegInstantiator::addInstantiation( std::vector< Node >& subs, std::vector<
           for( unsigned r=0; r<2; r++ ){
             int rr = upper_first ? (1-r) : r;
             if( mbp_bounds[rr].empty() ){
-              /*
               if( d_use_vts_inf ){
                 Trace("cbqi-bound") << "No " << ( rr==0 ? "lower" : "upper" ) << " bounds for " << pv << " (type=" << pvtn << ")" << std::endl;
                 //no bounds, we do +- infinity
-                Node val = d_qe->getTermDatabase()->getVtsInfinity();
+                Node val = d_qe->getTermDatabase()->getVtsInfinity( pvtn );
                 if( rr==0 ){
                   val = NodeManager::currentNM()->mkNode( UMINUS, val );
                   val = Rewriter::rewrite( val );
@@ -398,7 +389,6 @@ bool CegInstantiator::addInstantiation( std::vector< Node >& subs, std::vector<
                   return true;
                 }
               }
-              */
             }else{
               Trace("cbqi-bound") << ( rr==0 ? "Lower" : "Upper" ) << " bounds for " << pv << " (type=" << pvtn << ") : " << std::endl;
               int best = -1;
@@ -469,12 +459,14 @@ bool CegInstantiator::addInstantiation( std::vector< Node >& subs, std::vector<
       int new_effort = pvtn.isBoolean() ? effort : 1;
 #ifdef MBP_STRICT_ASSERTIONS
       //we only resort to values in the case of booleans
-      Assert( pvtn.isBoolean() );
+      Assert( !options::cbqiUseInf() || pvtn.isBoolean() );
 #endif
-      return addInstantiationInc( mv, pv, pv_coeff_m, subs, vars, coeff, has_coeff, theta, i, new_effort );
-    }else{
-      return false;
+      if( addInstantiationInc( mv, pv, pv_coeff_m, subs, vars, coeff, has_coeff, theta, i, new_effort ) ){
+        return true;
+      }
     }
+    Trace("cbqi-inst-debug") << "[No instantiation found for " << pv << "]" << std::endl;
+    return false;
   }
 }
 
@@ -637,22 +629,6 @@ bool CegInstantiator::addInstantiationCoeff( std::vector< Node >& subs, std::vec
         if( addInstantiationCoeff( subs, vars, coeff, has_coeff, j+1 ) ){
           return true;
         }
-          //equalities are both upper and lower bounds
-          /*
-          if( subs_typ[index]==0 && !veq_c.isNull() ){
-            subs[index] = NodeManager::currentNM()->mkNode( PLUS, subs[index],
-              NodeManager::currentNM()->mkNode( ITE,
-                NodeManager::currentNM()->mkNode( EQUAL,
-                  NodeManager::currentNM()->mkNode( INTS_MODULUS, veq[1], veq_c ),
-                  NodeManager::currentNM()->mkConst( Rational( 0 ) ) ),
-                NodeManager::currentNM()->mkConst( Rational( 0 ) ),
-                NodeManager::currentNM()->mkConst( Rational( 1 ) ) )
-            );
-            if( addInstantiationCoeff( subs, vars, coeff, has_coeff, j+1 ) ){
-              return true;
-            }
-          }
-          */
       }
     }
     subs[index] = prev;
@@ -809,6 +785,7 @@ bool CegInstantiator::check() {
     Trace("cegqi-engine") << "  CEGQI instantiator : wait until all ground theories are finished." << std::endl;
     return false;
   }
+  processAssertions();
   for( unsigned r=0; r<2; r++ ){
     std::vector< Node > subs;
     std::vector< Node > vars;
@@ -824,6 +801,225 @@ bool CegInstantiator::check() {
   return false;
 }
 
+void setAuxRep( std::map< Node, Node >& aux_uf, std::map< Node, Node >& aux_subs, Node n1, Node n2 ){
+  Assert( aux_uf.find( n1 )==aux_uf.end() );
+  Assert( aux_uf.find( n2 )==aux_uf.end() );
+  //only merge if not in substitution
+  if( aux_subs.find( n1 )==aux_subs.end() ){
+    aux_uf[n1] = n2;
+  }else if( aux_subs.find( n2 )==aux_subs.end() ){
+    aux_uf[n2] = n1;
+  }
+}
+
+Node getAuxRep( std::map< Node, Node >& aux_uf, Node n ){
+  std::map< Node, Node >::iterator it = aux_uf.find( n );
+  if( it!=aux_uf.end() ){
+    Node r = getAuxRep( aux_uf, it->second );
+    aux_uf[n] = r;
+    return r;
+  }else{
+    return n;
+  }
+}
+
+void CegInstantiator::processAssertions() {
+  Trace("cbqi-proc") << "--- Process assertions, #var = " << d_vars.size() << ", #aux-var = " << d_aux_vars.size() << std::endl;
+  d_curr_asserts.clear();
+  d_curr_eqc.clear();
+  d_curr_rep.clear();
+  d_curr_arith_eqc.clear();
+  
+  eq::EqualityEngine* ee = d_qe->getMasterEqualityEngine();
+  //to eliminate identified illegal terms
+  std::map< Node, Node > aux_uf;
+  std::map< Node, Node > aux_subs;
+  std::map< Node, bool > aux_subs_inelig;
+  
+  //for each variable
+  bool has_arith_var = false;
+  for( unsigned i=0; i<d_vars.size(); i++ ){
+    Node pv = d_vars[i];
+    TypeNode pvtn = pv.getType();
+    //collect current assertions
+    unsigned rmax = Theory::theoryOf( pv )==Theory::theoryOf( pv.getType() ) ? 1 : 2;
+    for( unsigned r=0; r<rmax; r++ ){
+      TheoryId tid = r==0 ? Theory::theoryOf( pv ) : Theory::theoryOf( pv.getType() );
+      Theory* theory = d_qe->getTheoryEngine()->theoryOf( tid );
+      Trace("cbqi-proc-debug") << "...theory of " << pv << " (r=" << r << ") is " << tid << std::endl;
+      if( d_curr_asserts.find( tid )==d_curr_asserts.end() ){
+        if (theory && d_qe->getTheoryEngine()->isTheoryEnabled(tid)) {
+          Trace("cbqi-proc") << "Collect assertions from " << tid << std::endl;
+          d_curr_asserts[tid].clear();
+          //collect all assertions from theory
+          for( context::CDList<Assertion>::const_iterator it = theory->facts_begin(); it != theory->facts_end(); ++ it) {
+            Node lit = (*it).assertion;
+            d_curr_asserts[tid].push_back( lit );
+            Trace("cbqi-proc-debug") << "...add : " << lit << std::endl;
+            if( lit.getKind()==EQUAL ){
+              //check if it is an auxiliary variable (for instance, from ITE removal).  If so, solve for it.
+              for( unsigned k=0; k<2; k++ ){
+                Node s = lit[k];
+                if( std::find( d_aux_vars.begin(), d_aux_vars.end(), s )!=d_aux_vars.end() ){
+                  Node sr = getAuxRep( aux_uf, s );
+                  if( std::find( d_aux_vars.begin(), d_aux_vars.end(), lit[1-k] )!=d_aux_vars.end() ){
+                    Node ssr = getAuxRep( aux_uf, lit[1-k] );
+                    //merge in the union find
+                    if( sr!=ssr ){
+                      Trace("cbqi-proc") << "...merge : " << sr << " = " << ssr << std::endl;
+                      setAuxRep( aux_uf, aux_subs, sr, ssr );
+                    }
+                  //if we don't have yet a substitution yet or the substitution is ineligible
+                  }else if( aux_subs.find( sr )==aux_subs.end() || aux_subs_inelig[sr] ){
+                    computeProgVars( lit[1-k] );
+                    bool isInelig = d_inelig.find( lit[1-k] )!=d_inelig.end();
+                    //equality for auxiliary variable : will add to substitution
+                    Trace("cbqi-proc") << "...add to substitution : " << sr << " -> " << lit[1-k] << std::endl;
+                    aux_subs[sr] = lit[1-k];
+                    aux_subs_inelig[sr] = isInelig;
+                  }
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+    //collect information about eqc
+    if( ee->hasTerm( pv ) ){
+      Node pvr = ee->getRepresentative( pv );
+      d_curr_rep[pv] = pvr;
+      if( d_curr_eqc.find( pvr )==d_curr_eqc.end() ){
+        Trace("cbqi-proc") << "Collect equivalence class " << pvr << std::endl;
+        eq::EqClassIterator eqc_i = eq::EqClassIterator( pvr, ee );
+        while( !eqc_i.isFinished() ){
+          d_curr_eqc[pvr].push_back( *eqc_i );
+          ++eqc_i;
+        }
+      }
+    }
+    //has arith var
+    if( pvtn.isInteger() || pvtn.isReal() ){
+      has_arith_var = true;
+    }
+  }
+  //must process all arithmetic eqc if any arithmetic variable
+  if( has_arith_var ){
+    Trace("cbqi-proc-debug") << "...collect arithmetic equivalence classes" << std::endl;
+    eq::EqClassesIterator eqcs_i = eq::EqClassesIterator( ee );
+    while( !eqcs_i.isFinished() ){
+      Node r = *eqcs_i;
+      TypeNode rtn = r.getType();
+      if( rtn.isInteger() || rtn.isReal() ){  
+        Trace("cbqi-proc-debug") << "...arith eqc: " << r << std::endl;
+        d_curr_arith_eqc.push_back( r );
+        if( d_curr_eqc.find( r )==d_curr_eqc.end() ){
+          Trace("cbqi-proc") << "Collect equivalence class " << r << std::endl;
+          eq::EqClassIterator eqc_i = eq::EqClassIterator( r, ee );
+          while( !eqc_i.isFinished() ){
+            d_curr_eqc[r].push_back( *eqc_i );
+            ++eqc_i;
+          }
+        }
+      }
+      ++eqcs_i;
+    }
+  }
+  //construct substitution from union find
+  std::vector< Node > subs_lhs;
+  std::vector< Node > subs_rhs;
+  for( unsigned i=0; i<d_aux_vars.size(); i++ ){
+    Node l = d_aux_vars[i];
+    Node r = getAuxRep( aux_uf, l );
+    std::map< Node, Node >::iterator it = aux_subs.find( r );
+    if( it!=aux_subs.end() ){
+      addToAuxVarSubstitution( subs_lhs, subs_rhs, l, it->second );
+    }else{
+#ifdef MBP_STRICT_ASSERTIONS
+      Assert( false );
+#endif
+      Trace("cbqi-proc") << "....no substitution found for auxiliary variable " << l << "!!!" << std::endl;
+    }
+  }  
+
+  
+  //apply substitutions to everything, if necessary
+  if( !subs_lhs.empty() ){
+    Trace("cbqi-proc") << "Applying substitution : " << std::endl;
+    for( unsigned i=0; i<subs_lhs.size(); i++ ){
+      Trace("cbqi-proc") << "  " << subs_lhs[i] << " -> " << subs_rhs[i] << std::endl;
+    }
+    
+    for( std::map< TheoryId, std::vector< Node > >::iterator it = d_curr_asserts.begin(); it != d_curr_asserts.end(); ++it ){
+      for( unsigned i=0; i<it->second.size(); i++ ){
+        Node lit = it->second[i];
+        lit = lit.substitute( subs_lhs.begin(), subs_lhs.end(), subs_rhs.begin(), subs_rhs.end() );
+        lit = Rewriter::rewrite( lit );
+        it->second[i] = lit;
+      }
+    }
+    for( std::map< Node, std::vector< Node > >::iterator it = d_curr_eqc.begin(); it != d_curr_eqc.end(); ++it ){
+      for( unsigned i=0; i<it->second.size(); i++ ){
+        Node n = it->second[i];
+        n = n.substitute( subs_lhs.begin(), subs_lhs.end(), subs_rhs.begin(), subs_rhs.end() );
+        n = Rewriter::rewrite( n  );
+        it->second[i] = n;
+      }
+    }
+  }
+
+  //remove unecessary assertions
+  for( std::map< TheoryId, std::vector< Node > >::iterator it = d_curr_asserts.begin(); it != d_curr_asserts.end(); ++it ){
+    std::vector< Node > akeep;
+    for( unsigned i=0; i<it->second.size(); i++ ){
+      Node n = it->second[i];
+      //compute the variables in assertion
+      computeProgVars( n );
+      //must be an eligible term
+      if( d_inelig.find( n )==d_inelig.end() ){
+        //must contain at least one variable
+        if( !d_prog_var[n].empty() ){
+          akeep.push_back( n );
+        }else{
+          Trace("cbqi-proc") << "...remove literal from " << it->first << " : " << n << " since it contains no relevant variables." << std::endl;
+        }
+      }else{
+        Trace("cbqi-proc") << "...remove literal from " << it->first << " : " << n << " since it contains ineligible terms." << std::endl;
+      }
+    }
+    it->second.clear();
+    it->second.insert( it->second.end(), akeep.begin(), akeep.end() );
+  }
+  
+  //remove duplicate terms from eqc
+  for( std::map< Node, std::vector< Node > >::iterator it = d_curr_eqc.begin(); it != d_curr_eqc.end(); ++it ){
+    std::vector< Node > new_eqc;
+    for( unsigned i=0; i<it->second.size(); i++ ){
+      if( std::find( new_eqc.begin(), new_eqc.end(), it->second[i] )==new_eqc.end() ){
+        new_eqc.push_back( it->second[i] );
+      }
+    }
+    it->second.clear();
+    it->second.insert( it->second.end(), new_eqc.begin(), new_eqc.end() );
+  }
+}
+
+void CegInstantiator::addToAuxVarSubstitution( std::vector< Node >& subs_lhs, std::vector< Node >& subs_rhs, Node l, Node r ) {
+  r = r.substitute( subs_lhs.begin(), subs_lhs.end(), subs_rhs.begin(), subs_rhs.end() );
+  
+  std::vector< Node > cl;
+  cl.push_back( l );
+  std::vector< Node > cr;
+  cr.push_back( r );
+  for( unsigned i=0; i<subs_lhs.size(); i++ ){
+    Node nr = subs_rhs[i].substitute( cl.begin(), cl.end(), cr.begin(), cr.end() );
+    nr = Rewriter::rewrite( nr );
+    subs_rhs[i] = nr;
+  }
+  
+  subs_lhs.push_back( l );
+  subs_rhs.push_back( r );
+}
 
 //old implementation
 
@@ -1176,10 +1372,14 @@ int InstStrategyCegqi::process( Node f, Theory::Effort effort, int e ) {
     CegInstantiator * cinst;
     std::map< Node, CegInstantiator * >::iterator it = d_cinst.find( f );
     if( it==d_cinst.end() ){
-      cinst = new CegInstantiator( d_quantEngine, d_out );
+      cinst = new CegInstantiator( d_quantEngine, d_out, true, options::cbqiUseInf() );
       for( int i=0; i<d_quantEngine->getTermDatabase()->getNumInstantiationConstants( f ); i++ ){
         cinst->d_vars.push_back( d_quantEngine->getTermDatabase()->getInstantiationConstant( f, i ) );
       }
+      std::map< Node, std::vector< Node > >::iterator itav = d_aux_variables.find( f );
+      if( itav!=d_aux_variables.end() ){
+        cinst->d_aux_vars.insert( cinst->d_aux_vars.begin(), itav->second.begin(), itav->second.end() );
+      }
       d_cinst[f] = cinst;
     }else{
       cinst = it->second;
@@ -1192,23 +1392,22 @@ int InstStrategyCegqi::process( Node f, Theory::Effort effort, int e ) {
   }else if( e==2 ){
     //minimize the free delta heuristically on demand
     if( d_check_vts_lemma_lc ){
+      d_check_vts_lemma_lc = false;
+      d_small_const = NodeManager::currentNM()->mkNode( MULT, d_small_const, d_small_const );
+      d_small_const = Rewriter::rewrite( d_small_const );
+      //heuristic for now, until we know how to do nested quantification
       Node delta = d_quantEngine->getTermDatabase()->getVtsDelta( true, false );
-      Node inf = d_quantEngine->getTermDatabase()->getVtsInfinity( true, false );
-      if( !delta.isNull() || !inf.isNull() ){
-        d_check_vts_lemma_lc = false;
-        d_small_const = NodeManager::currentNM()->mkNode( MULT, d_small_const, d_small_const );
-        d_small_const = Rewriter::rewrite( d_small_const );
-        //heuristic for now, until we know how to do nested quantification
-        if( !delta.isNull() ){
-          Trace("cegqi") << "Delta lemma for " << d_small_const << std::endl;
-          Node delta_lem_ub = NodeManager::currentNM()->mkNode( LT, delta, d_small_const );
-          d_quantEngine->getOutputChannel().lemma( delta_lem_ub );
-        }
-        if( !inf.isNull() ){
-          Trace("cegqi") << "Infinity lemma for " << d_small_const << std::endl;
-          Node inf_lem_lb = NodeManager::currentNM()->mkNode( GT, inf, NodeManager::currentNM()->mkConst( Rational(1)/d_small_const.getConst<Rational>() ) );
-          d_quantEngine->getOutputChannel().lemma( inf_lem_lb );
-        }
+      if( !delta.isNull() ){
+        Trace("cegqi") << "Delta lemma for " << d_small_const << std::endl;
+        Node delta_lem_ub = NodeManager::currentNM()->mkNode( LT, delta, d_small_const );
+        d_quantEngine->getOutputChannel().lemma( delta_lem_ub );
+      }
+      std::vector< Node > inf;
+      d_quantEngine->getTermDatabase()->getVtsTerms( inf, true, false, false );
+      for( unsigned i=0; i<inf.size(); i++ ){
+        Trace("cegqi") << "Infinity lemma for " << inf[i] << " " << d_small_const << std::endl;
+        Node inf_lem_lb = NodeManager::currentNM()->mkNode( GT, inf[i], NodeManager::currentNM()->mkConst( Rational(1)/d_small_const.getConst<Rational>() ) );
+        d_quantEngine->getOutputChannel().lemma( inf_lem_lb );
       }
     }
   }
@@ -1217,31 +1416,8 @@ int InstStrategyCegqi::process( Node f, Theory::Effort effort, int e ) {
 
 bool InstStrategyCegqi::addInstantiation( std::vector< Node >& subs ) {
   Assert( !d_curr_quant.isNull() );
-  /*
-  std::stringstream siss;
-  if( Trace.isOn("inst-cegqi") || Trace.isOn("inst-cegqi-debug") ){
-    for( unsigned j=0; j<d_single_inv_sk.size(); j++ ){
-      Node v = d_single_inv_map_to_prog[d_single_inv[0][j]];
-      siss << "    * " << v;
-      siss << " (" << d_single_inv_sk[j] << ")";
-      siss << " -> " << ( subs_typ[j]==9 ? "M:" : "") << subs[j] << std::endl;
-    }
-  }
-  */
-  //check if we need virtual term substitution (if used delta)
-  bool used_vts = false;
-  Node delta = d_quantEngine->getTermDatabase()->getVtsDelta( false, false );
-  Node inf = d_quantEngine->getTermDatabase()->getVtsInfinity( false, false );
-  if( !delta.isNull() || !inf.isNull() ){
-    for( unsigned i=0; i<subs.size(); i++ ){
-      if( !delta.isNull() && TermDb::containsTerm( subs[i], delta ) ){
-        used_vts = true;
-      }
-      if( !inf.isNull() && TermDb::containsTerm( subs[i], inf ) ){
-        used_vts = true;
-      }
-    }
-  }
+  //check if we need virtual term substitution (if used delta or infinity)
+  bool used_vts = d_quantEngine->getTermDatabase()->containsVtsTerm( subs, false );
   return d_quantEngine->addInstantiation( d_curr_quant, subs, false, false, false, used_vts );
 }
 
@@ -1258,7 +1434,16 @@ bool InstStrategyCegqi::isEligibleForInstantiation( Node n ) {
   }
 }
 
-
+void InstStrategyCegqi::setAuxiliaryVariables( Node q, std::vector< Node >& vars ) {
+  std::map< Node, CegInstantiator * >::iterator it = d_cinst.find( q );
+  if( it!=d_cinst.end() ){
+    Assert( it->second->d_aux_vars.empty() );
+    it->second->d_aux_vars.insert( it->second->d_aux_vars.end(), vars.begin(), vars.end() );
+  }else{
+    Assert( d_aux_variables.find( q )==d_aux_variables.end() );
+    d_aux_variables[q].insert( d_aux_variables[q].end(), vars.begin(), vars.end() );
+  }
+}
 
 
 
index 4f5049cd80b31a34fe316ee5bc61933a5168a6c7..64303e1f3e70cf32e7baa457efab944bd93f4686 100644 (file)
@@ -59,6 +59,11 @@ private:
   //program variable contains cache
   std::map< Node, std::map< Node, bool > > d_prog_var;
   std::map< Node, bool > d_inelig;
+  //current assertions
+  std::map< TheoryId, std::vector< Node > > d_curr_asserts;
+  std::map< Node, std::vector< Node > > d_curr_eqc;
+  std::map< Node, Node > d_curr_rep;
+  std::vector< Node > d_curr_arith_eqc;
 private:
   //for adding instantiations during check
   void computeProgVars( Node n );
@@ -75,10 +80,14 @@ private:
   Node applySubstitution( Node n, std::vector< Node >& subs, std::vector< Node >& vars,
                           std::vector< Node >& coeff, std::vector< Node >& has_coeff, Node& pv_coeff, bool try_coeff = true );
   Node getModelBasedProjectionValue( Node t, bool strict, bool isLower, Node c, Node me, Node mt, Node theta );
+  void processAssertions();
+  void addToAuxVarSubstitution( std::vector< Node >& subs_lhs, std::vector< Node >& subs_rhs, Node l, Node r );
 public:
   CegInstantiator( QuantifiersEngine * qe, CegqiOutput * out, bool use_vts_delta = true, bool use_vts_inf = true );
   //the CE variables
   std::vector< Node > d_vars;
+  //auxiliary variables
+  std::vector< Node > d_aux_vars;
   //check : add instantiations based on valuation of d_vars
   bool check();
 };
@@ -144,6 +153,7 @@ class InstStrategyCegqi : public InstStrategy {
 private:
   CegqiOutputInstStrategy * d_out;
   std::map< Node, CegInstantiator * > d_cinst;
+  std::map< Node, std::vector< Node > > d_aux_variables;
   Node d_small_const;
   Node d_curr_quant;
   bool d_check_vts_lemma_lc;
@@ -159,6 +169,9 @@ public:
   bool addLemma( Node lem );
   /** identify */
   std::string identify() const { return std::string("Cegqi"); }
+  
+  //set auxiliary variables
+  void setAuxiliaryVariables( Node q, std::vector< Node >& vars );
 };
 
 }
index b686ddb3bb05f7ae242e4b75e0c3ddfd5a0bd624..b12c822efed8a0cac4f8ea79491d0b3168b70682 100644 (file)
@@ -21,6 +21,7 @@
 #include "theory/quantifiers/inst_strategy_e_matching.h"
 #include "theory/quantifiers/inst_strategy_cbqi.h"
 #include "theory/quantifiers/trigger.h"
+#include "util/ite_removal.h"
 
 using namespace std;
 using namespace CVC4;
@@ -101,7 +102,26 @@ bool InstantiationEngine::doInstantiationRound( Theory::Effort effort ){
           //add counterexample lemma
           lem = Rewriter::rewrite( lem );
           Trace("cbqi") << "Counterexample lemma : " << lem << std::endl;
-          d_quantEngine->addLemma( lem, false );
+          
+          //must explicitly remove ITEs so that we record dependencies
+          IteSkolemMap iteSkolemMap;
+          std::vector< Node > lems;
+          lems.push_back( lem );
+          d_quantEngine->getTheoryEngine()->getIteRemover()->run(lems, iteSkolemMap);
+          std::vector< Node > aux_vars;
+          for( unsigned i=0; i<lems.size(); i++ ){
+            Trace("cbqi-debug") << "Counterexample lemma (processed) " << i << " : " << lems[i] << std::endl;
+            d_quantEngine->addLemma( lems[i], false );
+          }
+          for(IteSkolemMap::iterator i = iteSkolemMap.begin(); i != iteSkolemMap.end(); ++i) {
+            Trace("cbqi-debug") << "  Auxiliary var (from ITE) : " << i->first << std::endl;
+            aux_vars.push_back( i->first );
+          }
+          //record the auxiliary variables in the inst strategy
+          if( d_i_cegqi ){
+            d_i_cegqi->setAuxiliaryVariables( f, aux_vars );
+          }
+          
           addedLemma = true;
         }
       }
index 48a9fdea2100e9607739627327d758ce8ffc123a..b573ebd6710ca6a3e76d6a2dc2a09a26bdc77905 100644 (file)
@@ -246,6 +246,8 @@ option cbqiSat --cbqi-sat bool :read-write :default true
  answer sat when quantifiers are asserted with counterexample-based quantifier instantiation
 option cbqiModel --cbqi-model bool :read-write :default true
  guide instantiations by model values for counterexample-based quantifier instantiation
+option cbqiUseInf --cbqi-use-inf bool :read-write :default false
+ use infinity for vts in counterexample-based quantifier instantiation
  
 ### local theory extensions options 
 
index eefa4577041224a600ac82b3d9bbaacd37d8363f..db0328f4fef8cac82b93233047cef04dcff1f5ba 100644 (file)
@@ -1294,6 +1294,20 @@ Node TermDb::getCanonicalTerm( TNode n, bool apply_torder ){
   return getCanonicalTerm( n, var_count, subs, apply_torder );
 }
 
+void TermDb::getVtsTerms( std::vector< Node >& t, bool isFree, bool create, bool inc_delta ) {
+  if( inc_delta ){
+    Node delta = getVtsDelta( isFree, create );
+    if( !delta.isNull() ){
+      t.push_back( delta );
+    }
+  }
+  for( unsigned r=0; r<2; r++ ){
+    Node inf = getVtsInfinityIndex( r, isFree, create );
+    if( !inf.isNull() ){
+      t.push_back( inf );
+    }
+  }
+}
 
 Node TermDb::getVtsDelta( bool isFree, bool create ) {
   if( create ){
@@ -1309,30 +1323,62 @@ Node TermDb::getVtsDelta( bool isFree, bool create ) {
   return isFree ? d_vts_delta_free : d_vts_delta;
 }
 
-Node TermDb::getVtsInfinity( bool isFree, bool create ) {
+Node TermDb::getVtsInfinity( TypeNode tn, bool isFree, bool create ) {
   if( create ){
-    if( d_vts_inf_free.isNull() ){
-      d_vts_inf_free = NodeManager::currentNM()->mkSkolem( "inf", NodeManager::currentNM()->integerType(), "free infinity for virtual term substitution" );
+    if( d_vts_inf_free[tn].isNull() ){
+      d_vts_inf_free[tn] = NodeManager::currentNM()->mkSkolem( "inf", tn, "free infinity for virtual term substitution" );
     }
-    if( d_vts_inf.isNull() ){
-      d_vts_inf = NodeManager::currentNM()->mkSkolem( "inf", NodeManager::currentNM()->integerType(), "infinity for virtual term substitution" );
+    if( d_vts_inf[tn].isNull() ){
+      d_vts_inf[tn] = NodeManager::currentNM()->mkSkolem( "inf", tn, "infinity for virtual term substitution" );
     }
   }
-  return isFree ? d_vts_inf_free : d_vts_inf;
+  return isFree ? d_vts_inf_free[tn] : d_vts_inf[tn];
+}
+
+Node TermDb::getVtsInfinityIndex( int i, bool isFree, bool create ) {
+  if( i==0 ){
+    return getVtsInfinity( NodeManager::currentNM()->realType(), isFree, create );
+  }else if( i==1 ){
+    return getVtsInfinity( NodeManager::currentNM()->integerType(), isFree, create );
+  }else{
+    Assert( false );
+    return Node::null();
+  }
 }
 
 Node TermDb::rewriteVtsSymbols( Node n ) {
   if( ( n.getKind()==EQUAL || n.getKind()==GEQ ) ){
     Trace("quant-vts-debug") << "VTS : process " << n << std::endl;
-    bool rew_inf = false;
+    Node rew_vts_inf;
     bool rew_delta = false;
     //rewriting infinity always takes precedence over rewriting delta
-    if( !d_vts_inf.isNull() && containsTerm( n, d_vts_inf ) ){
-      rew_inf = true;
-    }else if( !d_vts_delta.isNull() && containsTerm( n, d_vts_delta ) ){
-      rew_delta = true;
+    for( unsigned r=0; r<2; r++ ){
+      Node inf = getVtsInfinityIndex( r, false, false );
+      if( !inf.isNull() && containsTerm( n, inf ) ){
+        if( rew_vts_inf.isNull() ){
+          rew_vts_inf = inf;
+        }else{
+          //for mixed int/real with multiple infinities
+          Trace("quant-vts-debug") << "Multiple infinities...equate " << inf << " = " << rew_vts_inf << std::endl;
+          std::vector< Node > subs_lhs;
+          subs_lhs.push_back( inf );
+          std::vector< Node > subs_rhs;
+          subs_lhs.push_back( rew_vts_inf );
+          n = n.substitute( subs_lhs.begin(), subs_lhs.end(), subs_rhs.begin(), subs_rhs.end() );
+          n = Rewriter::rewrite( n );
+          //may have cancelled
+          if( !containsTerm( n, rew_vts_inf ) ){
+            rew_vts_inf = Node::null();
+          }
+        }
+      }
+    }
+    if( rew_vts_inf.isNull() ){
+      if( !d_vts_delta.isNull() && containsTerm( n, d_vts_delta ) ){
+        rew_delta = true;
+      }
     }
-    if( rew_inf || rew_delta ){
+    if( !rew_vts_inf.isNull()  || rew_delta ){
       if( n.getKind()==EQUAL ){
         return d_false;
       }else{
@@ -1342,7 +1388,7 @@ Node TermDb::rewriteVtsSymbols( Node n ) {
             Trace("quant-vts-debug") << "VTS got monomial sum : " << std::endl;
             QuantArith::debugPrintMonomialSum( msum, "quant-vts-debug" );
           }
-          Node vts_sym = rew_inf ? d_vts_inf : d_vts_delta;
+          Node vts_sym = !rew_vts_inf.isNull() ? rew_vts_inf : d_vts_delta;
           Assert( !vts_sym.isNull() );
           Node iso_n;
           int res = QuantArith::isolate( vts_sym, msum, iso_n, n.getKind(), true );
@@ -1360,13 +1406,13 @@ Node TermDb::rewriteVtsSymbols( Node n ) {
             }
             Node nlit;
             if( res==1 ){
-              if( rew_inf ){
+              if( !rew_vts_inf.isNull() ){
                 nlit = d_true;
               }else{
                 nlit = NodeManager::currentNM()->mkNode( GEQ, d_zero, slv );
               }
             }else{
-              if( rew_inf ){
+              if( !rew_vts_inf.isNull() ){
                 nlit = d_false;
               }else{
                 nlit = NodeManager::currentNM()->mkNode( GT, slv, d_zero );
@@ -1382,16 +1428,13 @@ Node TermDb::rewriteVtsSymbols( Node n ) {
   }else if( n.getKind()==FORALL ){
     //cannot traverse beneath quantifiers
     std::vector< Node > vars;
+    getVtsTerms( vars, false );
     std::vector< Node > vars_free;
-    if( !d_vts_inf.isNull() ){
-      vars.push_back( d_vts_inf );
-      vars_free.push_back( d_vts_inf_free );
-    }
-    if( !d_vts_delta.isNull() ){
-      vars.push_back( d_vts_delta );
-      vars_free.push_back( d_vts_delta_free );
+    getVtsTerms( vars_free, true );
+    Assert( vars.size()==vars_free.size() );
+    if( !vars.empty() ){
+      n = n.substitute( vars.begin(), vars.end(), vars_free.begin(), vars_free.end() );
     }
-    n = n.substitute( vars.begin(), vars.end(), vars_free.begin(), vars_free.end() );
     return n;
   }else{
     bool childChanged = false;
@@ -1414,19 +1457,77 @@ Node TermDb::rewriteVtsSymbols( Node n ) {
   }
 }
 
-bool TermDb::containsTerm( Node n, Node t ) {
+bool TermDb::containsVtsTerm( Node n, bool isFree ) {
+  std::vector< Node > t;
+  getVtsTerms( t, isFree, false );
+  return containsTerms( n, t );
+}
+
+bool TermDb::containsVtsTerm( std::vector< Node >& n, bool isFree ) {
+  std::vector< Node > t;
+  getVtsTerms( t, isFree, false );
+  if( !t.empty() ){
+    for( unsigned i=0; i<n.size(); i++ ){
+      if( containsTerms( n[i], t ) ){
+        return true;
+      }
+    }
+  }
+  return false;
+}
+
+bool TermDb::containsVtsInfinity( Node n, bool isFree ) {
+  std::vector< Node > t;
+  getVtsTerms( t, isFree, false, false );
+  return containsTerms( n, t );
+}
+
+bool TermDb::containsTerm2( Node n, Node t, std::map< Node, bool >& visited ) {
   if( n==t ){
     return true;
   }else{
-    for( unsigned i=0; i<n.getNumChildren(); i++ ){
-      if( containsTerm( n[i], t ) ){
-        return true;
+    if( visited.find( n )==visited.end() ){
+      visited[n] = true;
+      for( unsigned i=0; i<n.getNumChildren(); i++ ){
+        if( containsTerm2( n[i], t, visited ) ){
+          return true;
+        }
       }
     }
     return false;
   }
 }
 
+bool TermDb::containsTerms2( Node n, std::vector< Node >& t, std::map< Node, bool >& visited ) {
+  if( std::find( t.begin(), t.end(), n )!=t.end() ){
+    return true;
+  }else{
+    if( visited.find( n )==visited.end() ){
+      visited[n] = true;
+      for( unsigned i=0; i<n.getNumChildren(); i++ ){
+        if( containsTerms2( n[i], t, visited ) ){
+          return true;
+        }
+      }
+    }
+    return false;
+  }
+}
+
+bool TermDb::containsTerm( Node n, Node t ) {
+  std::map< Node, bool > visited;
+  return containsTerm2( n, t, visited );
+}
+
+bool TermDb::containsTerms( Node n, std::vector< Node >& t ) {
+  if( t.empty() ){
+    return false;
+  }else{
+    std::map< Node, bool > visited;
+    return containsTerms2( n, t, visited );
+  }
+}
+
 Node TermDb::simpleNegate( Node n ){
   if( n.getKind()==OR || n.getKind()==AND ){
     std::vector< Node > children;
index 416761ce886fcbc3f6ddf9756042cddb01ad1903..529207390972dc08e2515620c51cbc4ae83033ad 100644 (file)
@@ -348,21 +348,37 @@ public:
 //for virtual term substitution
 private:
   Node d_vts_delta;
-  Node d_vts_inf;
+  std::map< TypeNode, Node > d_vts_inf;
   Node d_vts_delta_free;
-  Node d_vts_inf_free;
+  std::map< TypeNode, Node > d_vts_inf_free;
+  /** get vts infinity index */
+  Node getVtsInfinityIndex( int i, bool isFree = false, bool create = true  );
 public:
   /** get vts delta */
   Node getVtsDelta( bool isFree = false, bool create = true );
   /** get vts infinity */
-  Node getVtsInfinity( bool isFree = false, bool create = true );
+  Node getVtsInfinity( TypeNode tn, bool isFree = false, bool create = true );
+  /** get all vts terms */
+  void getVtsTerms( std::vector< Node >& t, bool isFree = false, bool create = true, bool inc_delta = true );
   /** rewrite delta */
   Node rewriteVtsSymbols( Node n );
-
+  /** simple check for contains term */
+  bool containsVtsTerm( Node n, bool isFree = false );
+  /** simple check for contains term */
+  bool containsVtsTerm( std::vector< Node >& n, bool isFree = false );
+  /** simple check for contains term */
+  bool containsVtsInfinity( Node n, bool isFree = false );
+  
+private:
+  //helper for contains term
+  static bool containsTerm2( Node n, Node t, std::map< Node, bool >& visited );
+  static bool containsTerms2( Node n, std::vector< Node >& t, std::map< Node, bool >& visited );
 //general utilities
 public:
   /** simple check for contains term */
   static bool containsTerm( Node n, Node t );
+  /** simple check for contains term */
+  static bool containsTerms( Node n, std::vector< Node >& t );
   /** simple negate */
   static Node simpleNegate( Node n );
   /** is assoc */
index 0c1a7c0811bdcb29260f2b58cac6d5b84b685b1b..96a99763d0ab4f0a30aecaeee750d6da8a076898 100644 (file)
@@ -830,6 +830,8 @@ public:
   SharedTermsDatabase* getSharedTermsDatabase() { return &d_sharedTerms; }
 
   theory::eq::EqualityEngine* getMasterEqualityEngine() { return d_masterEqualityEngine; }
+  
+  RemoveITE* getIteRemover() { return &d_iteRemover; }
 
   SortInference* getSortInference() { return &d_sortInfer; }
 private: