Introduce quantifiers inference manager (#5821)
[cvc5.git] / src / theory / sort_inference.cpp
index 060584fcfed6cf50dc1544795c5581f89d5f4beb..9fa216e67214c81f56a2e202bc1a4b1dda8e5607 100644 (file)
@@ -1,13 +1,13 @@
 /*********************                                                        */
 /*! \file sort_inference.cpp
  ** \verbatim
- ** Original author: Andrew Reynolds
- ** Major contributors: Morgan Deters
- ** Minor contributors (to current version): Kshitij Bansal
+ ** Top contributors (to current version):
+ **   Andrew Reynolds, Paul Meng, Mathias Preiner
  ** This file is part of the CVC4 project.
- ** Copyright (c) 2009-2014  New York University and The University of Iowa
- ** See the file COPYING in the top-level source directory for licensing
- ** information.\endverbatim
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
  **
  ** \brief Sort inference module
  **
 #include "options/quantifiers_options.h"
 #include "options/smt_options.h"
 #include "options/uf_options.h"
-#include "proof/proof_manager.h"
 #include "theory/rewriter.h"
 #include "theory/quantifiers/quant_util.h"
 
 using namespace CVC4;
+using namespace CVC4::kind;
 using namespace std;
 
 namespace CVC4 {
@@ -102,7 +102,7 @@ void SortInference::reset() {
   d_non_monotonic_sorts.clear();
   d_type_sub_sorts.clear();
   //reset info
-  sortCount = 1;
+  d_sortCount = 1;
   d_type_union_find.clear();
   d_type_types.clear();
   d_id_for_types.clear();
@@ -114,174 +114,195 @@ void SortInference::reset() {
   d_const_map.clear();
 }
 
-void SortInference::simplify( std::vector< Node >& assertions, bool doSortInference, bool doMonotonicyInference ){
-  if( doSortInference ){
-    Trace("sort-inference-proc") << "Calculating sort inference..." << std::endl;
-    //process all assertions
-    for( unsigned i=0; i<assertions.size(); i++ ){
-      Trace("sort-inference-debug") << "Process " << assertions[i] << std::endl;
-      std::map< Node, Node > var_bound;
-      process( assertions[i], var_bound );
-    }
-    Trace("sort-inference-proc") << "...done" << std::endl;
-    for( std::map< Node, int >::iterator it = d_op_return_types.begin(); it != d_op_return_types.end(); ++it ){
-      Trace("sort-inference") << it->first << " : ";
-      TypeNode retTn = it->first.getType();
-      if( !d_op_arg_types[ it->first ].empty() ){
-        Trace("sort-inference") << "( ";
-        for( size_t i=0; i<d_op_arg_types[ it->first ].size(); i++ ){
-          recordSubsort( retTn[i], d_op_arg_types[ it->first ][i] );
-          printSort( "sort-inference", d_op_arg_types[ it->first ][i] );
-          Trace("sort-inference") << " ";
-        }
-        Trace("sort-inference") << ") -> ";
-        retTn = retTn[(int)retTn.getNumChildren()-1];
+void SortInference::initialize(const std::vector<Node>& assertions)
+{
+  Trace("sort-inference-proc") << "Calculating sort inference..." << std::endl;
+  // process all assertions
+  std::map<Node, int> visited;
+  NodeManager * nm = NodeManager::currentNM();
+  int btId = getIdForType( nm->booleanType() );
+  for (const Node& a : assertions)
+  {
+    Trace("sort-inference-debug") << "Process " << a << std::endl;
+    std::map<Node, Node> var_bound;
+    int pid = process(a, var_bound, visited);
+    // the type of the topmost term must be Boolean
+    setEqual( pid, btId );
+  }
+  Trace("sort-inference-proc") << "...done" << std::endl;
+  for (const std::pair<const Node, int>& rt : d_op_return_types)
+  {
+    Trace("sort-inference") << rt.first << " : ";
+    TypeNode retTn = rt.first.getType();
+    if (!d_op_arg_types[rt.first].empty())
+    {
+      Trace("sort-inference") << "( ";
+      for (size_t i = 0; i < d_op_arg_types[rt.first].size(); i++)
+      {
+        recordSubsort(retTn[i], d_op_arg_types[rt.first][i]);
+        printSort("sort-inference", d_op_arg_types[rt.first][i]);
+        Trace("sort-inference") << " ";
       }
-      recordSubsort( retTn, it->second );
-      printSort( "sort-inference", it->second );
-      Trace("sort-inference") << std::endl;
+      Trace("sort-inference") << ") -> ";
+      retTn = retTn[(int)retTn.getNumChildren() - 1];
     }
-    for( std::map< Node, std::map< Node, int > >::iterator it = d_var_types.begin(); it != d_var_types.end(); ++it ){
-      Trace("sort-inference") << "Quantified formula : " << it->first << " : " << std::endl;
-      for( unsigned i=0; i<it->first[0].getNumChildren(); i++ ){
-        recordSubsort( it->first[0][i].getType(), it->second[it->first[0][i]] );
-        printSort( "sort-inference", it->second[it->first[0][i]] );
-        Trace("sort-inference") << std::endl;
-      }
+    recordSubsort(retTn, rt.second);
+    printSort("sort-inference", rt.second);
+    Trace("sort-inference") << std::endl;
+  }
+  for (std::pair<const Node, std::map<Node, int> >& vt : d_var_types)
+  {
+    Trace("sort-inference")
+        << "Quantified formula : " << vt.first << " : " << std::endl;
+    for (const Node& v : vt.first[0])
+    {
+      recordSubsort(v.getType(), vt.second[v]);
+      printSort("sort-inference", vt.second[v]);
       Trace("sort-inference") << std::endl;
     }
+    Trace("sort-inference") << std::endl;
+  }
 
-    if( !options::ufssSymBreak() ){
-      bool rewritten = false;
-      //determine monotonicity of sorts
-      Trace("sort-inference-proc") << "Calculating monotonicty for subsorts..." << std::endl;
-      for( unsigned i=0; i<assertions.size(); i++ ){
-        Trace("sort-inference-debug") << "Process monotonicity for " << assertions[i] << std::endl;
-        std::map< Node, Node > var_bound;
-        processMonotonic( assertions[i], true, true, var_bound );
-      }
-      Trace("sort-inference-proc") << "...done" << std::endl;
-
-      Trace("sort-inference") << "We have " << d_sub_sorts.size() << " sub-sorts : " << std::endl;
-      for( unsigned i=0; i<d_sub_sorts.size(); i++ ){
-        printSort( "sort-inference", d_sub_sorts[i] );
-        if( d_type_types.find( d_sub_sorts[i] )!=d_type_types.end() ){
-          Trace("sort-inference") << " is interpreted." << std::endl;
-        }else if( d_non_monotonic_sorts.find( d_sub_sorts[i] )==d_non_monotonic_sorts.end() ){
-          Trace("sort-inference") << " is monotonic." << std::endl;
-        }else{
-          Trace("sort-inference") << " is not monotonic." << std::endl;
-        }
-      }
+  // determine monotonicity of sorts
+  Trace("sort-inference-proc")
+      << "Calculating monotonicty for subsorts..." << std::endl;
+  std::map<Node, std::map<int, bool> > visitedm;
+  for (const Node& a : assertions)
+  {
+    Trace("sort-inference-debug")
+        << "Process monotonicity for " << a << std::endl;
+    std::map<Node, Node> var_bound;
+    processMonotonic(a, true, true, var_bound, visitedm);
+  }
+  Trace("sort-inference-proc") << "...done" << std::endl;
+
+  Trace("sort-inference") << "We have " << d_sub_sorts.size()
+                          << " sub-sorts : " << std::endl;
+  for (unsigned i = 0, size = d_sub_sorts.size(); i < size; i++)
+  {
+    printSort("sort-inference", d_sub_sorts[i]);
+    if (d_type_types.find(d_sub_sorts[i]) != d_type_types.end())
+    {
+      Trace("sort-inference") << " is interpreted." << std::endl;
+    }
+    else if (d_non_monotonic_sorts.find(d_sub_sorts[i])
+             == d_non_monotonic_sorts.end())
+    {
+      Trace("sort-inference") << " is monotonic." << std::endl;
+    }
+    else
+    {
+      Trace("sort-inference") << " is not monotonic." << std::endl;
+    }
+  }
+}
 
-      //simplify all assertions by introducing new symbols wherever necessary
-      Trace("sort-inference-proc") << "Perform simplification..." << std::endl;
-      for( unsigned i=0; i<assertions.size(); i++ ){
-        Node prev = assertions[i];
-        std::map< Node, Node > var_bound;
-        Trace("sort-inference-debug") << "Rewrite " << assertions[i] << std::endl;
-        Node curr = simplify( assertions[i], var_bound );
-        Trace("sort-inference-debug") << "Done." << std::endl;
-        if( curr!=assertions[i] ){
-          curr = theory::Rewriter::rewrite( curr );
-          rewritten = true;
-          Trace("sort-inference-rewrite") << assertions << std::endl;
-          Trace("sort-inference-rewrite") << " --> " << curr << std::endl;
-          PROOF( ProofManager::currentPM()->addDependence(curr, assertions[i]); );
-          assertions[i] = curr;
-        }
-      }
-      Trace("sort-inference-proc") << "...done" << std::endl;
-      //now, ensure constants are distinct
-      for( std::map< TypeNode, std::map< Node, Node > >::iterator it = d_const_map.begin(); it != d_const_map.end(); ++it ){
-        std::vector< Node > consts;
-        for( std::map< Node, Node >::iterator it2 = it->second.begin(); it2 != it->second.end(); ++it2 ){
-          consts.push_back( it2->second );
-        }
-        //TODO: add lemma enforcing introduced constants to be distinct
-      }
+Node SortInference::simplify(Node n,
+                             std::map<Node, Node>& model_replace_f,
+                             std::map<Node, std::map<TypeNode, Node> >& visited)
+{
+  Trace("sort-inference-debug") << "Simplify " << n << std::endl;
+  std::map<Node, Node> var_bound;
+  TypeNode tnn;
+  Node ret = simplifyNode(n, var_bound, tnn, model_replace_f, visited);
+  ret = theory::Rewriter::rewrite(ret);
+  return ret;
+}
 
-      //enforce constraints based on monotonicity
-      Trace("sort-inference-proc") << "Enforce monotonicity..." << std::endl;
-      for( std::map< TypeNode, std::vector< int > >::iterator it = d_type_sub_sorts.begin(); it != d_type_sub_sorts.end(); ++it ){
-        int nmonSort = -1;
-        for( unsigned i=0; i<it->second.size(); i++ ){
-          if( d_non_monotonic_sorts.find( it->second[i] )!=d_non_monotonic_sorts.end() ){
-            nmonSort = it->second[i];
-            break;
-          }
-        }
-        if( nmonSort!=-1 ){
-          std::vector< Node > injections;
-          TypeNode base_tn = getOrCreateTypeForId( nmonSort, it->first );
-          for( unsigned i=0; i<it->second.size(); i++ ){
-            if( it->second[i]!=nmonSort ){
-              TypeNode new_tn = getOrCreateTypeForId( it->second[i], it->first );
-              //make injection to nmonSort
-              Node a1 = mkInjection( new_tn, base_tn );
-              injections.push_back( a1 );
-              if( d_non_monotonic_sorts.find( it->second[i] )!=d_non_monotonic_sorts.end() ){
-                //also must make injection from nmonSort to this
-                Node a2 = mkInjection( base_tn, new_tn );
-                injections.push_back( a2 );
-              }
-            }
-          }
-          Trace("sort-inference-rewrite") << "Add the following injections for " << it->first << " to ensure consistency wrt non-monotonic sorts : " << std::endl;
-          for( unsigned j=0; j<injections.size(); j++ ){
-            Trace("sort-inference-rewrite") << "   " << injections[j] << std::endl;
-          }
-          assertions.insert( assertions.end(), injections.begin(), injections.end() );
-          if( !injections.empty() ){
-            rewritten = true;
-          }
-        }
+void SortInference::getNewAssertions(std::vector<Node>& new_asserts)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  // now, ensure constants are distinct
+  for (const std::pair<const TypeNode, std::map<Node, Node> >& cm : d_const_map)
+  {
+    std::vector<Node> consts;
+    for (const std::pair<const Node, Node>& c : cm.second)
+    {
+      Assert(c.first.isConst());
+      consts.push_back(c.second);
+    }
+    // add lemma enforcing introduced constants to be distinct
+    if (consts.size() > 1)
+    {
+      Node distinct_const = nm->mkNode(kind::DISTINCT, consts);
+      Trace("sort-inference-rewrite")
+          << "Add the constant distinctness lemma: " << std::endl;
+      Trace("sort-inference-rewrite") << "  " << distinct_const << std::endl;
+      new_asserts.push_back(distinct_const);
+    }
+  }
+
+  // enforce constraints based on monotonicity
+  Trace("sort-inference-proc") << "Enforce monotonicity..." << std::endl;
+
+  for (const std::pair<const TypeNode, std::vector<int> >& tss :
+       d_type_sub_sorts)
+  {
+    int nmonSort = -1;
+    unsigned nsorts = tss.second.size();
+    for (unsigned i = 0; i < nsorts; i++)
+    {
+      if (d_non_monotonic_sorts.find(tss.second[i])
+          != d_non_monotonic_sorts.end())
+      {
+        nmonSort = tss.second[i];
+        break;
       }
-      Trace("sort-inference-proc") << "...done" << std::endl;
-      //no sub-sort information is stored
-      reset();
-      Trace("sort-inference-debug") << "Finished sort inference, rewritten = " << rewritten << std::endl;
     }
-    /*
-    else if( !options::ufssSymBreak() ){
-      //just add the unit lemmas between constants
-      std::map< TypeNode, std::map< int, Node > > constants;
-      for( std::map< Node, int >::iterator it = d_op_return_types.begin(); it != d_op_return_types.end(); ++it ){
-        int rt = d_type_union_find.getRepresentative( it->second );
-        if( d_op_arg_types[ it->first ].empty() ){
-          TypeNode tn = it->first.getType();
-          if( constants[ tn ].find( rt )==constants[ tn ].end() ){
-            constants[ tn ][ rt ] = it->first;
+    if (nmonSort != -1)
+    {
+      std::vector<Node> injections;
+      TypeNode base_tn = getOrCreateTypeForId(nmonSort, tss.first);
+      for (unsigned i = 0; i < nsorts; i++)
+      {
+        if (tss.second[i] != nmonSort)
+        {
+          TypeNode new_tn = getOrCreateTypeForId(tss.second[i], tss.first);
+          // make injection to nmonSort
+          Node a1 = mkInjection(new_tn, base_tn);
+          injections.push_back(a1);
+          if (d_non_monotonic_sorts.find(tss.second[i])
+              != d_non_monotonic_sorts.end())
+          {
+            // also must make injection from nmonSort to this
+            Node a2 = mkInjection(base_tn, new_tn);
+            injections.push_back(a2);
           }
         }
       }
-      //add unit lemmas for each constant
-      for( std::map< TypeNode, std::map< int, Node > >::iterator it = constants.begin(); it != constants.end(); ++it ){
-        Node first_const;
-        for( std::map< int, Node >::iterator it2 = it->second.begin(); it2 != it->second.end(); ++it2 ){
-          if( first_const.isNull() ){
-            first_const = it2->second;
-          }else{
-            Node eq = first_const.eqNode( it2->second );
-            //eq = Rewriter::rewrite( eq );
-            Trace("sort-inference-lemma") << "Sort inference lemma : " << eq << std::endl;
-            assertions.push_back( eq );
-          }
+      if (Trace.isOn("sort-inference-rewrite"))
+      {
+        Trace("sort-inference-rewrite")
+            << "Add the following injections for " << tss.first
+            << " to ensure consistency wrt non-monotonic sorts : " << std::endl;
+        for (const Node& i : injections)
+        {
+          Trace("sort-inference-rewrite") << "   " << i << std::endl;
         }
       }
+      new_asserts.insert(
+          new_asserts.end(), injections.begin(), injections.end());
     }
-    */
-    initialSortCount = sortCount;
   }
