Faster conditional rewriting for and/or beneath quantifiers. Improvements to sort...
authorajreynol <andrew.j.reynolds@gmail.com>
Thu, 10 Mar 2016 23:49:13 +0000 (17:49 -0600)
committerajreynol <andrew.j.reynolds@gmail.com>
Thu, 10 Mar 2016 23:49:13 +0000 (17:49 -0600)
17 files changed:
src/options/quantifiers_options
src/smt/smt_engine.cpp
src/theory/quantifiers/candidate_generator.cpp
src/theory/quantifiers/inst_match_generator.cpp
src/theory/quantifiers/inst_strategy_e_matching.cpp
src/theory/quantifiers/quant_conflict_find.cpp
src/theory/quantifiers/quant_conflict_find.h
src/theory/quantifiers/quant_util.h
src/theory/quantifiers/quantifiers_rewriter.cpp
src/theory/quantifiers/relevant_domain.cpp
src/theory/quantifiers/term_database.cpp
src/theory/quantifiers/term_database.h
src/theory/quantifiers/trigger.cpp
src/theory/quantifiers_engine.cpp
src/theory/quantifiers_engine.h
src/theory/sort_inference.cpp
src/theory/sort_inference.h

index e3f4e94f2373baaf06245d97634cc15cdd711442..1363626c694f267e8e81edda3f41f020ed4ebc53 100644 (file)
@@ -57,6 +57,8 @@ option elimTautQuant --elim-taut-quant bool :default true
  eliminate tautological disjuncts of quantified formulas
 option purifyQuant --purify-quant bool :default false
  purify quantified formulas
+option elimExtArithQuant --elim-ext-arith-quant bool :default true
+ eliminate extended arithmetic symbols in quantified formulas
  
 #### E-matching options
  
@@ -67,6 +69,8 @@ option termDbMode --term-db-mode CVC4::theory::quantifiers::TermDbMode :default
  which ground terms to consider for instantiation
 option registerQuantBodyTerms --register-quant-body-terms bool :default false
  consider ground terms within bodies of quantified formulas for matching
+option inferArithTriggerEq --infer-arith-trigger-eq bool :default false
+ infer equalities for trigger terms based on solving arithmetic equalities
  
 option smartTriggers --smart-triggers bool :default true
  enable smart triggers
@@ -95,6 +99,8 @@ option incrementTriggers --increment-triggers bool :default true
  
 option instWhenMode --inst-when=MODE CVC4::theory::quantifiers::InstWhenMode :default CVC4::theory::quantifiers::INST_WHEN_FULL_LAST_CALL :read-write :include "options/quantifiers_modes.h" :handler stringToInstWhenMode :predicate checkInstWhenMode
  when to apply instantiation
+option instWhenDelayIncrement --inst-when-delay-increment bool :default false
+ delay incrementing counters for inst-when mode to ensure theory combination and standard quantifier effort strategies take turns
  
 option instMaxLevel --inst-max-level=N int :read-write :default -1
  maximum inst level of terms used to instantiate quantified formulas with (-1 == no limit, default)