-  if( doMonotonicyInference ){
-    Trace("sort-inference-proc") << "Calculating monotonicty for types..." << std::endl;
-    for( unsigned i=0; i<assertions.size(); i++ ){
-      Trace("sort-inference-debug") << "Process type monotonicity for " << assertions[i] << std::endl;
-      std::map< Node, Node > var_bound;
-      processMonotonic( assertions[i], true, true, var_bound, true );
-    }
-    Trace("sort-inference-proc") << "...done" << std::endl;
+  Trace("sort-inference-proc") << "...done" << std::endl;
+  // no sub-sort information is stored
+  reset();
+  Trace("sort-inference-debug") << "Finished sort inference" << std::endl;
+}
+
+void SortInference::computeMonotonicity(const std::vector<Node>& assertions)
+{
+  std::map<Node, std::map<int, bool> > visitedmt;
+  Trace("sort-inference-proc")
+      << "Calculating monotonicty for types..." << std::endl;
+  for (const Node& a : assertions)
+  {
+    Trace("sort-inference-debug")
+        << "Process type monotonicity for " << a << std::endl;
+    std::map<Node, Node> var_bound;
+    processMonotonic(a, true, true, var_bound, visitedmt, true);
   }
+  Trace("sort-inference-proc") << "...done" << std::endl;
 }
 
 void SortInference::setEqual( int t1, int t2 ){
@@ -328,184 +349,209 @@ int SortInference::getIdForType( TypeNode tn ){
   //register the return type
   std::map< TypeNode, int >::iterator it = d_id_for_types.find( tn );
   if( it==d_id_for_types.end() ){
-    int sc = sortCount;
-    d_type_types[ sortCount ] = tn;
-    d_id_for_types[ tn ] = sortCount;
-    sortCount++;
+    int sc = d_sortCount;
+    d_type_types[d_sortCount] = tn;
+    d_id_for_types[tn] = d_sortCount;
+    d_sortCount++;
     return sc;
   }else{
     return it->second;
   }
 }
 
-int SortInference::process( Node n, std::map< Node, Node >& var_bound ){
-  //add to variable bindings
-  if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
-    if( d_var_types.find( n )!=d_var_types.end() ){
-      return getIdForType( n.getType() );
-    }else{
-      for( size_t i=0; i<n[0].getNumChildren(); i++ ){
+int SortInference::process( Node n, std::map< Node, Node >& var_bound, std::map< Node, int >& visited ){
+  std::map< Node, int >::iterator itv = visited.find( n );
+  if( itv!=visited.end() ){
+    return itv->second;
+  }else{
+    //add to variable bindings
+    bool use_new_visited = false;
+    std::map< Node, int > new_visited;
+    if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
+      if( d_var_types.find( n )!=d_var_types.end() ){
+        return getIdForType( n.getType() );
+      }else{
         //apply sort inference to quantified variables
-        d_var_types[n][ n[0][i] ] = sortCount;
-        sortCount++;
-
-        //type of the quantified variable must be the same
-        var_bound[ n[0][i] ] = n;
+        for( size_t i=0; i<n[0].getNumChildren(); i++ ){
+          TypeNode nitn = n[0][i].getType();
+          if( !nitn.isSort() )
+          {
+            // If the variable is of an interpreted sort, we assume the
+            // the sort of the variable will stay the same sort.
+            d_var_types[n][n[0][i]] = getIdForType( nitn );
+          }
+          else
+          {
+            // If it is of an uninterpreted sort, infer subsorts.
+            d_var_types[n][n[0][i]] = d_sortCount;
+            d_sortCount++;
+          }
+          var_bound[ n[0][i] ] = n;
+        }
       }
+      use_new_visited = true;
     }
-  }
 
-  //process children
-  std::vector< Node > children;
-  std::vector< int > child_types;
-  for( size_t i=0; i<n.getNumChildren(); i++ ){
-    bool processChild = true;
-    if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
-      processChild = options::userPatternsQuant()==theory::quantifiers::USER_PAT_MODE_IGNORE ? i==1 : i>=1;
-    }
-    if( processChild ){
-      children.push_back( n[i] );
-      child_types.push_back( process( n[i], var_bound ) );
+    //process children
+    std::vector< Node > children;
+    std::vector< int > child_types;
+    for( size_t i=0; i<n.getNumChildren(); i++ ){
+      bool processChild = true;
+      if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
+        processChild =
+            options::userPatternsQuant() == options::UserPatMode::IGNORE
+                ? i == 1
+                : i >= 1;
+      }
+      if( processChild ){
+        children.push_back( n[i] );
+        child_types.push_back( process( n[i], var_bound, use_new_visited ? new_visited : visited ) );
+      }
     }
-  }
 
-  //remove from variable bindings
-  if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
-    //erase from variable bound
-    for( size_t i=0; i<n[0].getNumChildren(); i++ ){
-      var_bound.erase( n[0][i] );
-    }
-  }
-  Trace("sort-inference-debug") << "...Process " << n << std::endl;
-
-  int retType;
-  if( n.getKind()==kind::EQUAL ){
-    Trace("sort-inference-debug") << "For equality " << n << ", set equal types from : " << n[0].getType() << " " << n[1].getType() << std::endl;
-    //if original types are mixed (e.g. Int/Real), don't commit type equality in either direction
-    if( n[0].getType()!=n[1].getType() ){
-      //for now, assume the original types
-      for( unsigned i=0; i<2; i++ ){
-        int ct = getIdForType( n[i].getType() );
-        setEqual( child_types[i], ct );
+    //remove from variable bindings
+    if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
+      //erase from variable bound
+      for( size_t i=0; i<n[0].getNumChildren(); i++ ){
+        var_bound.erase( n[0][i] );
       }
-    }else{
-      //we only require that the left and right hand side must be equal
-      setEqual( child_types[0], child_types[1] );
     }
-    //int eqType = getIdForType( n[0].getType() );
-    //setEqual( child_types[0], eqType );
-    //setEqual( child_types[1], eqType );
-    retType = getIdForType( n.getType() );
-  }else if( n.getKind()==kind::APPLY_UF ){
-    Node op = n.getOperator();
-    TypeNode tn_op = op.getType();
-    if( d_op_return_types.find( op )==d_op_return_types.end() ){
-      if( n.getType().isBoolean() ){
-        //use booleans
-        d_op_return_types[op] = getIdForType( n.getType() );
+    Trace("sort-inference-debug") << "...Process " << n << std::endl;
+
+    int retType;
+    if( n.getKind()==kind::EQUAL && !n[0].getType().isBoolean() ){
+      Trace("sort-inference-debug") << "For equality " << n << ", set equal types from : " << n[0].getType() << " " << n[1].getType() << std::endl;
+      //if original types are mixed (e.g. Int/Real), don't commit type equality in either direction
+      if( n[0].getType()!=n[1].getType() ){
+        //for now, assume the original types
+        for( unsigned i=0; i<2; i++ ){
+          int ct = getIdForType( n[i].getType() );
+          setEqual( child_types[i], ct );
+        }
       }else{
-        //assign arbitrary sort for return type
-        d_op_return_types[op] = sortCount;
-        sortCount++;
-      }
-      //d_type_eq_class[sortCount].push_back( op );
-      //assign arbitrary sort for argument types
-      for( size_t i=0; i<n.getNumChildren(); i++ ){
-        d_op_arg_types[op].push_back( sortCount );
-        sortCount++;
+        //we only require that the left and right hand side must be equal
+        setEqual( child_types[0], child_types[1] );
       }
+      d_equality_types[n] = child_types[0];
+      retType = getIdForType( n.getType() );
     }
-    for( size_t i=0; i<n.getNumChildren(); i++ ){
-      //the argument of the operator must match the return type of the subterm
-      if( n[i].getType()!=tn_op[i] ){
-        //if type mismatch, assume original types
-        Trace("sort-inference-debug") << "Argument " << i << " of " << op << " " << n[i] << " has type " << n[i].getType();
-        Trace("sort-inference-debug") << ", while operator arg has type " << tn_op[i] << std::endl;
-        int ct1 = getIdForType( n[i].getType() );
-        setEqual( child_types[i], ct1 );
-        int ct2 = getIdForType( tn_op[i] );
-        setEqual( d_op_arg_types[op][i], ct2 );
-      }else{
-        setEqual( child_types[i], d_op_arg_types[op][i] );
+    else if (isHandledApplyUf(n.getKind()))
+    {
+      Node op = n.getOperator();
+      TypeNode tn_op = op.getType();
+      if( d_op_return_types.find( op )==d_op_return_types.end() ){
+        if( n.getType().isBoolean() ){
+          //use booleans
+          d_op_return_types[op] = getIdForType( n.getType() );
+        }else{
+          //assign arbitrary sort for return type
+          d_op_return_types[op] = d_sortCount;
+          d_sortCount++;
+        }
+        // d_type_eq_class[d_sortCount].push_back( op );
+        // assign arbitrary sort for argument types
+        for( size_t i=0; i<n.getNumChildren(); i++ ){
+          d_op_arg_types[op].push_back(d_sortCount);
+          d_sortCount++;
+        }
       }
-    }
-    //return type is the return type
-    retType = d_op_return_types[op];
-  }else{
-    std::map< Node, Node >::iterator it = var_bound.find( n );
-    if( it!=var_bound.end() ){
-      Trace("sort-inference-debug") << n << " is a bound variable." << std::endl;
-      //the return type was specified while binding
-      retType = d_var_types[it->second][n];
-    }else if( n.getKind() == kind::VARIABLE || n.getKind()==kind::SKOLEM ){
-      Trace("sort-inference-debug") << n << " is a variable." << std::endl;
-      if( d_op_return_types.find( n )==d_op_return_types.end() ){
-        //assign arbitrary sort
-        d_op_return_types[n] = sortCount;
-        sortCount++;
-        //d_type_eq_class[sortCount].push_back( n );
+      for( size_t i=0; i<n.getNumChildren(); i++ ){
+        //the argument of the operator must match the return type of the subterm
+        if( n[i].getType()!=tn_op[i] ){
+          //if type mismatch, assume original types
+          Trace("sort-inference-debug") << "Argument " << i << " of " << op << " " << n[i] << " has type " << n[i].getType();
+          Trace("sort-inference-debug") << ", while operator arg has type " << tn_op[i] << std::endl;
+          int ct1 = getIdForType( n[i].getType() );
+          setEqual( child_types[i], ct1 );
+          int ct2 = getIdForType( tn_op[i] );
+          setEqual( d_op_arg_types[op][i], ct2 );
+        }else{
+          setEqual( child_types[i], d_op_arg_types[op][i] );
+        }
       }
-      retType = d_op_return_types[n];
-    //}else if( n.isConst() ){
-    //  Trace("sort-inference-debug") << n << " is a constant." << std::endl;
-      //can be any type we want
-    //  retType = sortCount;
-    //  sortCount++;
+      //return type is the return type
+      retType = d_op_return_types[op];
     }else{
-      Trace("sort-inference-debug") << n << " is a interpreted symbol." << std::endl;
-      //it is an interpretted term
-      for( size_t i=0; i<children.size(); i++ ){
-        Trace("sort-inference-debug") << children[i] << " forced to have " << children[i].getType() << std::endl;
-        //must enforce the actual type of the operator on the children
-        int ct = getIdForType( children[i].getType() );
-        setEqual( child_types[i], ct );
+      std::map< Node, Node >::iterator it = var_bound.find( n );
+      if( it!=var_bound.end() ){
+        Trace("sort-inference-debug") << n << " is a bound variable." << std::endl;
+        //the return type was specified while binding
+        retType = d_var_types[it->second][n];
+      }else if( n.isVar() ){
+        Trace("sort-inference-debug") << n << " is a variable." << std::endl;
+        if( d_op_return_types.find( n )==d_op_return_types.end() ){
+          //assign arbitrary sort
+          d_op_return_types[n] = d_sortCount;
+          d_sortCount++;
+          // d_type_eq_class[d_sortCount].push_back( n );
+        }
+        retType = d_op_return_types[n];
+      }else if( n.isConst() ){
+        Trace("sort-inference-debug") << n << " is a constant." << std::endl;
+        //can be any type we want
+        retType = d_sortCount;
+        d_sortCount++;
+      }else{
+        Trace("sort-inference-debug") << n << " is a interpreted symbol." << std::endl;
+        //it is an interpreted term
+        for( size_t i=0; i<children.size(); i++ ){
+          Trace("sort-inference-debug") << children[i] << " forced to have " << children[i].getType() << std::endl;
+          //must enforce the actual type of the operator on the children
+          int ct = getIdForType( children[i].getType() );
+          setEqual( child_types[i], ct );
+        }
+        //return type must be the actual return type
+        retType = getIdForType( n.getType() );
       }
-      //return type must be the actual return type
-      retType = getIdForType( n.getType() );
     }
+    Trace("sort-inference-debug") << "...Type( " << n << " ) = ";
+    printSort("sort-inference-debug", retType );
+    Trace("sort-inference-debug") << std::endl;
+    visited[n] = retType;
+    return retType;
   }
-  Trace("sort-inference-debug") << "...Type( " << n << " ) = ";
-  printSort("sort-inference-debug", retType );
-  Trace("sort-inference-debug") << std::endl;
-  return retType;
 }
 
-void SortInference::processMonotonic( Node n, bool pol, bool hasPol, std::map< Node, Node >& var_bound, bool typeMode ) {
-  Trace("sort-inference-debug") << "...Process monotonic " << pol << " " << hasPol << " " << n << std::endl;
-  if( n.getKind()==kind::FORALL ){
-    //only consider variables universally if it is possible this quantified formula is asserted positively
-    if( !hasPol || pol ){
-      for( unsigned i=0; i<n[0].getNumChildren(); i++ ){
-        var_bound[n[0][i]] = n;
+void SortInference::processMonotonic( Node n, bool pol, bool hasPol, std::map< Node, Node >& var_bound, std::map< Node, std::map< int, bool > >& visited, bool typeMode ) {
+  int pindex = hasPol ? ( pol ? 1 : -1 ) : 0;
+  if( visited[n].find( pindex )==visited[n].end() ){
+    visited[n][pindex] = true;
+    Trace("sort-inference-debug") << "...Process monotonic " << pol << " " << hasPol << " " << n << std::endl;
+    if( n.getKind()==kind::FORALL ){
+      //only consider variables universally if it is possible this quantified formula is asserted positively
+      if( !hasPol || pol ){
+        for( unsigned i=0; i<n[0].getNumChildren(); i++ ){
+          var_bound[n[0][i]] = n;
+        }
       }
-    }
-    processMonotonic( n[1], pol, hasPol, var_bound, typeMode );
-    if( !hasPol || pol ){
-      for( unsigned i=0; i<n[0].getNumChildren(); i++ ){
-        var_bound.erase( n[0][i] );
+      processMonotonic( n[1], pol, hasPol, var_bound, visited, typeMode );
+      if( !hasPol || pol ){
+        for( unsigned i=0; i<n[0].getNumChildren(); i++ ){
+          var_bound.erase( n[0][i] );
+        }
       }
-    }
-    return;
-  }else if( n.getKind()==kind::EQUAL ){
-    if( !hasPol || pol ){
-      for( unsigned i=0; i<2; i++ ){
-        if( var_bound.find( n[i] )!=var_bound.end() ){
-          if( !typeMode ){
-            int sid = getSortId( var_bound[n[i]], n[i] );
-            d_non_monotonic_sorts[sid] = true;
-          }else{
-            d_non_monotonic_sorts_orig[n[i].getType()] = true;
+      return;
+    }else if( n.getKind()==kind::EQUAL ){
+      if( !hasPol || pol ){
+        for( unsigned i=0; i<2; i++ ){
+          if( var_bound.find( n[i] )!=var_bound.end() ){
+            if( !typeMode ){
+              int sid = getSortId( var_bound[n[i]], n[i] );
+              d_non_monotonic_sorts[sid] = true;
+            }else{
+              d_non_monotonic_sorts_orig[n[i].getType()] = true;
+            }
+            break;
           }
-          break;
         }
       }
     }
-  }
-  for( unsigned i=0; i<n.getNumChildren(); i++ ){
-    bool npol;
-    bool nhasPol;
-    theory::QuantPhaseReq::getPolarity( n, i, hasPol, pol, nhasPol, npol );
-    processMonotonic( n[i], npol, nhasPol, var_bound, typeMode );
+    for( unsigned i=0; i<n.getNumChildren(); i++ ){
+      bool npol;
+      bool nhasPol;
+      theory::QuantPhaseReq::getPolarity( n, i, hasPol, pol, nhasPol, npol );
+      processMonotonic( n[i], npol, nhasPol, var_bound, visited, typeMode );
+    }
   }
 }
 
@@ -516,8 +562,13 @@ TypeNode SortInference::getOrCreateTypeForId( int t, TypeNode pref ){
     return d_type_types[rt];
   }else{
     TypeNode retType;
-    //see if we can assign pref
-    if( !pref.isNull() && d_id_for_types.find( pref )==d_id_for_types.end() ){
+    // See if we can assign pref. This is an optimization for reusing an
+    // uninterpreted sort as the first subsort, so that fewer symbols needed
+    // to be rewritten in the sort-inferred signature. Notice we only assign
+    // pref here if it is an uninterpreted sort.
+    if (!pref.isNull() && d_id_for_types.find(pref) == d_id_for_types.end()
+        && pref.isSort())
+    {
       retType = pref;
     }else{
       //must create new type
@@ -544,7 +595,8 @@ TypeNode SortInference::getTypeForId( int t ){
 }
 
 Node SortInference::getNewSymbol( Node old, TypeNode tn ){
-  if( tn==old.getType() ){
+  // if no sort was inferred for this node, return original
+  if( tn.isNull() || tn.isComparableTo( old.getType() ) ){
     return old;
   }else if( old.isConst() ){
     //must make constant of type tn
@@ -565,128 +617,160 @@ Node SortInference::getNewSymbol( Node old, TypeNode tn ){
   }
 }
 
-Node SortInference::simplify( Node n, std::map< Node, Node >& var_bound ){
-  Trace("sort-inference-debug2") << "Simplify " << n << std::endl;
-  std::vector< Node > children;
-  if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
-    //recreate based on types of variables
-    std::vector< Node > new_children;
-    for( size_t i=0; i<n[0].getNumChildren(); i++ ){
-      TypeNode tn = getOrCreateTypeForId( d_var_types[n][ n[0][i] ], n[0][i].getType() );
-      Node v = getNewSymbol( n[0][i], tn );
-      Trace("sort-inference-debug2") << "Map variable " << n[0][i] << " to " << v << std::endl;
-      new_children.push_back( v );
-      var_bound[ n[0][i] ] = v;
+Node SortInference::simplifyNode(
+    Node n,
+    std::map<Node, Node>& var_bound,
+    TypeNode tnn,
+    std::map<Node, Node>& model_replace_f,
+    std::map<Node, std::map<TypeNode, Node> >& visited)
+{
+  std::map< TypeNode, Node >::iterator itv = visited[n].find( tnn );
+  if( itv!=visited[n].end() ){
+    return itv->second;
+  }else{
+    Trace("sort-inference-debug2") << "Simplify " << n << ", type context=" << tnn << std::endl;
+    std::vector< Node > children;
+    std::map< Node, std::map< TypeNode, Node > > new_visited;
+    bool use_new_visited = false;
+    if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
+      //recreate based on types of variables
+      std::vector< Node > new_children;
+      for( size_t i=0; i<n[0].getNumChildren(); i++ ){
+        TypeNode tn = getOrCreateTypeForId( d_var_types[n][ n[0][i] ], n[0][i].getType() );
+        Node v = getNewSymbol( n[0][i], tn );
+        Trace("sort-inference-debug2") << "Map variable " << n[0][i] << " to " << v << std::endl;
+        new_children.push_back( v );
+        var_bound[ n[0][i] ] = v;
+      }
+      children.push_back( NodeManager::currentNM()->mkNode( n[0].getKind(), new_children ) );
+      use_new_visited = true;
     }
-    children.push_back( NodeManager::currentNM()->mkNode( n[0].getKind(), new_children ) );
-  }
 
-  //process children
-  if( n.getMetaKind() == kind::metakind::PARAMETERIZED ){
-    children.push_back( n.getOperator() );
-  }
-  bool childChanged = false;
-  for( size_t i=0; i<n.getNumChildren(); i++ ){
-    bool processChild = true;
-    if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
-      processChild = options::userPatternsQuant()==theory::quantifiers::USER_PAT_MODE_IGNORE ? i==1 : i>=1;
+    //process children
+    if( n.getMetaKind() == kind::metakind::PARAMETERIZED ){
+      children.push_back( n.getOperator() );
     }
-    if( processChild ){
-      Node nc = simplify( n[i], var_bound );
-      Trace("sort-inference-debug2") << "Simplify " << i << " " << n[i] << " returned " << nc << std::endl;
-      children.push_back( nc );
-      childChanged = childChanged || nc!=n[i];
+    Node op;
+    if( n.hasOperator() ){
+      op = n.getOperator();
     }
-  }
-
-  //remove from variable bindings
-  if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
-    //erase from variable bound
-    for( size_t i=0; i<n[0].getNumChildren(); i++ ){
-      Trace("sort-inference-debug2") << "Remove bound for " << n[0][i] << std::endl;
-      var_bound.erase( n[0][i] );
+    bool childChanged = false;
+    TypeNode tnnc;
+    for( size_t i=0; i<n.getNumChildren(); i++ ){
+      bool processChild = true;
+      if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
+        processChild =
+            options::userPatternsQuant() == options::UserPatMode::IGNORE
+                ? i == 1
+                : i >= 1;
+      }
+      if( processChild ){
+        if (isHandledApplyUf(n.getKind()))
+        {
+          Assert(d_op_arg_types.find(op) != d_op_arg_types.end());
+          tnnc = getOrCreateTypeForId( d_op_arg_types[op][i], n[i].getType() );
+          Assert(!tnnc.isNull());
+        }
+        else if (n.getKind() == kind::EQUAL && !n[0].getType().isBoolean()
+                 && i == 0)
+        {
+          Assert(d_equality_types.find(n) != d_equality_types.end());
+          tnnc = getOrCreateTypeForId( d_equality_types[n], n[0].getType() );
+          Assert(!tnnc.isNull());
+        }
+        Node nc = simplifyNode(n[i],
+                               var_bound,
+                               tnnc,
+                               model_replace_f,
+                               use_new_visited ? new_visited : visited);
+        Trace("sort-inference-debug2") << "Simplify " << i << " " << n[i] << " returned " << nc << std::endl;
+        children.push_back( nc );
+        childChanged = childChanged || nc!=n[i];
+      }
     }
-    return NodeManager::currentNM()->mkNode( n.getKind(), children );
-  }else if( n.getKind()==kind::EQUAL ){
-    TypeNode tn1 = children[0].getType();
-    TypeNode tn2 = children[1].getType();
-    if( !tn1.isSubtypeOf( tn2 ) && !tn2.isSubtypeOf( tn1 ) ){
-      if( children[0].isConst() ){
-        children[0] = getNewSymbol( children[0], children[1].getType() );
-      }else if( children[1].isConst() ){
-        children[1] = getNewSymbol( children[1], children[0].getType() );
-      }else{
+
+    //remove from variable bindings
+    Node ret;
+    if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
+      //erase from variable bound
+      for( size_t i=0; i<n[0].getNumChildren(); i++ ){
+        Trace("sort-inference-debug2") << "Remove bound for " << n[0][i] << std::endl;
+        var_bound.erase( n[0][i] );
+      }
+      ret = NodeManager::currentNM()->mkNode( n.getKind(), children );
+    }else if( n.getKind()==kind::EQUAL ){
+      TypeNode tn1 = children[0].getType();
+      TypeNode tn2 = children[1].getType();
+      if( !tn1.isComparableTo( tn2 ) ){
         Trace("sort-inference-warn") << "Sort inference created bad equality: " << children[0] << " = " << children[1] << std::endl;
         Trace("sort-inference-warn") << "  Types : " << children[0].getType() << " " << children[1].getType() << std::endl;
-        Assert( false );
+        Assert(false);
       }
-    }
-    return NodeManager::currentNM()->mkNode( kind::EQUAL, children );
-  }else if( n.getKind()==kind::APPLY_UF ){
-    Node op = n.getOperator();
-    if( d_symbol_map.find( op )==d_symbol_map.end() ){
-      //make the new operator if necessary
-      bool opChanged = false;
-      std::vector< TypeNode > argTypes;
-      for( size_t i=0; i<n.getNumChildren(); i++ ){
-        TypeNode tn = getOrCreateTypeForId( d_op_arg_types[op][i], n[i].getType() );
-        argTypes.push_back( tn );
-        if( tn!=n[i].getType() ){
+      ret = NodeManager::currentNM()->mkNode( kind::EQUAL, children );
+    }
+    else if (isHandledApplyUf(n.getKind()))
+    {
+      if( d_symbol_map.find( op )==d_symbol_map.end() ){
+        //make the new operator if necessary
+        bool opChanged = false;
+        std::vector< TypeNode > argTypes;
+        for( size_t i=0; i<n.getNumChildren(); i++ ){
+          TypeNode tn = getOrCreateTypeForId( d_op_arg_types[op][i], n[i].getType() );
+          argTypes.push_back( tn );
+          if( tn!=n[i].getType() ){
+            opChanged = true;
+          }
+        }
+        TypeNode retType = getOrCreateTypeForId( d_op_return_types[op], n.getType() );
+        if( retType!=n.getType() ){
           opChanged = true;
         }
-      }
-      TypeNode retType = getOrCreateTypeForId( d_op_return_types[op], n.getType() );
-      if( retType!=n.getType() ){
-        opChanged = true;
-      }
-      if( opChanged ){
-        std::stringstream ss;
-        ss << "io_" << op;
-        TypeNode typ = NodeManager::currentNM()->mkFunctionType( argTypes, retType );
-        d_symbol_map[op] = NodeManager::currentNM()->mkSkolem( ss.str(), typ, "op created during sort inference" );
-        Trace("setp-model") << "Function " << op << " is replaced with " << d_symbol_map[op] << std::endl;
-        d_model_replace_f[op] = d_symbol_map[op];
-      }else{
-        d_symbol_map[op] = op;
-      }
-    }
-    children[0] = d_symbol_map[op];
-    //make sure all children have been taken care of
-    for( size_t i=0; i<n.getNumChildren(); i++ ){
-      TypeNode tn = children[i+1].getType();
-      TypeNode tna = getTypeForId( d_op_arg_types[op][i] );
-      if( tn!=tna ){
-        if( n[i].isConst() ){
-          children[i+1] = getNewSymbol( n[i], tna );
+        if( opChanged ){
+          std::stringstream ss;
+          ss << "io_" << op;
+          TypeNode typ = NodeManager::currentNM()->mkFunctionType( argTypes, retType );
+          d_symbol_map[op] = NodeManager::currentNM()->mkSkolem( ss.str(), typ, "op created during sort inference" );
+          Trace("setp-model") << "Function " << op << " is replaced with " << d_symbol_map[op] << std::endl;
+          model_replace_f[op] = d_symbol_map[op];
         }else{
-          Trace("sort-inference-warn") << "Sort inference created bad child: " << n << " " << n[i] << " " << tn << " " << tna << std::endl;
-          Assert( false );
+          d_symbol_map[op] = op;
         }
       }
-    }
-    return NodeManager::currentNM()->mkNode( kind::APPLY_UF, children );
-  }else{
-    std::map< Node, Node >::iterator it = var_bound.find( n );
-    if( it!=var_bound.end() ){
-      return it->second;
-    }else if( n.getKind() == kind::VARIABLE || n.getKind() == kind::SKOLEM ){
-      if( d_symbol_map.find( n )==d_symbol_map.end() ){
-        TypeNode tn = getOrCreateTypeForId( d_op_return_types[n], n.getType() );
-        d_symbol_map[n] = getNewSymbol( n, tn );
+      children[0] = d_symbol_map[op];
+      // make sure all children have been given proper types
+      for (size_t i = 0, size = n.getNumChildren(); i < size; i++)
+      {
+        TypeNode tn = children[i+1].getType();
+        TypeNode tna = getTypeForId( d_op_arg_types[op][i] );
+        if (!tn.isSubtypeOf(tna))
+        {
+          Trace("sort-inference-warn") << "Sort inference created bad child: " << n << " " << n[i] << " " << tn << " " << tna << std::endl;
+          Assert(false);
+        }
       }
-      return d_symbol_map[n];
-    }else if( n.isConst() ){
-      //just return n, we will fix at higher scope
-      return n;
+      ret = NodeManager::currentNM()->mkNode( kind::APPLY_UF, children );
     }else{
-      if( childChanged ){
-        return NodeManager::currentNM()->mkNode( n.getKind(), children );
+      std::map< Node, Node >::iterator it = var_bound.find( n );
+      if( it!=var_bound.end() ){
+        ret = it->second;
+      }else if( n.getKind() == kind::VARIABLE || n.getKind() == kind::SKOLEM ){
+        if( d_symbol_map.find( n )==d_symbol_map.end() ){
+          TypeNode tn = getOrCreateTypeForId( d_op_return_types[n], n.getType() );
+          d_symbol_map[n] = getNewSymbol( n, tn );
+        }
+        ret = d_symbol_map[n];
+      }else if( n.isConst() ){
+        //type is determined by context
+        ret = getNewSymbol( n, tnn );
+      }else if( childChanged ){
+        ret = NodeManager::currentNM()->mkNode( n.getKind(), children );
       }else{
-        return n;
+        ret = n;
       }
     }
+    visited[n][tnn] = ret;
+    return ret;
   }
-
 }
 
 Node SortInference::mkInjection( TypeNode tn1, TypeNode tn2 ) {
@@ -728,14 +812,16 @@ void SortInference::setSkolemVar( Node f, Node v, Node sk ){
   if( isWellSortedFormula( f ) && d_var_types.find( f )==d_var_types.end() ){
     //calculate the sort for variables if not done so already
     std::map< Node, Node > var_bound;
-    process( f, var_bound );
+    std::map< Node, int > visited;
+    process( f, var_bound, visited );
   }
   d_op_return_types[sk] = getSortId( f, v );
   Trace("sort-inference-temp") << "Set skolem sort id for " << sk << " to " << d_op_return_types[sk] << std::endl;
 }
 
 bool SortInference::isWellSortedFormula( Node n ) {
-  if( n.getType().isBoolean() && n.getKind()!=kind::APPLY_UF ){
+  if (n.getType().isBoolean() && !isHandledApplyUf(n.getKind()))
+  {
     for( unsigned i=0; i<n.getNumChildren(); i++ ){
       if( !isWellSortedFormula( n[i] ) ){
         return false;
@@ -751,7 +837,8 @@ bool SortInference::isWellSorted( Node n ) {
   if( getSortId( n )==0 ){
     return false;
   }else{
-    if( n.getKind()==kind::APPLY_UF ){
+    if (isHandledApplyUf(n.getKind()))
+    {
       for( unsigned i=0; i<n.getNumChildren(); i++ ){
         int s1 = getSortId( n[i] );
         int s2 = d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] );
@@ -767,18 +854,14 @@ bool SortInference::isWellSorted( Node n ) {
   }
 }
 
-void SortInference::getSortConstraints( Node n, UnionFind& uf ) {
-  if( n.getKind()==kind::APPLY_UF ){
-    for( unsigned i=0; i<n.getNumChildren(); i++ ){
-      getSortConstraints( n[i], uf );
-      uf.setEqual( getSortId( n[i] ), d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] ) );
-    }
-  }
-}
-
 bool SortInference::isMonotonic( TypeNode tn ) {
-  Assert( tn.isSort() );
+  Assert(tn.isSort());
   return d_non_monotonic_sorts_orig.find( tn )==d_non_monotonic_sorts_orig.end();
 }
 
+bool SortInference::isHandledApplyUf(Kind k) const
+{
+  return k == APPLY_UF && !options::ufHo();
+}
+
 }/* CVC4 namespace */