index 201585070e6f0545c7da3af6427dbf2ef1c131d8..93623408ecb0d8f6a701f7410657fa14cc1b0469 100644 (file)
@@ -1809,23 +1809,22 @@ void SmtEngine::setDefaults() {
     if( !options::rewriteDivk.wasSetByUser()) {
       options::rewriteDivk.set( true );
     }
-  }
-  if( options::cbqi() && d_logic.isPure(THEORY_ARITH) ){
-    options::cbqiAll.set( false );
-    if( !options::quantConflictFind.wasSetByUser() ){
-      options::quantConflictFind.set( false );
-    }
-    if( !options::instNoEntail.wasSetByUser() ){
-      options::instNoEntail.set( false );
-    }
-    if( !options::instWhenMode.wasSetByUser() && options::cbqiModel() ){
-      //only instantiation should happen at last call when model is avaiable
-      if( !options::instWhenMode.wasSetByUser() ){
-        options::instWhenMode.set( quantifiers::INST_WHEN_LAST_CALL );
+    if( d_logic.isPure(THEORY_ARITH) ){
+      options::cbqiAll.set( false );
+      if( !options::quantConflictFind.wasSetByUser() ){
+        options::quantConflictFind.set( false );
+      }
+      if( !options::instNoEntail.wasSetByUser() ){
+        options::instNoEntail.set( false );
+      }
+      if( !options::instWhenMode.wasSetByUser() && options::cbqiModel() ){
+        //only instantiation should happen at last call when model is avaiable
+        if( !options::instWhenMode.wasSetByUser() ){
+          options::instWhenMode.set( quantifiers::INST_WHEN_LAST_CALL );
+        }
       }
     }
   }
-
   //implied options...
   if( options::qcfMode.wasSetByUser() || options::qcfTConstraint() ){
     options::quantConflictFind.set( true );
index 0cdb22be45c5c9359d0c88a7dc0b2c91fa0139ba..680be77daa27f6459c102b4f8b3d7480d2a9d7ee 100644 (file)
@@ -90,7 +90,7 @@ void CandidateGeneratorQE::reset( Node eqc ){
 bool CandidateGeneratorQE::isLegalOpCandidate( Node n ) {
   if( n.hasOperator() ){
     if( isLegalCandidate( n ) ){
-      return d_qe->getTermDatabase()->getOperator( n )==d_op;
+      return d_qe->getTermDatabase()->getMatchOperator( n )==d_op;
     }
   }
   return false;
index 89c2d4868369debae7c2b5e0434ffd5f69cf1d94..41c62192f55418867824b7e959e9dcb71624eb7f 100644 (file)
@@ -97,11 +97,11 @@ void InstMatchGenerator::initialize( Node q, QuantifiersEngine* qe, std::vector<
     }
     d_match_pattern_type = d_match_pattern.getType();
     Trace("inst-match-gen") << "Pattern is " << d_pattern << ", match pattern is " << d_match_pattern << std::endl;
-    d_match_pattern_op = qe->getTermDatabase()->getOperator( d_match_pattern );
+    d_match_pattern_op = qe->getTermDatabase()->getMatchOperator( d_match_pattern );
 
     //now, collect children of d_match_pattern
     //int childMatchPolicy = MATCH_GEN_DEFAULT;
-    for( int i=0; i<(int)d_match_pattern.getNumChildren(); i++ ){
+    for( unsigned i=0; i<d_match_pattern.getNumChildren(); i++ ){
       Node qa = quantifiers::TermDb::getInstConstAttr(d_match_pattern[i]);
       if( !qa.isNull() ){
         InstMatchGenerator * cimg = Trigger::getInstMatchGenerator( q, d_match_pattern[i] );
@@ -129,7 +129,7 @@ void InstMatchGenerator::initialize( Node q, QuantifiersEngine* qe, std::vector<
     //create candidate generator
     if( d_match_pattern.getKind()==INST_CONSTANT ){
       if( d_pattern.getKind()==APPLY_SELECTOR_TOTAL ){
-        Expr selectorExpr = qe->getTermDatabase()->getOperator( d_pattern ).toExpr();
+        Expr selectorExpr = qe->getTermDatabase()->getMatchOperator( d_pattern ).toExpr();
         size_t selectorIndex = Datatype::cindexOf(selectorExpr);
         const Datatype& dt = Datatype::datatypeOf(selectorExpr);
         const DatatypeConstructor& c = dt[selectorIndex];
@@ -197,7 +197,7 @@ bool InstMatchGenerator::getMatch( Node f, Node t, InstMatch& m, QuantifiersEngi
     Assert( !Trigger::isAtomicTrigger( d_match_pattern ) || t.getOperator()==d_match_pattern.getOperator() );
     //first, check if ground arguments are not equal, or a match is in conflict
     Trace("matching-debug2") << "Setting immediate matches..." << std::endl;
-    for( int i=0; i<(int)d_match_pattern.getNumChildren(); i++ ){
+    for( unsigned i=0; i<d_match_pattern.getNumChildren(); i++ ){
       if( d_children_types[i]==0 ){
         Trace("matching-debug2") << "Setting " << d_var_num[i] << " to " << t[i] << "..." << std::endl;
         bool addToPrev = m.get( d_var_num[i] ).isNull();
@@ -683,7 +683,7 @@ int InstMatchGeneratorMulti::addTerm( Node q, Node t, QuantifiersEngine* qe ){
   Assert( options::eagerInstQuant() );
   int addedLemmas = 0;
   for( int i=0; i<(int)d_children.size(); i++ ){
-    Node t_op = qe->getTermDatabase()->getOperator( t );
+    Node t_op = qe->getTermDatabase()->getMatchOperator( t );
     if( ((InstMatchGenerator*)d_children[i])->d_match_pattern_op==t_op ){
       InstMatch m( q );
       //if it produces a match, then process it with the rest
@@ -709,7 +709,7 @@ InstMatchGeneratorSimple::InstMatchGeneratorSimple( Node q, Node pat ) : d_f( q
 }
 
 void InstMatchGeneratorSimple::resetInstantiationRound( QuantifiersEngine* qe ) {
-  d_op = qe->getTermDatabase()->getOperator( d_match_pattern );
+  d_op = qe->getTermDatabase()->getMatchOperator( d_match_pattern );
 }
 
 int InstMatchGeneratorSimple::addInstantiations( Node q, InstMatch& baseMatch, QuantifiersEngine* qe ){
@@ -763,9 +763,10 @@ void InstMatchGeneratorSimple::addInstantiations( InstMatch& m, QuantifiersEngin
 }
 
 int InstMatchGeneratorSimple::addTerm( Node q, Node t, QuantifiersEngine* qe ){
+  //for eager instantiation only
   Assert( options::eagerInstQuant() );
   InstMatch m( q );
-  for( int i=0; i<(int)t.getNumChildren(); i++ ){
+  for( unsigned i=0; i<t.getNumChildren(); i++ ){
     if( d_match_pattern[i].getKind()==INST_CONSTANT ){
       m.setValue(d_var_num[i], t[i]);
     }else if( !qe->getEqualityQuery()->areEqual( d_match_pattern[i], t[i] ) ){
index 299eb51fdbf88da679ce513c76c2a295faf8bb70..621327c0b8ca13f36cea458b590b06842bd6d746 100644 (file)
@@ -257,9 +257,8 @@ void InstStrategyAutoGenTriggers::generateTriggers( Node f ){
       Node bd = d_quantEngine->getTermDatabase()->getInstConstantBody( f );
       Trigger::collectPatTerms( d_quantEngine, f, bd, patTermsF, d_tr_strategy, d_user_no_gen[f], true );
       Trace("auto-gen-trigger-debug") << "Collected pat terms for " << bd << ", no-patterns : " << d_user_no_gen[f].size() << std::endl;
-      Trace("auto-gen-trigger-debug") << "   ";
       for( int i=0; i<(int)patTermsF.size(); i++ ){
-        Trace("auto-gen-trigger-debug") << patTermsF[i] << " ";
+        Trace("auto-gen-trigger-debug") << "   " << patTermsF[i] << std::endl;
       }
       Trace("auto-gen-trigger-debug") << std::endl;
       if( ntrivTriggers ){
index 779c0c44e7ae67aa7634e46e0bd4406551394b92..93cd4be91c6cf620adcc1ca07824b20738e058e3 100644 (file)
@@ -843,7 +843,7 @@ MatchGen::MatchGen( QuantInfo * qi, Node n, bool isVar )
       d_qni_size++;
       d_type_not = false;
       d_n = n;
-      //Node f = getOperator( n );
+      //Node f = getMatchOperator( n );
       for( unsigned j=0; j<d_n.getNumChildren(); j++ ){
         Node nn = d_n[j];
         Trace("qcf-qregister-debug") << "  " << d_qni_size;
@@ -1106,7 +1106,7 @@ void MatchGen::reset( QuantConflictFind * p, bool tgt, QuantInfo * qi ) {
     }
   }else if( d_type==typ_var ){
     Assert( isHandledUfTerm( d_n ) );
-    Node f = getOperator( p, d_n );
+    Node f = getMatchOperator( p, d_n );
     Debug("qcf-match-debug") << "       reset: Var will match operators of " << f << std::endl;
     TermArgTrie * qni = p->getTermDatabase()->getTermArgTrie( Node::null(), f );
     if( qni!=NULL ){
@@ -1339,7 +1339,7 @@ bool MatchGen::getNextMatch( QuantConflictFind * p, QuantInfo * qi ) {
       /*
       if( d_type==typ_var && p->d_effort==QuantConflictFind::effort_mc && !d_matched_basis ){
         d_matched_basis = true;
-        Node f = getOperator( d_n );
+        Node f = getMatchOperator( d_n );
         TNode mbo = p->getTermDatabase()->getModelBasisOpTerm( f );
         if( qi->setMatch( p, d_qni_var_num[0], mbo ) ){
           success = true;
@@ -1702,9 +1702,9 @@ bool MatchGen::isHandledUfTerm( TNode n ) {
   return inst::Trigger::isAtomicTriggerKind( n.getKind() );
 }
 
-Node MatchGen::getOperator( QuantConflictFind * p, Node n ) {
+Node MatchGen::getMatchOperator( QuantConflictFind * p, Node n ) {
   if( isHandledUfTerm( n ) ){
-    return p->getTermDatabase()->getOperator( n );
+    return p->getTermDatabase()->getMatchOperator( n );
   }else{
     return Node::null();
   }
@@ -1896,7 +1896,7 @@ void QuantConflictFind::assertNode( Node q ) {
 
 Node QuantConflictFind::evaluateTerm( Node n ) {
   if( MatchGen::isHandledUfTerm( n ) ){
-    Node f = MatchGen::getOperator( this, n );
+    Node f = MatchGen::getMatchOperator( this, n );
     Node nn;
     if( getEqualityEngine()->hasTerm( n ) ){
       nn = getTermDatabase()->existsTerm( f, n );
index 11299b532d1a4b8eb78999cdd82b9f42aac543b6..4bcc59bde501065233590c5bf240f5c716724d77 100644 (file)
@@ -98,7 +98,7 @@ public:
   // is this term treated as UF application?
   static bool isHandledBoolConnective( TNode n );
   static bool isHandledUfTerm( TNode n );
-  static Node getOperator( QuantConflictFind * p, Node n );
+  static Node getMatchOperator( QuantConflictFind * p, Node n );
   //can this node be handled by the algorithm
   static bool isHandled( TNode n );
 };
index 566a099235534edfe8e2a6ebf3e2d356ec2a689b..b4cf54dfd9e395993973b59f64592c00b4ceb7f1 100644 (file)
@@ -99,8 +99,6 @@ class EqualityQuery {
 public:
   EqualityQuery(){}
   virtual ~EqualityQuery(){};
-  /** reset */
-  virtual void reset() = 0;
   /** contains term */
   virtual bool hasTerm( Node a ) = 0;
   /** get the representative of the equivalence class of a */
index ff55c5c9b23dcdf563a5e3b0d32e529da9a72a16..c10ba944b21af874d817e3da872bed9c98b1cce4 100644 (file)
@@ -554,6 +554,16 @@ void setEntailedCond( Node n, bool pol, std::map< Node, bool >& currCond, std::v
   }
 }
 
+void removeEntailedCond( std::map< Node, bool >& currCond, std::vector< Node >& new_cond, std::map< Node, Node >& cache ) {
+  if( !new_cond.empty() ){
+    for( unsigned j=0; j<new_cond.size(); j++ ){
+      currCond.erase( new_cond[j] );
+    }
+    new_cond.clear();
+    cache.clear();
+  }
+}
+
 Node QuantifiersRewriter::computeProcessTerms( Node body, std::vector< Node >& new_vars, std::vector< Node >& new_conds, Node q, QAttributes& qa ){
   std::map< Node, bool > curr_cond;
   std::map< Node, Node > cache;
@@ -582,96 +592,119 @@ Node QuantifiersRewriter::computeProcessTerms2( Node body, bool hasPol, bool pol
     ret = iti->second;
     Trace("quantifiers-rewrite-term-debug2") << "Return (cached) " << ret << " for " << body << std::endl;
   }else{
-    bool firstTimeCD = true;
+    //only do context dependent processing up to depth 8
+    bool doCD = nCurrCond<8;
     bool changed = false;
     std::vector< Node > children;
-    for( size_t i=0; i<body.getNumChildren(); i++ ){
-      std::vector< Node > new_cond;
+    //set entailed conditions based on OR/AND
+    std::map< int, std::vector< Node > > new_cond_children;
+    if( doCD && ( body.getKind()==OR || body.getKind()==AND ) ){
+      nCurrCond = nCurrCond + 1;
       bool conflict = false;
-      //only do context dependent processing up to depth 8
-      if( nCurrCond<8 ){
-        if( firstTimeCD ){
-          firstTimeCD = false;
-          nCurrCond = nCurrCond + 1;
-        }
-        if( Trace.isOn("quantifiers-rewrite-term-debug") ){
-          //if( ( body.getKind()==ITE && i>0 ) || ( hasPol && ( ( body.getKind()==OR && pol ) || (body.getKind()==AND && !pol ) ) ) ){
-          if( ( body.getKind()==ITE && i>0 ) || body.getKind()==OR || body.getKind()==AND ){
-            Trace("quantifiers-rewrite-term-debug") << "---rewrite " << body[i] << " under conditions:----" << std::endl;
+      bool use_pol = body.getKind()==AND;
+      for( unsigned j=0; j<body.getNumChildren(); j++ ){
+        setEntailedCond( body[j], use_pol, currCond, new_cond_children[j], conflict );
+      }
+      if( conflict ){
+        Trace("quantifiers-rewrite-term-debug") << "-------conflict, return " << !use_pol << std::endl;
+        ret = NodeManager::currentNM()->mkConst( !use_pol );
+      }
+    }
+    if( ret.isNull() ){
+      for( size_t i=0; i<body.getNumChildren(); i++ ){
+      
+        //set/update entailed conditions
+        std::vector< Node > new_cond;
+        bool conflict = false;
+        if( doCD ){
+          if( Trace.isOn("quantifiers-rewrite-term-debug") ){
+            if( ( body.getKind()==ITE && i>0 ) || body.getKind()==OR || body.getKind()==AND ){
+              Trace("quantifiers-rewrite-term-debug") << "---rewrite " << body[i] << " under conditions:----" << std::endl;
+            }
           }
-        }
-        if( body.getKind()==ITE && i>0 ){
-          setEntailedCond( children[0], i==1, currCond, new_cond, conflict );
-          //should not conflict (entailment check failed) 
-          Assert( !conflict );
-        }
-        //if( hasPol && ( ( body.getKind()==OR && pol ) || ( body.getKind()==AND && !pol ) ) ){
-        //  bool use_pol = !pol;
-        if( body.getKind()==OR || body.getKind()==AND ){
-          bool use_pol = body.getKind()==AND;
-          for( unsigned j=0; j<body.getNumChildren(); j++ ){
-            if( j<i ){
-              setEntailedCond( children[j], use_pol, currCond, new_cond, conflict );
-            }else if( j>i ){
-              setEntailedCond( body[j], use_pol, currCond, new_cond, conflict );
+          if( body.getKind()==ITE && i>0 ){
+            if( i==1 ){
+              nCurrCond = nCurrCond + 1;
+            }
+            setEntailedCond( children[0], i==1, currCond, new_cond, conflict );
+            //should not conflict (entailment check failed) 
+            Assert( !conflict );
+          }
+          if( body.getKind()==OR || body.getKind()==AND ){
+            bool use_pol = body.getKind()==AND;
+            //remove the current condition
+            removeEntailedCond( currCond, new_cond_children[i], cache );
+            if( i>0 ){
+              //add the previous condition
+              setEntailedCond( children[i-1], use_pol, currCond, new_cond_children[i-1], conflict );
+            }
+            if( conflict ){
+              Trace("quantifiers-rewrite-term-debug") << "-------conflict, return " << !use_pol << std::endl;
+              ret = NodeManager::currentNM()->mkConst( !use_pol );
             }
           }
-          if( conflict ){
-            Trace("quantifiers-rewrite-term-debug") << "-------conflict, return " << !use_pol << std::endl;
-            ret = NodeManager::currentNM()->mkConst( !use_pol );
+          if( !new_cond.empty() ){
+            cache.clear();
           }
-        }
-        if( !new_cond.empty() ){
-          cache.clear();
-        }
-        if( Trace.isOn("quantifiers-rewrite-term-debug") ){
-          //if( ( body.getKind()==ITE && i>0 ) || ( hasPol && ( ( body.getKind()==OR && pol ) || (body.getKind()==AND && !pol ) ) ) ){
-          if( ( body.getKind()==ITE && i>0 ) || body.getKind()==OR || body.getKind()==AND ){      
-            Trace("quantifiers-rewrite-term-debug") << "-------" << std::endl;
+          if( Trace.isOn("quantifiers-rewrite-term-debug") ){
+            if( ( body.getKind()==ITE && i>0 ) || body.getKind()==OR || body.getKind()==AND ){      
+              Trace("quantifiers-rewrite-term-debug") << "-------" << std::endl;
+            }
           }
         }
-      }
-      if( !conflict ){
-        bool newHasPol;
-        bool newPol;
-        QuantPhaseReq::getPolarity( body, i, hasPol, pol, newHasPol, newPol );
-        Node nn = computeProcessTerms2( body[i], newHasPol, newPol, currCond, nCurrCond, cache, icache, new_vars, new_conds );
-        if( body.getKind()==ITE && i==0 ){
-          int res = getEntailedCond( nn, currCond );
-          Trace("quantifiers-rewrite-term-debug") << "Condition for " << body << " is " << nn << ", entailment check=" << res << std::endl;
-          if( res==1 ){
-            ret = computeProcessTerms2( body[1], hasPol, pol, currCond, nCurrCond, cache, icache, new_vars, new_conds );
-          }else if( res==-1 ){
-            ret = computeProcessTerms2( body[2], hasPol, pol, currCond, nCurrCond, cache, icache, new_vars, new_conds );
+        
+        //do the recursive call on children
+        if( !conflict ){
+          bool newHasPol;
+          bool newPol;
+          QuantPhaseReq::getPolarity( body, i, hasPol, pol, newHasPol, newPol );
+          Node nn = computeProcessTerms2( body[i], newHasPol, newPol, currCond, nCurrCond, cache, icache, new_vars, new_conds );
+          if( body.getKind()==ITE && i==0 ){
+            int res = getEntailedCond( nn, currCond );
+            Trace("quantifiers-rewrite-term-debug") << "Condition for " << body << " is " << nn << ", entailment check=" << res << std::endl;
+            if( res==1 ){
+              ret = computeProcessTerms2( body[1], hasPol, pol, currCond, nCurrCond, cache, icache, new_vars, new_conds );
+            }else if( res==-1 ){
+              ret = computeProcessTerms2( body[2], hasPol, pol, currCond, nCurrCond, cache, icache, new_vars, new_conds );
+            }
           }
+          children.push_back( nn );
+          changed = changed || nn!=body[i];
+        }
+        
+        //clean up entailed conditions
+        removeEntailedCond( currCond, new_cond, cache );
+        
+        if( !ret.isNull() ){
+          break;
         }
-        children.push_back( nn );
-        changed = changed || nn!=body[i];
       }
-      if( !new_cond.empty() ){
-        for( unsigned j=0; j<new_cond.size(); j++ ){
-          currCond.erase( new_cond[j] );
+      
+      //make return value
+      if( ret.isNull() ){
+        if( changed ){
+          if( body.getMetaKind() == kind::metakind::PARAMETERIZED ){
+            children.insert( children.begin(), body.getOperator() );
+          }
+          ret = NodeManager::currentNM()->mkNode( body.getKind(), children );
+        }else{
+          ret = body;
         }
-        cache.clear();
-      }
-      if( !ret.isNull() ){
-        break;
       }
     }
-    if( ret.isNull() ){
-      if( changed ){
-        if( body.getMetaKind() == kind::metakind::PARAMETERIZED ){
-          children.insert( children.begin(), body.getOperator() );
-        }
-        ret = NodeManager::currentNM()->mkNode( body.getKind(), children );
-      }else{
-        ret = body;
+    
+    //clean up entailed conditions
+    if( body.getKind()==OR || body.getKind()==AND ){
+      for( unsigned j=0; j<body.getNumChildren(); j++ ){
+        removeEntailedCond( currCond, new_cond_children[j], cache );
       }
     }
+    
     Trace("quantifiers-rewrite-term-debug2") << "Returning " << ret << " for " << body << std::endl;
     cache[body] = ret;
   }
 
+  //do context-independent rewriting
   iti = icache.find( ret );
   if( iti!=icache.end() ){
     return iti->second;
@@ -701,46 +734,60 @@ Node QuantifiersRewriter::computeProcessTerms2( Node body, bool hasPol, bool pol
           }
         }
       }
-    }else if( ret.getKind()==INTS_DIVISION_TOTAL || ret.getKind()==INTS_MODULUS_TOTAL ){
-      Node num = ret[0];
-      Node den = ret[1];
-      if(den.isConst()) {
-        const Rational& rat = den.getConst<Rational>();
-        Assert(!num.isConst());
-        if(rat != 0) {
-          Node intVar = NodeManager::currentNM()->mkBoundVar(NodeManager::currentNM()->integerType());
-          new_vars.push_back( intVar );
-          Node cond;
-          if(rat > 0) {
-            cond = NodeManager::currentNM()->mkNode(kind::AND,
-                     NodeManager::currentNM()->mkNode(kind::LEQ, NodeManager::currentNM()->mkNode(kind::MULT, den, intVar), num),
-                     NodeManager::currentNM()->mkNode(kind::LT, num,
-                       NodeManager::currentNM()->mkNode(kind::MULT, den, NodeManager::currentNM()->mkNode(kind::PLUS, intVar, NodeManager::currentNM()->mkConst(Rational(1))))));
-          } else {
-            cond = NodeManager::currentNM()->mkNode(kind::AND,
-                     NodeManager::currentNM()->mkNode(kind::LEQ, NodeManager::currentNM()->mkNode(kind::MULT, den, intVar), num),
-                     NodeManager::currentNM()->mkNode(kind::LT, num,
-                       NodeManager::currentNM()->mkNode(kind::MULT, den, NodeManager::currentNM()->mkNode(kind::PLUS, intVar, NodeManager::currentNM()->mkConst(Rational(-1))))));
-          }
-          new_conds.push_back( cond.negate() );
-          if( ret.getKind()==INTS_DIVISION_TOTAL ){
-            ret = intVar;
-          }else{
-            ret = NodeManager::currentNM()->mkNode(kind::MINUS, num, NodeManager::currentNM()->mkNode(kind::MULT, den, intVar));
+    /* ITE lifting
+    if( ret.getKind()==ITE ){
+      TypeNode ite_t = ret[1].getType();
+      if( !ite_t.isBoolean() ){
+        ite_t = TypeNode::leastCommonTypeNode( ite_t, ret[2].getType() );
+        Node ite_v = NodeManager::currentNM()->mkBoundVar(ite_t);
+        new_vars.push_back( ite_v );
+        Node cond = NodeManager::currentNM()->mkNode(kind::ITE, ret[0], ite_v.eqNode( ret[1] ), ite_v.eqNode( ret[2] ) );
+        new_conds.push_back( cond.negate() );
+        ret = ite_v;
+      }
+      */
+    }else if( options::elimExtArithQuant() ){
+      if( ret.getKind()==INTS_DIVISION_TOTAL || ret.getKind()==INTS_MODULUS_TOTAL ){
+        Node num = ret[0];
+        Node den = ret[1];
+        if(den.isConst()) {
+          const Rational& rat = den.getConst<Rational>();
+          Assert(!num.isConst());
+          if(rat != 0) {
+            Node intVar = NodeManager::currentNM()->mkBoundVar(NodeManager::currentNM()->integerType());
+            new_vars.push_back( intVar );
+            Node cond;
+            if(rat > 0) {
+              cond = NodeManager::currentNM()->mkNode(kind::AND,
+                       NodeManager::currentNM()->mkNode(kind::LEQ, NodeManager::currentNM()->mkNode(kind::MULT, den, intVar), num),
+                       NodeManager::currentNM()->mkNode(kind::LT, num,
+                         NodeManager::currentNM()->mkNode(kind::MULT, den, NodeManager::currentNM()->mkNode(kind::PLUS, intVar, NodeManager::currentNM()->mkConst(Rational(1))))));
+            } else {
+              cond = NodeManager::currentNM()->mkNode(kind::AND,
+                       NodeManager::currentNM()->mkNode(kind::LEQ, NodeManager::currentNM()->mkNode(kind::MULT, den, intVar), num),
+                       NodeManager::currentNM()->mkNode(kind::LT, num,
+                         NodeManager::currentNM()->mkNode(kind::MULT, den, NodeManager::currentNM()->mkNode(kind::PLUS, intVar, NodeManager::currentNM()->mkConst(Rational(-1))))));
+            }
+            new_conds.push_back( cond.negate() );
+            if( ret.getKind()==INTS_DIVISION_TOTAL ){
+              ret = intVar;
+            }else{
+              ret = NodeManager::currentNM()->mkNode(kind::MINUS, num, NodeManager::currentNM()->mkNode(kind::MULT, den, intVar));
+            }
           }
         }
-      }
-    }else if( ret.getKind()==TO_INTEGER || ret.getKind()==IS_INTEGER ){
-      Node intVar = NodeManager::currentNM()->mkBoundVar(NodeManager::currentNM()->integerType());
-      new_vars.push_back( intVar );
-      new_conds.push_back(NodeManager::currentNM()->mkNode(kind::AND,
-                            NodeManager::currentNM()->mkNode(kind::LT,
-                              NodeManager::currentNM()->mkNode(kind::MINUS, ret[0], NodeManager::currentNM()->mkConst(Rational(1))), intVar),
-                            NodeManager::currentNM()->mkNode(kind::LEQ, intVar, ret[0])).negate());
-      if( ret.getKind()==TO_INTEGER ){
-        ret = intVar;
-      }else{
-        ret = ret[0].eqNode( intVar );
+      }else if( ret.getKind()==TO_INTEGER || ret.getKind()==IS_INTEGER ){
+        Node intVar = NodeManager::currentNM()->mkBoundVar(NodeManager::currentNM()->integerType());
+        new_vars.push_back( intVar );
+        new_conds.push_back(NodeManager::currentNM()->mkNode(kind::AND,
+                              NodeManager::currentNM()->mkNode(kind::LT,
+                                NodeManager::currentNM()->mkNode(kind::MINUS, ret[0], NodeManager::currentNM()->mkConst(Rational(1))), intVar),
+                              NodeManager::currentNM()->mkNode(kind::LEQ, intVar, ret[0])).negate());
+        if( ret.getKind()==TO_INTEGER ){
+          ret = intVar;
+        }else{
+          ret = ret[0].eqNode( intVar );
+        }
       }
     }
     icache[prev] = ret;
index 88793358ef7e33f79bdcb05446c2afcfa241dfc8..ce7a38fa57603a731583ccf89ee10217c23f6d8a 100644 (file)
@@ -140,7 +140,7 @@ void RelevantDomain::compute(){
 }
 
 void RelevantDomain::computeRelevantDomain( Node q, Node n, bool hasPol, bool pol ) {
-  Node op = d_qe->getTermDatabase()->getOperator( n );
+  Node op = d_qe->getTermDatabase()->getMatchOperator( n );
   for( unsigned i=0; i<n.getNumChildren(); i++ ){
     if( !op.isNull() ){
       RDomain * rf = getRDomain( op, i );
index be0f60654719cd8259a3f2e56113c4fb6fd49ed0..4c58aa88654398f5d3437dc42b563dad57069f63 100644 (file)
@@ -111,7 +111,7 @@ Node TermDb::getGroundTerm( Node f, unsigned i ) {
   return d_op_map[f][i];
 }
 
-Node TermDb::getOperator( Node n ) {
+Node TermDb::getMatchOperator( Node n ) {
   //return n.getOperator();
   Kind k = n.getKind();
   if( k==SELECT || k==STORE || k==UNION || k==INTERSECTION || k==SUBSET || k==SETMINUS || k==MEMBER || k==SINGLETON ){
@@ -148,10 +148,10 @@ void TermDb::addTerm( Node n, std::set< Node >& added, bool withinQuant, bool wi
         //if this is an atomic trigger, consider adding it
         if( inst::Trigger::isAtomicTrigger( n ) ){
           Trace("term-db") << "register term in db " << n << std::endl;
-          Node op = getOperator( n );
+          Node op = getMatchOperator( n );
           d_op_map[op].push_back( n );
           added.insert( n );
-
+          
           if( options::eagerInstQuant() ){
             for( unsigned i=0; i<n.getNumChildren(); i++ ){
               if( !n.hasAttribute(InstLevelAttribute()) && n.getAttribute(InstLevelAttribute())==0 ){
@@ -223,7 +223,7 @@ TNode TermDb::evaluateTerm( TNode n, std::map< TNode, TNode >& subs, bool subsRe
     }
   }else{
     if( n.hasOperator() ){
-      TNode f = getOperator( n );
+      TNode f = getMatchOperator( n );
       if( !f.isNull() ){
         std::vector< TNode > args;
         for( unsigned i=0; i<n.getNumChildren(); i++ ){
@@ -257,7 +257,7 @@ TNode TermDb::evaluateTerm( TNode n ) {
     return ee->getRepresentative( n );
   }else if( n.getKind()!=BOUND_VARIABLE ){
     if( n.hasOperator() ){
-      TNode f = getOperator( n );
+      TNode f = getMatchOperator( n );
       if( !f.isNull() ){
         std::vector< TNode > args;
         for( unsigned i=0; i<n.getNumChildren(); i++ ){
@@ -503,7 +503,7 @@ void TermDb::reset( Theory::Effort effort ){
           if( Trace.isOn("term-db-debug") ){
             Trace("term-db-debug") << "Adding term " << n << " with arg reps : ";
             for( unsigned i=0; i<d_arg_reps[n].size(); i++ ){
-              Trace("term-db-debug") << d_arg_reps[n] << " ";
+              Trace("term-db-debug") << d_arg_reps[n][i] << " ";
             }
             Trace("term-db-debug") << std::endl;
             if( ee->hasTerm( n ) ){
index b1cfcf2aef3af8375df51a2494aa63b16d14eb56..fcacbd6868402be636ae0573c698e11fe6e91391 100644 (file)
@@ -188,7 +188,6 @@ public:
   /** map from type nodes to terms of that type */
   std::map< TypeNode, std::vector< Node > > d_type_map;
 
-
   /** count number of non-redundant ground terms per operator */
   std::map< Node, int > d_op_nonred_count;
   /**mapping from UF terms to representatives of their arguments */
@@ -212,8 +211,8 @@ public:
   void presolve();
   /** reset (calculate which terms are active) */
   void reset( Theory::Effort effort );
-  /** get operator*/
-  Node getOperator( Node n );
+  /** get match operator */
+  Node getMatchOperator( Node n );
   /** get term arg index */
   TermArgTrie * getTermArgTrie( Node f );
   TermArgTrie * getTermArgTrie( Node eqc, Node f );
index 9aee18317f58b1307486ffdac79526a924d9e5ed..0628b7fbc7e65d9d189327371d956ab330318ba3 100644 (file)
@@ -73,7 +73,7 @@ d_quantEngine( qe ), d_f( f ){
   //Notice() << "Trigger : " << (*this) << "  for " << f << std::endl;
   if( options::eagerInstQuant() ){
     for( int i=0; i<(int)d_nodes.size(); i++ ){
-      Node op = qe->getTermDatabase()->getOperator( d_nodes[i] );
+      Node op = qe->getTermDatabase()->getMatchOperator( d_nodes[i] );
       qe->getTermDatabase()->registerTrigger( this, op );
     }
   }
index 7cda713a10a65d8d40c6410bcfdfc2f69ca15849..5d19d603c62ed9810407b0ef0b27d55b92186e8f 100644 (file)
@@ -138,8 +138,10 @@ QuantifiersEngine::QuantifiersEngine(context::Context* c, context::UserContext*
   d_builder = NULL;
 
   d_total_inst_count_debug = 0;
-  d_ierCounter = 0;
-  d_ierCounter_lc = 0;
+  //allow theory combination to go first, once initially, when instWhenDelayIncrement = true
+  d_ierCounter = options::instWhenDelayIncrement() ? 1 : 0;
+  d_ierCounter_lc = options::instWhenDelayIncrement() ? 1 : 0;
+  d_ierCounterLastLc = -1;
   //if any strategy called only on last call, use phase 3
   d_inst_when_phase = options::cbqi() ? 3 : 2;
 }
@@ -338,10 +340,12 @@ void QuantifiersEngine::check( Theory::Effort e ){
     Trace("quant-engine-debug") << "Master equality engine not consistent, return." << std::endl;
     return;
   }
-  if( e==Theory::EFFORT_FULL ){
-    d_ierCounter++;
-  }else if( e==Theory::EFFORT_LAST_CALL ){
-    d_ierCounter_lc++;
+  if( !options::instWhenDelayIncrement() ){
+    if( e==Theory::EFFORT_FULL ){
+      d_ierCounter++;
+    }else if( e==Theory::EFFORT_LAST_CALL ){
+      d_ierCounter_lc++;
+    }
   }
   bool needsCheck = !d_lemmas_waiting.empty();
   unsigned needsModelE = QEFFORT_NONE;
@@ -392,14 +396,6 @@ void QuantifiersEngine::check( Theory::Effort e ){
       Trace("quant-engine-debug") << "  Needs model effort : " << needsModelE << std::endl;
       Trace("quant-engine-debug") << "Resetting all modules..." << std::endl;
     }
-    if( Trace.isOn("quant-engine-ee") ){
-      Trace("quant-engine-ee") << "Equality engine : " << std::endl;
-      debugPrintEqualityEngine( "quant-engine-ee" );
-    }
-    if( Trace.isOn("quant-engine-assert") ){
-      Trace("quant-engine-assert") << "Assertions : " << std::endl;
-      getTheoryEngine()->printAssertions("quant-engine-assert");
-    }
 
     //reset relevant information
 
@@ -410,12 +406,22 @@ void QuantifiersEngine::check( Theory::Effort e ){
     }
 
     Trace("quant-engine-debug2") << "Reset term db..." << std::endl;
+    d_eq_query->reset( e );
     d_term_db->reset( e );
-    d_eq_query->reset();
     if( d_rel_dom ){
       d_rel_dom->reset();
     }
     d_model->reset_round();
+    
+    if( Trace.isOn("quant-engine-ee") ){
+      Trace("quant-engine-ee") << "Equality engine : " << std::endl;
+      debugPrintEqualityEngine( "quant-engine-ee" );
+    }
+    if( Trace.isOn("quant-engine-assert") ){
+      Trace("quant-engine-assert") << "Assertions : " << std::endl;
+      getTheoryEngine()->printAssertions("quant-engine-assert");
+    }
+    
     for( unsigned i=0; i<d_modules.size(); i++ ){
       Trace("quant-engine-debug2") << "Reset " << d_modules[i]->identify().c_str() << std::endl;
       d_modules[i]->reset_round( e );
@@ -462,23 +468,39 @@ void QuantifiersEngine::check( Theory::Effort e ){
       //if we have added one, stop
       if( d_hasAddedLemma ){
         break;
-      }else if( e==Theory::EFFORT_LAST_CALL && quant_e==QEFFORT_MODEL ){
-        //if we have a chance not to set incomplete
-        if( !setIncomplete ){
-          setIncomplete = false;
-          //check if we should set the incomplete flag
-          for( unsigned i=0; i<qm.size(); i++ ){
-            if( !qm[i]->checkComplete() ){
-              Trace("quant-engine-debug") << "Set incomplete because " << qm[i]->identify().c_str() << " was incomplete." << std::endl;
-              setIncomplete = true;
+      }else{
+        if( quant_e==QEFFORT_CONFLICT ){
+          if( options::instWhenDelayIncrement() ){
+            if( e==Theory::EFFORT_FULL ){
+              //increment if a last call happened, or if we already were in phase
+              if( d_ierCounterLastLc!=d_ierCounter_lc || d_ierCounter%d_inst_when_phase==0 ){
+                d_ierCounter++;
+                d_ierCounterLastLc = d_ierCounter_lc;
+              }
+            }else if( e==Theory::EFFORT_LAST_CALL ){
+              d_ierCounter_lc++;
+            }
+          }
+        }else if( quant_e==QEFFORT_MODEL ){
+          if( e==Theory::EFFORT_LAST_CALL ){
+            //if we have a chance not to set incomplete
+            if( !setIncomplete ){
+              setIncomplete = false;
+              //check if we should set the incomplete flag
+              for( unsigned i=0; i<qm.size(); i++ ){
+                if( !qm[i]->checkComplete() ){
+                  Trace("quant-engine-debug") << "Set incomplete because " << qm[i]->identify().c_str() << " was incomplete." << std::endl;
+                  setIncomplete = true;
+                  break;
+                }
+              }
+            }
+            //if setIncomplete = false, we will answer SAT, otherwise we will run at quant_e QEFFORT_LAST_CALL
+            if( !setIncomplete ){
               break;
             }
           }
         }
-        //if setIncomplete = false, we will answer SAT, otherwise we will run at quant_e QEFFORT_LAST_CALL
-        if( !setIncomplete ){
-          break;
-        }
       }
     }
     Trace("quant-engine-debug") << "Done check modules that needed check." << std::endl;
@@ -1022,6 +1044,7 @@ bool QuantifiersEngine::addSplitEquality( Node n1, Node n2, bool reqPhase, bool
 }
 
 bool QuantifiersEngine::getInstWhenNeedsCheck( Theory::Effort e ) {
+  Trace("quant-engine-debug2") << "Get inst when needs check, counts=" << d_ierCounter << ", " << d_ierCounter_lc << std::endl;
   //determine if we should perform check, based on instWhenMode
   bool performCheck = false;
   if( options::instWhenMode()==quantifiers::INST_WHEN_FULL ){
@@ -1225,9 +1248,64 @@ void QuantifiersEngine::debugPrintEqualityEngine( const char * c ) {
   }
 }
 
-void EqualityQueryQuantifiersEngine::reset(){
+void EqualityQueryQuantifiersEngine::reset( Theory::Effort e ){
   d_int_rep.clear();
   d_reset_count++;
+  processInferences( e );
+}
+
+void EqualityQueryQuantifiersEngine::processInferences( Theory::Effort e ) {
+  if( options::inferArithTriggerEq() ){
+    std::vector< Node > infer;
+    std::vector< Node > infer_exp;
+    eq::EqualityEngine* ee = getEngine();
+    eq::EqClassesIterator eqcs_i = eq::EqClassesIterator( ee );
+    while( !eqcs_i.isFinished() ){
+      TNode r = (*eqcs_i);
+      TypeNode tr = r.getType();
+      if( tr.isReal() ){
+        std::vector< Node > eqc;
+        eq::EqClassIterator eqc_i = eq::EqClassIterator( r, ee );
+        while( !eqc_i.isFinished() ){
+          TNode n = (*eqc_i);
+          //accumulate equivalence class
+          eqc.push_back( n );
+          ++eqc_i;
+        }
+        for( unsigned i=0; i<eqc.size(); i++ ){
+          Node n = eqc[i];
+          if( n.getKind()==PLUS ){
+            std::map< Node, Node > msum;
+            QuantArith::getMonomialSum( n, msum );
+            for( std::map< Node, Node >::iterator it = msum.begin(); it != msum.end(); ++it ){
+              //if the term is a trigger
+              Node t = it->first;
+              if( inst::Trigger::isAtomicTrigger( t ) ){
+                for( unsigned j=0; j<eqc.size(); j++ ){
+                  if( i!=j ){
+                    Node eq = n.eqNode( eqc[j] );
+                    Node v = QuantArith::solveEqualityFor( eq, t );
+                    Trace("quant-engine-ee-proc-debug") << "processInferences : Can infer : " << t << " == " << v << std::endl;
+                    if( ee->hasTerm( v ) && ee->getRepresentative( v )!=r ){
+                      Trace("quant-engine-ee-proc") << "processInferences : Infer : " << t << " == " << v << " from " << n << " == " << eqc[j] << std::endl;
+                      infer.push_back( t.eqNode( v ) );
+                      infer_exp.push_back( n.eqNode( eqc[j] ) );
+                    }
+                  }
+                }
+              }
+            }
+          } 
+        }
+      }
+      ++eqcs_i;
+    }
+    for( unsigned i=0; i<infer.size(); i++ ){
+      Trace("quant-engine-ee-proc-debug") << "Asserting equality " << infer[i] << std::endl;
+      ee->assertEquality( infer[i], true, infer_exp[i] );
+    }
+    Assert( ee->consistent() );
+  }
 }
 
 bool EqualityQueryQuantifiersEngine::hasTerm( Node a ){
index 92296ebac8bc46221491b79c4c6d002295d8bc54..0c43223d87b0f53c4ac25cd4c62b2f82982771c6 100644 (file)
@@ -206,6 +206,7 @@ private:
   /** inst round counters */
   int d_ierCounter;
   int d_ierCounter_lc;
+  int d_ierCounterLastLc;
   int d_inst_when_phase;
   /** has presolve been called */
   context::CDO< bool > d_presolve;
@@ -399,6 +400,8 @@ private:
   /** reset count */
   int d_reset_count;
 
+  /** processInferences : will merge equivalence classes in master equality engine, if possible */
+  void processInferences( Theory::Effort e );
   /** node contains */
   Node getInstance( Node n, const std::vector< Node >& eqc, std::hash_map<TNode, Node, TNodeHashFunction>& cache );
   /** get score */
@@ -407,7 +410,7 @@ public:
   EqualityQueryQuantifiersEngine( QuantifiersEngine* qe ) : d_qe( qe ), d_reset_count( 0 ){}
   ~EqualityQueryQuantifiersEngine(){}
   /** reset */
-  void reset();
+  void reset( Theory::Effort e );
   /** general queries about equality */
   bool hasTerm( Node a );
   Node getRepresentative( Node a );
index 060584fcfed6cf50dc1544795c5581f89d5f4beb..4b29148a7727374b5e61105b8d1f6de32704b617 100644 (file)
@@ -118,10 +118,11 @@ void SortInference::simplify( std::vector< Node >& assertions, bool doSortInfere
   if( doSortInference ){
     Trace("sort-inference-proc") << "Calculating sort inference..." << std::endl;
     //process all assertions
+    std::map< Node, int > visited;
     for( unsigned i=0; i<assertions.size(); i++ ){
       Trace("sort-inference-debug") << "Process " << assertions[i] << std::endl;
       std::map< Node, Node > var_bound;
-      process( assertions[i], var_bound );
+      process( assertions[i], var_bound, visited );
     }
     Trace("sort-inference-proc") << "...done" << std::endl;
     for( std::map< Node, int >::iterator it = d_op_return_types.begin(); it != d_op_return_types.end(); ++it ){
@@ -155,10 +156,11 @@ void SortInference::simplify( std::vector< Node >& assertions, bool doSortInfere
       bool rewritten = false;
       //determine monotonicity of sorts
       Trace("sort-inference-proc") << "Calculating monotonicty for subsorts..." << std::endl;
+      std::map< Node, std::map< int, bool > > visited;
       for( unsigned i=0; i<assertions.size(); i++ ){
         Trace("sort-inference-debug") << "Process monotonicity for " << assertions[i] << std::endl;
         std::map< Node, Node > var_bound;
-        processMonotonic( assertions[i], true, true, var_bound );
+        processMonotonic( assertions[i], true, true, var_bound, visited );
       }
       Trace("sort-inference-proc") << "...done" << std::endl;
 
@@ -176,13 +178,16 @@ void SortInference::simplify( std::vector< Node >& assertions, bool doSortInfere
 
       //simplify all assertions by introducing new symbols wherever necessary
       Trace("sort-inference-proc") << "Perform simplification..." << std::endl;
+      std::map< Node, std::map< TypeNode, Node > > visited2;
       for( unsigned i=0; i<assertions.size(); i++ ){
         Node prev = assertions[i];
         std::map< Node, Node > var_bound;
-        Trace("sort-inference-debug") << "Rewrite " << assertions[i] << std::endl;
-        Node curr = simplify( assertions[i], var_bound );
+        Trace("sort-inference-debug") << "Simplify " << assertions[i] << std::endl;
+        TypeNode tnn;
+        Node curr = simplifyNode( assertions[i], var_bound, tnn, visited2 );
         Trace("sort-inference-debug") << "Done." << std::endl;
         if( curr!=assertions[i] ){
+          Trace("sort-inference-debug") << "Rewrite " << curr << std::endl;
           curr = theory::Rewriter::rewrite( curr );
           rewritten = true;
           Trace("sort-inference-rewrite") << assertions << std::endl;
@@ -196,9 +201,17 @@ void SortInference::simplify( std::vector< Node >& assertions, bool doSortInfere
       for( std::map< TypeNode, std::map< Node, Node > >::iterator it = d_const_map.begin(); it != d_const_map.end(); ++it ){
         std::vector< Node > consts;
         for( std::map< Node, Node >::iterator it2 = it->second.begin(); it2 != it->second.end(); ++it2 ){
+          Assert( it2->first.isConst() );
           consts.push_back( it2->second );
         }
-        //TODO: add lemma enforcing introduced constants to be distinct
+        //add lemma enforcing introduced constants to be distinct
+        if( consts.size()>1 ){
+          Node distinct_const = NodeManager::currentNM()->mkNode( kind::DISTINCT, consts );
+          Trace("sort-inference-rewrite") << "Add the constant distinctness lemma: " << std::endl;
+          Trace("sort-inference-rewrite") << "  " << distinct_const << std::endl;
+          assertions.push_back( distinct_const );
+          rewritten = true;
+        }
       }
 
       //enforce constraints based on monotonicity
@@ -242,43 +255,15 @@ void SortInference::simplify( std::vector< Node >& assertions, bool doSortInfere
       reset();
       Trace("sort-inference-debug") << "Finished sort inference, rewritten = " << rewritten << std::endl;
     }
-    /*
-    else if( !options::ufssSymBreak() ){
-      //just add the unit lemmas between constants
-      std::map< TypeNode, std::map< int, Node > > constants;
-      for( std::map< Node, int >::iterator it = d_op_return_types.begin(); it != d_op_return_types.end(); ++it ){
-        int rt = d_type_union_find.getRepresentative( it->second );
-        if( d_op_arg_types[ it->first ].empty() ){
-          TypeNode tn = it->first.getType();
-          if( constants[ tn ].find( rt )==constants[ tn ].end() ){
-            constants[ tn ][ rt ] = it->first;
-          }
-        }
-      }
-      //add unit lemmas for each constant
-      for( std::map< TypeNode, std::map< int, Node > >::iterator it = constants.begin(); it != constants.end(); ++it ){
-        Node first_const;
-        for( std::map< int, Node >::iterator it2 = it->second.begin(); it2 != it->second.end(); ++it2 ){
-          if( first_const.isNull() ){
-            first_const = it2->second;
-          }else{
-            Node eq = first_const.eqNode( it2->second );
-            //eq = Rewriter::rewrite( eq );
-            Trace("sort-inference-lemma") << "Sort inference lemma : " << eq << std::endl;
-            assertions.push_back( eq );
-          }
-        }
-      }
-    }
-    */
     initialSortCount = sortCount;
   }
   if( doMonotonicyInference ){
+    std::map< Node, std::map< int, bool > > visited;
     Trace("sort-inference-proc") << "Calculating monotonicty for types..." << std::endl;
     for( unsigned i=0; i<assertions.size(); i++ ){
       Trace("sort-inference-debug") << "Process type monotonicity for " << assertions[i] << std::endl;
       std::map< Node, Node > var_bound;
-      processMonotonic( assertions[i], true, true, var_bound, true );
+      processMonotonic( assertions[i], true, true, var_bound, visited, true );
     }
     Trace("sort-inference-proc") << "...done" << std::endl;
   }
@@ -338,174 +323,185 @@ int SortInference::getIdForType( TypeNode tn ){
   }
 }
 
-int SortInference::process( Node n, std::map< Node, Node >& var_bound ){
-  //add to variable bindings
-  if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
-    if( d_var_types.find( n )!=d_var_types.end() ){
-      return getIdForType( n.getType() );
-    }else{
-      for( size_t i=0; i<n[0].getNumChildren(); i++ ){
-        //apply sort inference to quantified variables
-        d_var_types[n][ n[0][i] ] = sortCount;
-        sortCount++;
+int SortInference::process( Node n, std::map< Node, Node >& var_bound, std::map< Node, int >& visited ){
+  std::map< Node, int >::iterator itv = visited.find( n );
+  if( itv!=visited.end() ){
+    return itv->second;
+  }else{
+    //add to variable bindings
+    bool use_new_visited = false;
+    std::map< Node, int > new_visited;
+    if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
+      if( d_var_types.find( n )!=d_var_types.end() ){
+        return getIdForType( n.getType() );
+      }else{
+        for( size_t i=0; i<n[0].getNumChildren(); i++ ){
+          //apply sort inference to quantified variables
+          d_var_types[n][ n[0][i] ] = sortCount;
+          sortCount++;
 
-        //type of the quantified variable must be the same
-        var_bound[ n[0][i] ] = n;
+          //type of the quantified variable must be the same
+          var_bound[ n[0][i] ] = n;
+        }
       }
+      use_new_visited = true;
     }
-  }
 
-  //process children
-  std::vector< Node > children;
-  std::vector< int > child_types;
-  for( size_t i=0; i<n.getNumChildren(); i++ ){
-    bool processChild = true;
-    if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
-      processChild = options::userPatternsQuant()==theory::quantifiers::USER_PAT_MODE_IGNORE ? i==1 : i>=1;
-    }
-    if( processChild ){
-      children.push_back( n[i] );
-      child_types.push_back( process( n[i], var_bound ) );
+    //process children
+    std::vector< Node > children;
+    std::vector< int > child_types;
+    for( size_t i=0; i<n.getNumChildren(); i++ ){
+      bool processChild = true;
+      if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
+        processChild = options::userPatternsQuant()==theory::quantifiers::USER_PAT_MODE_IGNORE ? i==1 : i>=1;
+      }
+      if( processChild ){
+        children.push_back( n[i] );
+        child_types.push_back( process( n[i], var_bound, use_new_visited ? new_visited : visited ) );
+      }
     }
-  }
 
-  //remove from variable bindings
-  if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
-    //erase from variable bound
-    for( size_t i=0; i<n[0].getNumChildren(); i++ ){
-      var_bound.erase( n[0][i] );
-    }
-  }
-  Trace("sort-inference-debug") << "...Process " << n << std::endl;
-
-  int retType;
-  if( n.getKind()==kind::EQUAL ){
-    Trace("sort-inference-debug") << "For equality " << n << ", set equal types from : " << n[0].getType() << " " << n[1].getType() << std::endl;
-    //if original types are mixed (e.g. Int/Real), don't commit type equality in either direction
-    if( n[0].getType()!=n[1].getType() ){
-      //for now, assume the original types
-      for( unsigned i=0; i<2; i++ ){
-        int ct = getIdForType( n[i].getType() );
-        setEqual( child_types[i], ct );
+    //remove from variable bindings
+    if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
+      //erase from variable bound
+      for( size_t i=0; i<n[0].getNumChildren(); i++ ){
+        var_bound.erase( n[0][i] );
       }
-    }else{
-      //we only require that the left and right hand side must be equal
-      setEqual( child_types[0], child_types[1] );
     }
-    //int eqType = getIdForType( n[0].getType() );
-    //setEqual( child_types[0], eqType );
-    //setEqual( child_types[1], eqType );
-    retType = getIdForType( n.getType() );
-  }else if( n.getKind()==kind::APPLY_UF ){
-    Node op = n.getOperator();
-    TypeNode tn_op = op.getType();
-    if( d_op_return_types.find( op )==d_op_return_types.end() ){
-      if( n.getType().isBoolean() ){
-        //use booleans
-        d_op_return_types[op] = getIdForType( n.getType() );
+    Trace("sort-inference-debug") << "...Process " << n << std::endl;
+
+    int retType;
+    if( n.getKind()==kind::EQUAL ){
+      Trace("sort-inference-debug") << "For equality " << n << ", set equal types from : " << n[0].getType() << " " << n[1].getType() << std::endl;
+      //if original types are mixed (e.g. Int/Real), don't commit type equality in either direction
+      if( n[0].getType()!=n[1].getType() ){
+        //for now, assume the original types
+        for( unsigned i=0; i<2; i++ ){
+          int ct = getIdForType( n[i].getType() );
+          setEqual( child_types[i], ct );
+        }
       }else{
-        //assign arbitrary sort for return type
-        d_op_return_types[op] = sortCount;
-        sortCount++;
+        //we only require that the left and right hand side must be equal
+        setEqual( child_types[0], child_types[1] );
       }
-      //d_type_eq_class[sortCount].push_back( op );
-      //assign arbitrary sort for argument types
-      for( size_t i=0; i<n.getNumChildren(); i++ ){
-        d_op_arg_types[op].push_back( sortCount );
-        sortCount++;
-      }
-    }
-    for( size_t i=0; i<n.getNumChildren(); i++ ){
-      //the argument of the operator must match the return type of the subterm
-      if( n[i].getType()!=tn_op[i] ){
-        //if type mismatch, assume original types
-        Trace("sort-inference-debug") << "Argument " << i << " of " << op << " " << n[i] << " has type " << n[i].getType();
-        Trace("sort-inference-debug") << ", while operator arg has type " << tn_op[i] << std::endl;
-        int ct1 = getIdForType( n[i].getType() );
-        setEqual( child_types[i], ct1 );
-        int ct2 = getIdForType( tn_op[i] );
-        setEqual( d_op_arg_types[op][i], ct2 );
-      }else{
-        setEqual( child_types[i], d_op_arg_types[op][i] );
+      d_equality_types[n] = child_types[0];
+      retType = getIdForType( n.getType() );
+    }else if( n.getKind()==kind::APPLY_UF ){
+      Node op = n.getOperator();
+      TypeNode tn_op = op.getType();
+      if( d_op_return_types.find( op )==d_op_return_types.end() ){
+        if( n.getType().isBoolean() ){
+          //use booleans
+          d_op_return_types[op] = getIdForType( n.getType() );
+        }else{
+          //assign arbitrary sort for return type
+          d_op_return_types[op] = sortCount;
+          sortCount++;
+        }
+        //d_type_eq_class[sortCount].push_back( op );
+        //assign arbitrary sort for argument types
+        for( size_t i=0; i<n.getNumChildren(); i++ ){
+          d_op_arg_types[op].push_back( sortCount );
+          sortCount++;
+        }
       }
-    }
-    //return type is the return type
-    retType = d_op_return_types[op];
-  }else{
-    std::map< Node, Node >::iterator it = var_bound.find( n );
-    if( it!=var_bound.end() ){
-      Trace("sort-inference-debug") << n << " is a bound variable." << std::endl;
-      //the return type was specified while binding
-      retType = d_var_types[it->second][n];
-    }else if( n.getKind() == kind::VARIABLE || n.getKind()==kind::SKOLEM ){
-      Trace("sort-inference-debug") << n << " is a variable." << std::endl;
-      if( d_op_return_types.find( n )==d_op_return_types.end() ){
-        //assign arbitrary sort
-        d_op_return_types[n] = sortCount;
-        sortCount++;
-        //d_type_eq_class[sortCount].push_back( n );
+      for( size_t i=0; i<n.getNumChildren(); i++ ){
+        //the argument of the operator must match the return type of the subterm
+        if( n[i].getType()!=tn_op[i] ){
+          //if type mismatch, assume original types
+          Trace("sort-inference-debug") << "Argument " << i << " of " << op << " " << n[i] << " has type " << n[i].getType();
+          Trace("sort-inference-debug") << ", while operator arg has type " << tn_op[i] << std::endl;
+          int ct1 = getIdForType( n[i].getType() );
+          setEqual( child_types[i], ct1 );
+          int ct2 = getIdForType( tn_op[i] );
+          setEqual( d_op_arg_types[op][i], ct2 );
+        }else{
+          setEqual( child_types[i], d_op_arg_types[op][i] );
+        }
       }
-      retType = d_op_return_types[n];
-    //}else if( n.isConst() ){
-    //  Trace("sort-inference-debug") << n << " is a constant." << std::endl;
-      //can be any type we want
-    //  retType = sortCount;
-    //  sortCount++;
+      //return type is the return type
+      retType = d_op_return_types[op];
     }else{
-      Trace("sort-inference-debug") << n << " is a interpreted symbol." << std::endl;
-      //it is an interpretted term
-      for( size_t i=0; i<children.size(); i++ ){
-        Trace("sort-inference-debug") << children[i] << " forced to have " << children[i].getType() << std::endl;
-        //must enforce the actual type of the operator on the children
-        int ct = getIdForType( children[i].getType() );
-        setEqual( child_types[i], ct );
+      std::map< Node, Node >::iterator it = var_bound.find( n );
+      if( it!=var_bound.end() ){
+        Trace("sort-inference-debug") << n << " is a bound variable." << std::endl;
+        //the return type was specified while binding
+        retType = d_var_types[it->second][n];
+      }else if( n.getKind() == kind::VARIABLE || n.getKind()==kind::SKOLEM ){
+        Trace("sort-inference-debug") << n << " is a variable." << std::endl;
+        if( d_op_return_types.find( n )==d_op_return_types.end() ){
+          //assign arbitrary sort
+          d_op_return_types[n] = sortCount;
+          sortCount++;
+          //d_type_eq_class[sortCount].push_back( n );
+        }
+        retType = d_op_return_types[n];
+      }else if( n.isConst() ){
+        Trace("sort-inference-debug") << n << " is a constant." << std::endl;
+        //can be any type we want
+        retType = sortCount;
+        sortCount++;
+      }else{
+        Trace("sort-inference-debug") << n << " is a interpreted symbol." << std::endl;
+        //it is an interpreted term
+        for( size_t i=0; i<children.size(); i++ ){
+          Trace("sort-inference-debug") << children[i] << " forced to have " << children[i].getType() << std::endl;
+          //must enforce the actual type of the operator on the children
+          int ct = getIdForType( children[i].getType() );
+          setEqual( child_types[i], ct );
+        }
+        //return type must be the actual return type
+        retType = getIdForType( n.getType() );
       }
-      //return type must be the actual return type
-      retType = getIdForType( n.getType() );
     }
+    Trace("sort-inference-debug") << "...Type( " << n << " ) = ";
+    printSort("sort-inference-debug", retType );
+    Trace("sort-inference-debug") << std::endl;
+    visited[n] = retType;
+    return retType;
   }
-  Trace("sort-inference-debug") << "...Type( " << n << " ) = ";
-  printSort("sort-inference-debug", retType );
-  Trace("sort-inference-debug") << std::endl;
-  return retType;
 }
 
-void SortInference::processMonotonic( Node n, bool pol, bool hasPol, std::map< Node, Node >& var_bound, bool typeMode ) {
-  Trace("sort-inference-debug") << "...Process monotonic " << pol << " " << hasPol << " " << n << std::endl;
-  if( n.getKind()==kind::FORALL ){
-    //only consider variables universally if it is possible this quantified formula is asserted positively
-    if( !hasPol || pol ){
-      for( unsigned i=0; i<n[0].getNumChildren(); i++ ){
-        var_bound[n[0][i]] = n;
+void SortInference::processMonotonic( Node n, bool pol, bool hasPol, std::map< Node, Node >& var_bound, std::map< Node, std::map< int, bool > >& visited, bool typeMode ) {
+  int pindex = hasPol ? ( pol ? 1 : -1 ) : 0;
+  if( visited[n].find( pindex )==visited[n].end() ){
+    visited[n][pindex] = true;
+    Trace("sort-inference-debug") << "...Process monotonic " << pol << " " << hasPol << " " << n << std::endl;
+    if( n.getKind()==kind::FORALL ){
+      //only consider variables universally if it is possible this quantified formula is asserted positively
+      if( !hasPol || pol ){
+        for( unsigned i=0; i<n[0].getNumChildren(); i++ ){
+          var_bound[n[0][i]] = n;
+        }
       }
-    }
-    processMonotonic( n[1], pol, hasPol, var_bound, typeMode );
-    if( !hasPol || pol ){
-      for( unsigned i=0; i<n[0].getNumChildren(); i++ ){
-        var_bound.erase( n[0][i] );
+      processMonotonic( n[1], pol, hasPol, var_bound, visited, typeMode );
+      if( !hasPol || pol ){
+        for( unsigned i=0; i<n[0].getNumChildren(); i++ ){
+          var_bound.erase( n[0][i] );
+        }
       }
-    }
-    return;
-  }else if( n.getKind()==kind::EQUAL ){
-    if( !hasPol || pol ){
-      for( unsigned i=0; i<2; i++ ){
-        if( var_bound.find( n[i] )!=var_bound.end() ){
-          if( !typeMode ){
-            int sid = getSortId( var_bound[n[i]], n[i] );
-            d_non_monotonic_sorts[sid] = true;
-          }else{
-            d_non_monotonic_sorts_orig[n[i].getType()] = true;
+      return;
+    }else if( n.getKind()==kind::EQUAL ){
+      if( !hasPol || pol ){
+        for( unsigned i=0; i<2; i++ ){
+          if( var_bound.find( n[i] )!=var_bound.end() ){
+            if( !typeMode ){
+              int sid = getSortId( var_bound[n[i]], n[i] );
+              d_non_monotonic_sorts[sid] = true;
+            }else{
+              d_non_monotonic_sorts_orig[n[i].getType()] = true;
+            }
+            break;
           }
-          break;
         }
       }
     }
-  }
-  for( unsigned i=0; i<n.getNumChildren(); i++ ){
-    bool npol;
-    bool nhasPol;
-    theory::QuantPhaseReq::getPolarity( n, i, hasPol, pol, nhasPol, npol );
-    processMonotonic( n[i], npol, nhasPol, var_bound, typeMode );
+    for( unsigned i=0; i<n.getNumChildren(); i++ ){
+      bool npol;
+      bool nhasPol;
+      theory::QuantPhaseReq::getPolarity( n, i, hasPol, pol, nhasPol, npol );
+      processMonotonic( n[i], npol, nhasPol, var_bound, visited, typeMode );
+    }
   }
 }
 
@@ -544,7 +540,7 @@ TypeNode SortInference::getTypeForId( int t ){
 }
 
 Node SortInference::getNewSymbol( Node old, TypeNode tn ){
-  if( tn==old.getType() ){
+  if( tn.isNull() || tn==old.getType() ){
     return old;
   }else if( old.isConst() ){
     //must make constant of type tn
@@ -565,128 +561,139 @@ Node SortInference::getNewSymbol( Node old, TypeNode tn ){
   }
 }
 
-Node SortInference::simplify( Node n, std::map< Node, Node >& var_bound ){
-  Trace("sort-inference-debug2") << "Simplify " << n << std::endl;
-  std::vector< Node > children;
-  if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
-    //recreate based on types of variables
-    std::vector< Node > new_children;
-    for( size_t i=0; i<n[0].getNumChildren(); i++ ){
-      TypeNode tn = getOrCreateTypeForId( d_var_types[n][ n[0][i] ], n[0][i].getType() );
-      Node v = getNewSymbol( n[0][i], tn );
-      Trace("sort-inference-debug2") << "Map variable " << n[0][i] << " to " << v << std::endl;
-      new_children.push_back( v );
-      var_bound[ n[0][i] ] = v;
+Node SortInference::simplifyNode( Node n, std::map< Node, Node >& var_bound, TypeNode tnn, std::map< Node, std::map< TypeNode, Node > >& visited ){
+  std::map< TypeNode, Node >::iterator itv = visited[n].find( tnn );
+  if( itv!=visited[n].end() ){
+    return itv->second;
+  }else{
+    Trace("sort-inference-debug2") << "Simplify " << n << ", type context=" << tnn << std::endl;
+    std::vector< Node > children;
+    std::map< Node, std::map< TypeNode, Node > > new_visited;
+    bool use_new_visited = false;
+    if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
+      //recreate based on types of variables
+      std::vector< Node > new_children;
+      for( size_t i=0; i<n[0].getNumChildren(); i++ ){
+        TypeNode tn = getOrCreateTypeForId( d_var_types[n][ n[0][i] ], n[0][i].getType() );
+        Node v = getNewSymbol( n[0][i], tn );
+        Trace("sort-inference-debug2") << "Map variable " << n[0][i] << " to " << v << std::endl;
+        new_children.push_back( v );
+        var_bound[ n[0][i] ] = v;
+      }
+      children.push_back( NodeManager::currentNM()->mkNode( n[0].getKind(), new_children ) );
+      use_new_visited = true;
     }
-    children.push_back( NodeManager::currentNM()->mkNode( n[0].getKind(), new_children ) );
-  }
 
-  //process children
-  if( n.getMetaKind() == kind::metakind::PARAMETERIZED ){
-    children.push_back( n.getOperator() );
-  }
-  bool childChanged = false;
-  for( size_t i=0; i<n.getNumChildren(); i++ ){
-    bool processChild = true;
-    if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
-      processChild = options::userPatternsQuant()==theory::quantifiers::USER_PAT_MODE_IGNORE ? i==1 : i>=1;
+    //process children
+    if( n.getMetaKind() == kind::metakind::PARAMETERIZED ){
+      children.push_back( n.getOperator() );
     }
-    if( processChild ){
-      Node nc = simplify( n[i], var_bound );
-      Trace("sort-inference-debug2") << "Simplify " << i << " " << n[i] << " returned " << nc << std::endl;
-      children.push_back( nc );
-      childChanged = childChanged || nc!=n[i];
+    Node op;
+    if( n.hasOperator() ){
+      op = n.getOperator();
     }
-  }
-
-  //remove from variable bindings
-  if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
-    //erase from variable bound
-    for( size_t i=0; i<n[0].getNumChildren(); i++ ){
-      Trace("sort-inference-debug2") << "Remove bound for " << n[0][i] << std::endl;
-      var_bound.erase( n[0][i] );
+    bool childChanged = false;
+    TypeNode tnnc;
+    for( size_t i=0; i<n.getNumChildren(); i++ ){
+      bool processChild = true;
+      if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
+        processChild = options::userPatternsQuant()==theory::quantifiers::USER_PAT_MODE_IGNORE ? i==1 : i>=1;
+      }
+      if( processChild ){
+        if( n.getKind()==kind::APPLY_UF ){
+          Assert( d_op_arg_types.find( op )!=d_op_arg_types.end() );
+          tnnc = getOrCreateTypeForId( d_op_arg_types[op][i], n[i].getType() );
+          Assert( !tnnc.isNull() );
+        }else if( n.getKind()==kind::EQUAL && i==0 ){
+          Assert( d_equality_types.find( n )!=d_equality_types.end() );
+          tnnc = getOrCreateTypeForId( d_equality_types[n], n[0].getType() );
+          Assert( !tnnc.isNull() );
+        }
+        Node nc = simplifyNode( n[i], var_bound, tnnc, use_new_visited ? new_visited : visited );
+        Trace("sort-inference-debug2") << "Simplify " << i << " " << n[i] << " returned " << nc << std::endl;
+        children.push_back( nc );
+        childChanged = childChanged || nc!=n[i];
+      }
     }
-    return NodeManager::currentNM()->mkNode( n.getKind(), children );
-  }else if( n.getKind()==kind::EQUAL ){
-    TypeNode tn1 = children[0].getType();
-    TypeNode tn2 = children[1].getType();
-    if( !tn1.isSubtypeOf( tn2 ) && !tn2.isSubtypeOf( tn1 ) ){
-      if( children[0].isConst() ){
-        children[0] = getNewSymbol( children[0], children[1].getType() );
-      }else if( children[1].isConst() ){
-        children[1] = getNewSymbol( children[1], children[0].getType() );
-      }else{
+
+    //remove from variable bindings
+    Node ret;
+    if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
+      //erase from variable bound
+      for( size_t i=0; i<n[0].getNumChildren(); i++ ){
+        Trace("sort-inference-debug2") << "Remove bound for " << n[0][i] << std::endl;
+        var_bound.erase( n[0][i] );
+      }
+      ret = NodeManager::currentNM()->mkNode( n.getKind(), children );
+    }else if( n.getKind()==kind::EQUAL ){
+      TypeNode tn1 = children[0].getType();
+      TypeNode tn2 = children[1].getType();
+      if( !tn1.isSubtypeOf( tn2 ) && !tn2.isSubtypeOf( tn1 ) ){
         Trace("sort-inference-warn") << "Sort inference created bad equality: " << children[0] << " = " << children[1] << std::endl;
         Trace("sort-inference-warn") << "  Types : " << children[0].getType() << " " << children[1].getType() << std::endl;
         Assert( false );
       }
-    }
-    return NodeManager::currentNM()->mkNode( kind::EQUAL, children );
-  }else if( n.getKind()==kind::APPLY_UF ){
-    Node op = n.getOperator();
-    if( d_symbol_map.find( op )==d_symbol_map.end() ){
-      //make the new operator if necessary
-      bool opChanged = false;
-      std::vector< TypeNode > argTypes;
-      for( size_t i=0; i<n.getNumChildren(); i++ ){
-        TypeNode tn = getOrCreateTypeForId( d_op_arg_types[op][i], n[i].getType() );
-        argTypes.push_back( tn );
-        if( tn!=n[i].getType() ){
+      ret = NodeManager::currentNM()->mkNode( kind::EQUAL, children );
+    }else if( n.getKind()==kind::APPLY_UF ){
+      if( d_symbol_map.find( op )==d_symbol_map.end() ){
+        //make the new operator if necessary
+        bool opChanged = false;
+        std::vector< TypeNode > argTypes;
+        for( size_t i=0; i<n.getNumChildren(); i++ ){
+          TypeNode tn = getOrCreateTypeForId( d_op_arg_types[op][i], n[i].getType() );
+          argTypes.push_back( tn );
+          if( tn!=n[i].getType() ){
+            opChanged = true;
+          }
+        }
+        TypeNode retType = getOrCreateTypeForId( d_op_return_types[op], n.getType() );
+        if( retType!=n.getType() ){
           opChanged = true;
         }
-      }
-      TypeNode retType = getOrCreateTypeForId( d_op_return_types[op], n.getType() );
-      if( retType!=n.getType() ){
-        opChanged = true;
-      }
-      if( opChanged ){
-        std::stringstream ss;
-        ss << "io_" << op;
-        TypeNode typ = NodeManager::currentNM()->mkFunctionType( argTypes, retType );
-        d_symbol_map[op] = NodeManager::currentNM()->mkSkolem( ss.str(), typ, "op created during sort inference" );
-        Trace("setp-model") << "Function " << op << " is replaced with " << d_symbol_map[op] << std::endl;
-        d_model_replace_f[op] = d_symbol_map[op];
-      }else{
-        d_symbol_map[op] = op;
-      }
-    }
-    children[0] = d_symbol_map[op];
-    //make sure all children have been taken care of
-    for( size_t i=0; i<n.getNumChildren(); i++ ){
-      TypeNode tn = children[i+1].getType();
-      TypeNode tna = getTypeForId( d_op_arg_types[op][i] );
-      if( tn!=tna ){
-        if( n[i].isConst() ){
-          children[i+1] = getNewSymbol( n[i], tna );
+        if( opChanged ){
+          std::stringstream ss;
+          ss << "io_" << op;
+          TypeNode typ = NodeManager::currentNM()->mkFunctionType( argTypes, retType );
+          d_symbol_map[op] = NodeManager::currentNM()->mkSkolem( ss.str(), typ, "op created during sort inference" );
+          Trace("setp-model") << "Function " << op << " is replaced with " << d_symbol_map[op] << std::endl;
+          d_model_replace_f[op] = d_symbol_map[op];
         }else{
+          d_symbol_map[op] = op;
+        }
+      }
+      children[0] = d_symbol_map[op];
+      //make sure all children have been taken care of
+      for( size_t i=0; i<n.getNumChildren(); i++ ){
+        TypeNode tn = children[i+1].getType();
+        TypeNode tna = getTypeForId( d_op_arg_types[op][i] );
+        if( tn!=tna ){
           Trace("sort-inference-warn") << "Sort inference created bad child: " << n << " " << n[i] << " " << tn << " " << tna << std::endl;
           Assert( false );
         }
       }
-    }
-    return NodeManager::currentNM()->mkNode( kind::APPLY_UF, children );
-  }else{
-    std::map< Node, Node >::iterator it = var_bound.find( n );
-    if( it!=var_bound.end() ){
-      return it->second;
-    }else if( n.getKind() == kind::VARIABLE || n.getKind() == kind::SKOLEM ){
-      if( d_symbol_map.find( n )==d_symbol_map.end() ){
-        TypeNode tn = getOrCreateTypeForId( d_op_return_types[n], n.getType() );
-        d_symbol_map[n] = getNewSymbol( n, tn );
-      }
-      return d_symbol_map[n];
-    }else if( n.isConst() ){
-      //just return n, we will fix at higher scope
-      return n;
+      ret = NodeManager::currentNM()->mkNode( kind::APPLY_UF, children );
     }else{
-      if( childChanged ){
-        return NodeManager::currentNM()->mkNode( n.getKind(), children );
+      std::map< Node, Node >::iterator it = var_bound.find( n );
+      if( it!=var_bound.end() ){
+        ret = it->second;
+      }else if( n.getKind() == kind::VARIABLE || n.getKind() == kind::SKOLEM ){
+        if( d_symbol_map.find( n )==d_symbol_map.end() ){
+          TypeNode tn = getOrCreateTypeForId( d_op_return_types[n], n.getType() );
+          d_symbol_map[n] = getNewSymbol( n, tn );
+        }
+        ret = d_symbol_map[n];
+      }else if( n.isConst() ){
+        //type is determined by context
+        ret = getNewSymbol( n, tnn );
+      }else if( childChanged ){
+        ret = NodeManager::currentNM()->mkNode( n.getKind(), children );
       }else{
-        return n;
+        ret = n;
       }
     }
+    visited[n][tnn] = ret;
+    return ret;
   }
-
 }
 
 Node SortInference::mkInjection( TypeNode tn1, TypeNode tn2 ) {
@@ -728,7 +735,8 @@ void SortInference::setSkolemVar( Node f, Node v, Node sk ){
   if( isWellSortedFormula( f ) && d_var_types.find( f )==d_var_types.end() ){
     //calculate the sort for variables if not done so already
     std::map< Node, Node > var_bound;
-    process( f, var_bound );
+    std::map< Node, int > visited;
+    process( f, var_bound, visited );
   }
   d_op_return_types[sk] = getSortId( f, v );
   Trace("sort-inference-temp") << "Set skolem sort id for " << sk << " to " << d_op_return_types[sk] << std::endl;
index f926776de4b19bc167cbb9966dbb39ba4c77af0f..163a3c53edcdae4ba59bf119161388265b324d0d 100644 (file)
@@ -61,6 +61,7 @@ private:
   //for apply uf operators
   std::map< Node, int > d_op_return_types;
   std::map< Node, std::vector< int > > d_op_arg_types;
+  std::map< Node, int > d_equality_types;
   //for bound variables
   std::map< Node, std::map< Node, int > > d_var_types;
   //get representative
@@ -68,10 +69,10 @@ private:
   int getIdForType( TypeNode tn );
   void printSort( const char* c, int t );
   //process
-  int process( Node n, std::map< Node, Node >& var_bound );
+  int process( Node n, std::map< Node, Node >& var_bound, std::map< Node, int >& visited );
 //for monotonicity inference
 private:
-  void processMonotonic( Node n, bool pol, bool hasPol, std::map< Node, Node >& var_bound, bool typeMode = false );
+  void processMonotonic( Node n, bool pol, bool hasPol, std::map< Node, Node >& var_bound, std::map< Node, std::map< int, bool > >& visited, bool typeMode = false );
 
 //for rewriting
 private:
@@ -84,7 +85,7 @@ private:
   TypeNode getTypeForId( int t );
   Node getNewSymbol( Node old, TypeNode tn );
   //simplify
-  Node simplify( Node n, std::map< Node, Node >& var_bound );
+  Node simplifyNode( Node n, std::map< Node, Node >& var_bound, TypeNode tnn, std::map< Node, std::map< TypeNode, Node > >& visited );
   //make injection
   Node mkInjection( TypeNode tn1, TypeNode tn2 );
   //reset