Introduce quantifiers inference manager (#5821)
[cvc5.git] / src / theory / sort_inference.cpp
index 73a50bc7acc521f5f7553c6d43a5964cad0efd06..9fa216e67214c81f56a2e202bc1a4b1dda8e5607 100644 (file)
@@ -2,10 +2,10 @@
 /*! \file sort_inference.cpp
  ** \verbatim
  ** Top contributors (to current version):
- **   Andrew Reynolds, Paul Meng, Morgan Deters
+ **   Andrew Reynolds, Paul Meng, Mathias Preiner
  ** This file is part of the CVC4 project.
- ** Copyright (c) 2009-2017 by the authors listed in the file AUTHORS
- ** in the top-level source directory) and their institutional affiliations.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
  ** All rights reserved.  See the file COPYING in the top-level source
  ** directory for licensing information.\endverbatim
  **
 #include "options/quantifiers_options.h"
 #include "options/smt_options.h"
 #include "options/uf_options.h"
-#include "proof/proof_manager.h"
 #include "theory/rewriter.h"
 #include "theory/quantifiers/quant_util.h"
 
 using namespace CVC4;
+using namespace CVC4::kind;
 using namespace std;
 
 namespace CVC4 {
@@ -102,7 +102,7 @@ void SortInference::reset() {
   d_non_monotonic_sorts.clear();
   d_type_sub_sorts.clear();
   //reset info
-  sortCount = 1;
+  d_sortCount = 1;
   d_type_union_find.clear();
   d_type_types.clear();
   d_id_for_types.clear();
@@ -114,159 +114,195 @@ void SortInference::reset() {
   d_const_map.clear();
 }
 
-void SortInference::simplify( std::vector< Node >& assertions, bool doSortInference, bool doMonotonicyInference ){
-  if( doSortInference ){
-    Trace("sort-inference-proc") << "Calculating sort inference..." << std::endl;
-    //process all assertions
-    std::map< Node, int > visited;
-    for( unsigned i=0; i<assertions.size(); i++ ){
-      Trace("sort-inference-debug") << "Process " << assertions[i] << std::endl;
-      std::map< Node, Node > var_bound;
-      process( assertions[i], var_bound, visited );
-    }
-    Trace("sort-inference-proc") << "...done" << std::endl;
-    for( std::map< Node, int >::iterator it = d_op_return_types.begin(); it != d_op_return_types.end(); ++it ){
-      Trace("sort-inference") << it->first << " : ";
-      TypeNode retTn = it->first.getType();
-      if( !d_op_arg_types[ it->first ].empty() ){
-        Trace("sort-inference") << "( ";
-        for( size_t i=0; i<d_op_arg_types[ it->first ].size(); i++ ){
-          recordSubsort( retTn[i], d_op_arg_types[ it->first ][i] );
-          printSort( "sort-inference", d_op_arg_types[ it->first ][i] );
-          Trace("sort-inference") << " ";
-        }
-        Trace("sort-inference") << ") -> ";
-        retTn = retTn[(int)retTn.getNumChildren()-1];
+void SortInference::initialize(const std::vector<Node>& assertions)
+{
+  Trace("sort-inference-proc") << "Calculating sort inference..." << std::endl;
+  // process all assertions
+  std::map<Node, int> visited;
+  NodeManager * nm = NodeManager::currentNM();
+  int btId = getIdForType( nm->booleanType() );
+  for (const Node& a : assertions)
+  {
+    Trace("sort-inference-debug") << "Process " << a << std::endl;
+    std::map<Node, Node> var_bound;
+    int pid = process(a, var_bound, visited);
+    // the type of the topmost term must be Boolean
+    setEqual( pid, btId );
+  }
+  Trace("sort-inference-proc") << "...done" << std::endl;
+  for (const std::pair<const Node, int>& rt : d_op_return_types)
+  {
+    Trace("sort-inference") << rt.first << " : ";
+    TypeNode retTn = rt.first.getType();
+    if (!d_op_arg_types[rt.first].empty())
+    {
+      Trace("sort-inference") << "( ";
+      for (size_t i = 0; i < d_op_arg_types[rt.first].size(); i++)
+      {
+        recordSubsort(retTn[i], d_op_arg_types[rt.first][i]);
+        printSort("sort-inference", d_op_arg_types[rt.first][i]);
+        Trace("sort-inference") << " ";
       }
-      recordSubsort( retTn, it->second );
-      printSort( "sort-inference", it->second );
-      Trace("sort-inference") << std::endl;
+      Trace("sort-inference") << ") -> ";
+      retTn = retTn[(int)retTn.getNumChildren() - 1];
     }
-    for( std::map< Node, std::map< Node, int > >::iterator it = d_var_types.begin(); it != d_var_types.end(); ++it ){
-      Trace("sort-inference") << "Quantified formula : " << it->first << " : " << std::endl;
-      for( unsigned i=0; i<it->first[0].getNumChildren(); i++ ){
-        recordSubsort( it->first[0][i].getType(), it->second[it->first[0][i]] );
-        printSort( "sort-inference", it->second[it->first[0][i]] );
-        Trace("sort-inference") << std::endl;
-      }
+    recordSubsort(retTn, rt.second);
+    printSort("sort-inference", rt.second);
+    Trace("sort-inference") << std::endl;
+  }
+  for (std::pair<const Node, std::map<Node, int> >& vt : d_var_types)
+  {
+    Trace("sort-inference")
+        << "Quantified formula : " << vt.first << " : " << std::endl;
+    for (const Node& v : vt.first[0])
+    {
+      recordSubsort(v.getType(), vt.second[v]);
+      printSort("sort-inference", vt.second[v]);
       Trace("sort-inference") << std::endl;
     }
+    Trace("sort-inference") << std::endl;
+  }
 
-    if( !options::ufssSymBreak() ){
-      bool rewritten = false;
-      //determine monotonicity of sorts
-      Trace("sort-inference-proc") << "Calculating monotonicty for subsorts..." << std::endl;
-      std::map< Node, std::map< int, bool > > visited;
-      for( unsigned i=0; i<assertions.size(); i++ ){
-        Trace("sort-inference-debug") << "Process monotonicity for " << assertions[i] << std::endl;
-        std::map< Node, Node > var_bound;
-        processMonotonic( assertions[i], true, true, var_bound, visited );
-      }
-      Trace("sort-inference-proc") << "...done" << std::endl;
-
-      Trace("sort-inference") << "We have " << d_sub_sorts.size() << " sub-sorts : " << std::endl;
-      for( unsigned i=0; i<d_sub_sorts.size(); i++ ){
-        printSort( "sort-inference", d_sub_sorts[i] );
-        if( d_type_types.find( d_sub_sorts[i] )!=d_type_types.end() ){
-          Trace("sort-inference") << " is interpreted." << std::endl;
-        }else if( d_non_monotonic_sorts.find( d_sub_sorts[i] )==d_non_monotonic_sorts.end() ){
-          Trace("sort-inference") << " is monotonic." << std::endl;
-        }else{
-          Trace("sort-inference") << " is not monotonic." << std::endl;
-        }
-      }
+  // determine monotonicity of sorts
+  Trace("sort-inference-proc")
+      << "Calculating monotonicty for subsorts..." << std::endl;
+  std::map<Node, std::map<int, bool> > visitedm;
+  for (const Node& a : assertions)
+  {
+    Trace("sort-inference-debug")
+        << "Process monotonicity for " << a << std::endl;
+    std::map<Node, Node> var_bound;
+    processMonotonic(a, true, true, var_bound, visitedm);
+  }
+  Trace("sort-inference-proc") << "...done" << std::endl;
+
+  Trace("sort-inference") << "We have " << d_sub_sorts.size()
+                          << " sub-sorts : " << std::endl;
+  for (unsigned i = 0, size = d_sub_sorts.size(); i < size; i++)
+  {
+    printSort("sort-inference", d_sub_sorts[i]);
+    if (d_type_types.find(d_sub_sorts[i]) != d_type_types.end())
+    {
+      Trace("sort-inference") << " is interpreted." << std::endl;
+    }
+    else if (d_non_monotonic_sorts.find(d_sub_sorts[i])
+             == d_non_monotonic_sorts.end())
+    {
+      Trace("sort-inference") << " is monotonic." << std::endl;
+    }
+    else
+    {
+      Trace("sort-inference") << " is not monotonic." << std::endl;
+    }
+  }
+}
 
-      //simplify all assertions by introducing new symbols wherever necessary
-      Trace("sort-inference-proc") << "Perform simplification..." << std::endl;
-      std::map< Node, std::map< TypeNode, Node > > visited2;
-      for( unsigned i=0; i<assertions.size(); i++ ){
-        Node prev = assertions[i];
-        std::map< Node, Node > var_bound;
-        Trace("sort-inference-debug") << "Simplify " << assertions[i] << std::endl;
-        TypeNode tnn;
-        Node curr = simplifyNode( assertions[i], var_bound, tnn, visited2 );
-        Trace("sort-inference-debug") << "Done." << std::endl;
-        if( curr!=assertions[i] ){
-          Trace("sort-inference-debug") << "Rewrite " << curr << std::endl;
-          curr = theory::Rewriter::rewrite( curr );
-          rewritten = true;
-          Trace("sort-inference-rewrite") << assertions << std::endl;
-          Trace("sort-inference-rewrite") << " --> " << curr << std::endl;
-          PROOF( ProofManager::currentPM()->addDependence(curr, assertions[i]); );
-          assertions[i] = curr;
-        }
-      }
-      Trace("sort-inference-proc") << "...done" << std::endl;
-      //now, ensure constants are distinct
-      for( std::map< TypeNode, std::map< Node, Node > >::iterator it = d_const_map.begin(); it != d_const_map.end(); ++it ){
-        std::vector< Node > consts;
-        for( std::map< Node, Node >::iterator it2 = it->second.begin(); it2 != it->second.end(); ++it2 ){
-          Assert( it2->first.isConst() );
-          consts.push_back( it2->second );
-        }
-        //add lemma enforcing introduced constants to be distinct
-        if( consts.size()>1 ){
-          Node distinct_const = NodeManager::currentNM()->mkNode( kind::DISTINCT, consts );
-          Trace("sort-inference-rewrite") << "Add the constant distinctness lemma: " << std::endl;
-          Trace("sort-inference-rewrite") << "  " << distinct_const << std::endl;
-          assertions.push_back( distinct_const );
-          rewritten = true;
-        }
-      }
+Node SortInference::simplify(Node n,
+                             std::map<Node, Node>& model_replace_f,
+                             std::map<Node, std::map<TypeNode, Node> >& visited)
+{
+  Trace("sort-inference-debug") << "Simplify " << n << std::endl;
+  std::map<Node, Node> var_bound;
+  TypeNode tnn;
+  Node ret = simplifyNode(n, var_bound, tnn, model_replace_f, visited);
+  ret = theory::Rewriter::rewrite(ret);
+  return ret;
+}
 
-      //enforce constraints based on monotonicity
-      Trace("sort-inference-proc") << "Enforce monotonicity..." << std::endl;
-      for( std::map< TypeNode, std::vector< int > >::iterator it = d_type_sub_sorts.begin(); it != d_type_sub_sorts.end(); ++it ){
-        int nmonSort = -1;
-        for( unsigned i=0; i<it->second.size(); i++ ){
-          if( d_non_monotonic_sorts.find( it->second[i] )!=d_non_monotonic_sorts.end() ){
-            nmonSort = it->second[i];
-            break;
+void SortInference::getNewAssertions(std::vector<Node>& new_asserts)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  // now, ensure constants are distinct
+  for (const std::pair<const TypeNode, std::map<Node, Node> >& cm : d_const_map)
+  {
+    std::vector<Node> consts;
+    for (const std::pair<const Node, Node>& c : cm.second)
+    {
+      Assert(c.first.isConst());
+      consts.push_back(c.second);
+    }
+    // add lemma enforcing introduced constants to be distinct
+    if (consts.size() > 1)
+    {
+      Node distinct_const = nm->mkNode(kind::DISTINCT, consts);
+      Trace("sort-inference-rewrite")
+          << "Add the constant distinctness lemma: " << std::endl;
+      Trace("sort-inference-rewrite") << "  " << distinct_const << std::endl;
+      new_asserts.push_back(distinct_const);
+    }
+  }
+
+  // enforce constraints based on monotonicity
+  Trace("sort-inference-proc") << "Enforce monotonicity..." << std::endl;
+
+  for (const std::pair<const TypeNode, std::vector<int> >& tss :
+       d_type_sub_sorts)
+  {
+    int nmonSort = -1;
+    unsigned nsorts = tss.second.size();
+    for (unsigned i = 0; i < nsorts; i++)
+    {
+      if (d_non_monotonic_sorts.find(tss.second[i])
+          != d_non_monotonic_sorts.end())
+      {
+        nmonSort = tss.second[i];
+        break;
+      }
+    }
+    if (nmonSort != -1)
+    {
+      std::vector<Node> injections;
+      TypeNode base_tn = getOrCreateTypeForId(nmonSort, tss.first);
+      for (unsigned i = 0; i < nsorts; i++)
+      {
+        if (tss.second[i] != nmonSort)
+        {
+          TypeNode new_tn = getOrCreateTypeForId(tss.second[i], tss.first);
+          // make injection to nmonSort
+          Node a1 = mkInjection(new_tn, base_tn);
+          injections.push_back(a1);
+          if (d_non_monotonic_sorts.find(tss.second[i])
+              != d_non_monotonic_sorts.end())
+          {
+            // also must make injection from nmonSort to this
+            Node a2 = mkInjection(base_tn, new_tn);
+            injections.push_back(a2);
           }
         }
-        if( nmonSort!=-1 ){
-          std::vector< Node > injections;
-          TypeNode base_tn = getOrCreateTypeForId( nmonSort, it->first );
-          for( unsigned i=0; i<it->second.size(); i++ ){
-            if( it->second[i]!=nmonSort ){
-              TypeNode new_tn = getOrCreateTypeForId( it->second[i], it->first );
-              //make injection to nmonSort
-              Node a1 = mkInjection( new_tn, base_tn );
-              injections.push_back( a1 );
-              if( d_non_monotonic_sorts.find( it->second[i] )!=d_non_monotonic_sorts.end() ){
-                //also must make injection from nmonSort to this
-                Node a2 = mkInjection( base_tn, new_tn );
-                injections.push_back( a2 );
-              }
-            }
-          }
-          Trace("sort-inference-rewrite") << "Add the following injections for " << it->first << " to ensure consistency wrt non-monotonic sorts : " << std::endl;
-          for( unsigned j=0; j<injections.size(); j++ ){
-            Trace("sort-inference-rewrite") << "   " << injections[j] << std::endl;
-          }
-          assertions.insert( assertions.end(), injections.begin(), injections.end() );
-          if( !injections.empty() ){
-            rewritten = true;
-          }
+      }
+      if (Trace.isOn("sort-inference-rewrite"))
+      {
+        Trace("sort-inference-rewrite")
+            << "Add the following injections for " << tss.first
+            << " to ensure consistency wrt non-monotonic sorts : " << std::endl;
+        for (const Node& i : injections)
+        {
+          Trace("sort-inference-rewrite") << "   " << i << std::endl;
         }
       }
-      Trace("sort-inference-proc") << "...done" << std::endl;
-      //no sub-sort information is stored
-      reset();
-      Trace("sort-inference-debug") << "Finished sort inference, rewritten = " << rewritten << std::endl;
+      new_asserts.insert(
+          new_asserts.end(), injections.begin(), injections.end());
     }
-    initialSortCount = sortCount;
   }
-  if( doMonotonicyInference ){
-    std::map< Node, std::map< int, bool > > visited;
-    Trace("sort-inference-proc") << "Calculating monotonicty for types..." << std::endl;
-    for( unsigned i=0; i<assertions.size(); i++ ){
-      Trace("sort-inference-debug") << "Process type monotonicity for " << assertions[i] << std::endl;
-      std::map< Node, Node > var_bound;
-      processMonotonic( assertions[i], true, true, var_bound, visited, true );
-    }
-    Trace("sort-inference-proc") << "...done" << std::endl;
+  Trace("sort-inference-proc") << "...done" << std::endl;
+  // no sub-sort information is stored
+  reset();
+  Trace("sort-inference-debug") << "Finished sort inference" << std::endl;
+}
+
+void SortInference::computeMonotonicity(const std::vector<Node>& assertions)
+{
+  std::map<Node, std::map<int, bool> > visitedmt;
+  Trace("sort-inference-proc")
+      << "Calculating monotonicty for types..." << std::endl;
+  for (const Node& a : assertions)
+  {
+    Trace("sort-inference-debug")
+        << "Process type monotonicity for " << a << std::endl;
+    std::map<Node, Node> var_bound;
+    processMonotonic(a, true, true, var_bound, visitedmt, true);
   }
+  Trace("sort-inference-proc") << "...done" << std::endl;
 }
 
 void SortInference::setEqual( int t1, int t2 ){
@@ -313,10 +349,10 @@ int SortInference::getIdForType( TypeNode tn ){
   //register the return type
   std::map< TypeNode, int >::iterator it = d_id_for_types.find( tn );
   if( it==d_id_for_types.end() ){
-    int sc = sortCount;
-    d_type_types[ sortCount ] = tn;
-    d_id_for_types[ tn ] = sortCount;
-    sortCount++;
+    int sc = d_sortCount;
+    d_type_types[d_sortCount] = tn;
+    d_id_for_types[tn] = d_sortCount;
+    d_sortCount++;
     return sc;
   }else{
     return it->second;
@@ -335,12 +371,21 @@ int SortInference::process( Node n, std::map< Node, Node >& var_bound, std::map<
       if( d_var_types.find( n )!=d_var_types.end() ){
         return getIdForType( n.getType() );
       }else{
+        //apply sort inference to quantified variables
         for( size_t i=0; i<n[0].getNumChildren(); i++ ){
-          //apply sort inference to quantified variables
-          d_var_types[n][ n[0][i] ] = sortCount;
-          sortCount++;
-
-          //type of the quantified variable must be the same
+          TypeNode nitn = n[0][i].getType();
+          if( !nitn.isSort() )
+          {
+            // If the variable is of an interpreted sort, we assume the
+            // the sort of the variable will stay the same sort.
+            d_var_types[n][n[0][i]] = getIdForType( nitn );
+          }
+          else
+          {
+            // If it is of an uninterpreted sort, infer subsorts.
+            d_var_types[n][n[0][i]] = d_sortCount;
+            d_sortCount++;
+          }
           var_bound[ n[0][i] ] = n;
         }
       }
@@ -353,7 +398,10 @@ int SortInference::process( Node n, std::map< Node, Node >& var_bound, std::map<
     for( size_t i=0; i<n.getNumChildren(); i++ ){
       bool processChild = true;
       if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
-        processChild = options::userPatternsQuant()==theory::quantifiers::USER_PAT_MODE_IGNORE ? i==1 : i>=1;
+        processChild =
+            options::userPatternsQuant() == options::UserPatMode::IGNORE
+                ? i == 1
+                : i >= 1;
       }
       if( processChild ){
         children.push_back( n[i] );
@@ -386,7 +434,9 @@ int SortInference::process( Node n, std::map< Node, Node >& var_bound, std::map<
       }
       d_equality_types[n] = child_types[0];
       retType = getIdForType( n.getType() );
-    }else if( n.getKind()==kind::APPLY_UF ){
+    }
+    else if (isHandledApplyUf(n.getKind()))
+    {
       Node op = n.getOperator();
       TypeNode tn_op = op.getType();
       if( d_op_return_types.find( op )==d_op_return_types.end() ){
@@ -395,14 +445,14 @@ int SortInference::process( Node n, std::map< Node, Node >& var_bound, std::map<
           d_op_return_types[op] = getIdForType( n.getType() );
         }else{
           //assign arbitrary sort for return type
-          d_op_return_types[op] = sortCount;
-          sortCount++;
+          d_op_return_types[op] = d_sortCount;
+          d_sortCount++;
         }
-        //d_type_eq_class[sortCount].push_back( op );
-        //assign arbitrary sort for argument types
+        // d_type_eq_class[d_sortCount].push_back( op );
+        // assign arbitrary sort for argument types
         for( size_t i=0; i<n.getNumChildren(); i++ ){
-          d_op_arg_types[op].push_back( sortCount );
-          sortCount++;
+          d_op_arg_types[op].push_back(d_sortCount);
+          d_sortCount++;
         }
       }
       for( size_t i=0; i<n.getNumChildren(); i++ ){
@@ -427,20 +477,20 @@ int SortInference::process( Node n, std::map< Node, Node >& var_bound, std::map<
         Trace("sort-inference-debug") << n << " is a bound variable." << std::endl;
         //the return type was specified while binding
         retType = d_var_types[it->second][n];
-      }else if( n.getKind() == kind::VARIABLE || n.getKind()==kind::SKOLEM ){
+      }else if( n.isVar() ){
         Trace("sort-inference-debug") << n << " is a variable." << std::endl;
         if( d_op_return_types.find( n )==d_op_return_types.end() ){
           //assign arbitrary sort
-          d_op_return_types[n] = sortCount;
-          sortCount++;
-          //d_type_eq_class[sortCount].push_back( n );
+          d_op_return_types[n] = d_sortCount;
+          d_sortCount++;
+          // d_type_eq_class[d_sortCount].push_back( n );
         }
         retType = d_op_return_types[n];
       }else if( n.isConst() ){
         Trace("sort-inference-debug") << n << " is a constant." << std::endl;
         //can be any type we want
-        retType = sortCount;
-        sortCount++;
+        retType = d_sortCount;
+        d_sortCount++;
       }else{
         Trace("sort-inference-debug") << n << " is a interpreted symbol." << std::endl;
         //it is an interpreted term
@@ -512,8 +562,13 @@ TypeNode SortInference::getOrCreateTypeForId( int t, TypeNode pref ){
     return d_type_types[rt];
   }else{
     TypeNode retType;
-    //see if we can assign pref
-    if( !pref.isNull() && d_id_for_types.find( pref )==d_id_for_types.end() ){
+    // See if we can assign pref. This is an optimization for reusing an
+    // uninterpreted sort as the first subsort, so that fewer symbols needed
+    // to be rewritten in the sort-inferred signature. Notice we only assign
+    // pref here if it is an uninterpreted sort.
+    if (!pref.isNull() && d_id_for_types.find(pref) == d_id_for_types.end()
+        && pref.isSort())
+    {
       retType = pref;
     }else{
       //must create new type
@@ -540,7 +595,8 @@ TypeNode SortInference::getTypeForId( int t ){
 }
 
 Node SortInference::getNewSymbol( Node old, TypeNode tn ){
-  if( tn.isNull() || tn==old.getType() ){
+  // if no sort was inferred for this node, return original
+  if( tn.isNull() || tn.isComparableTo( old.getType() ) ){
     return old;
   }else if( old.isConst() ){
     //must make constant of type tn
@@ -561,7 +617,13 @@ Node SortInference::getNewSymbol( Node old, TypeNode tn ){
   }
 }
 
-Node SortInference::simplifyNode( Node n, std::map< Node, Node >& var_bound, TypeNode tnn, std::map< Node, std::map< TypeNode, Node > >& visited ){
+Node SortInference::simplifyNode(
+    Node n,
+    std::map<Node, Node>& var_bound,
+    TypeNode tnn,
+    std::map<Node, Node>& model_replace_f,
+    std::map<Node, std::map<TypeNode, Node> >& visited)
+{
   std::map< TypeNode, Node >::iterator itv = visited[n].find( tnn );
   if( itv!=visited[n].end() ){
     return itv->second;
@@ -597,19 +659,30 @@ Node SortInference::simplifyNode( Node n, std::map< Node, Node >& var_bound, Typ
     for( size_t i=0; i<n.getNumChildren(); i++ ){
       bool processChild = true;
       if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
-        processChild = options::userPatternsQuant()==theory::quantifiers::USER_PAT_MODE_IGNORE ? i==1 : i>=1;
+        processChild =
+            options::userPatternsQuant() == options::UserPatMode::IGNORE
+                ? i == 1
+                : i >= 1;
       }
       if( processChild ){
-        if( n.getKind()==kind::APPLY_UF ){
-          Assert( d_op_arg_types.find( op )!=d_op_arg_types.end() );
+        if (isHandledApplyUf(n.getKind()))
+        {
+          Assert(d_op_arg_types.find(op) != d_op_arg_types.end());
           tnnc = getOrCreateTypeForId( d_op_arg_types[op][i], n[i].getType() );
-          Assert( !tnnc.isNull() );
-        }else if( n.getKind()==kind::EQUAL && i==0 ){
-          Assert( d_equality_types.find( n )!=d_equality_types.end() );
+          Assert(!tnnc.isNull());
+        }
+        else if (n.getKind() == kind::EQUAL && !n[0].getType().isBoolean()
+                 && i == 0)
+        {
+          Assert(d_equality_types.find(n) != d_equality_types.end());
           tnnc = getOrCreateTypeForId( d_equality_types[n], n[0].getType() );
-          Assert( !tnnc.isNull() );
+          Assert(!tnnc.isNull());
         }
-        Node nc = simplifyNode( n[i], var_bound, tnnc, use_new_visited ? new_visited : visited );
+        Node nc = simplifyNode(n[i],
+                               var_bound,
+                               tnnc,
+                               model_replace_f,
+                               use_new_visited ? new_visited : visited);
         Trace("sort-inference-debug2") << "Simplify " << i << " " << n[i] << " returned " << nc << std::endl;
         children.push_back( nc );
         childChanged = childChanged || nc!=n[i];
@@ -628,13 +701,15 @@ Node SortInference::simplifyNode( Node n, std::map< Node, Node >& var_bound, Typ
     }else if( n.getKind()==kind::EQUAL ){
       TypeNode tn1 = children[0].getType();
       TypeNode tn2 = children[1].getType();
-      if( !tn1.isSubtypeOf( tn2 ) && !tn2.isSubtypeOf( tn1 ) ){
+      if( !tn1.isComparableTo( tn2 ) ){
         Trace("sort-inference-warn") << "Sort inference created bad equality: " << children[0] << " = " << children[1] << std::endl;
         Trace("sort-inference-warn") << "  Types : " << children[0].getType() << " " << children[1].getType() << std::endl;
-        Assert( false );
+        Assert(false);
       }
       ret = NodeManager::currentNM()->mkNode( kind::EQUAL, children );
-    }else if( n.getKind()==kind::APPLY_UF ){
+    }
+    else if (isHandledApplyUf(n.getKind()))
+    {
       if( d_symbol_map.find( op )==d_symbol_map.end() ){
         //make the new operator if necessary
         bool opChanged = false;
@@ -656,19 +731,21 @@ Node SortInference::simplifyNode( Node n, std::map< Node, Node >& var_bound, Typ
           TypeNode typ = NodeManager::currentNM()->mkFunctionType( argTypes, retType );
           d_symbol_map[op] = NodeManager::currentNM()->mkSkolem( ss.str(), typ, "op created during sort inference" );
           Trace("setp-model") << "Function " << op << " is replaced with " << d_symbol_map[op] << std::endl;
-          d_model_replace_f[op] = d_symbol_map[op];
+          model_replace_f[op] = d_symbol_map[op];
         }else{
           d_symbol_map[op] = op;
         }
       }
       children[0] = d_symbol_map[op];
-      //make sure all children have been taken care of
-      for( size_t i=0; i<n.getNumChildren(); i++ ){
+      // make sure all children have been given proper types
+      for (size_t i = 0, size = n.getNumChildren(); i < size; i++)
+      {
         TypeNode tn = children[i+1].getType();
         TypeNode tna = getTypeForId( d_op_arg_types[op][i] );
-        if( tn!=tna ){
+        if (!tn.isSubtypeOf(tna))
+        {
           Trace("sort-inference-warn") << "Sort inference created bad child: " << n << " " << n[i] << " " << tn << " " << tna << std::endl;
-          Assert( false );
+          Assert(false);
         }
       }
       ret = NodeManager::currentNM()->mkNode( kind::APPLY_UF, children );
@@ -743,7 +820,8 @@ void SortInference::setSkolemVar( Node f, Node v, Node sk ){
 }
 
 bool SortInference::isWellSortedFormula( Node n ) {
-  if( n.getType().isBoolean() && n.getKind()!=kind::APPLY_UF ){
+  if (n.getType().isBoolean() && !isHandledApplyUf(n.getKind()))
+  {
     for( unsigned i=0; i<n.getNumChildren(); i++ ){
       if( !isWellSortedFormula( n[i] ) ){
         return false;
@@ -759,7 +837,8 @@ bool SortInference::isWellSorted( Node n ) {
   if( getSortId( n )==0 ){
     return false;
   }else{
-    if( n.getKind()==kind::APPLY_UF ){
+    if (isHandledApplyUf(n.getKind()))
+    {
       for( unsigned i=0; i<n.getNumChildren(); i++ ){
         int s1 = getSortId( n[i] );
         int s2 = d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] );
@@ -775,18 +854,14 @@ bool SortInference::isWellSorted( Node n ) {
   }
 }
 
-void SortInference::getSortConstraints( Node n, UnionFind& uf ) {
-  if( n.getKind()==kind::APPLY_UF ){
-    for( unsigned i=0; i<n.getNumChildren(); i++ ){
-      getSortConstraints( n[i], uf );
-      uf.setEqual( getSortId( n[i] ), d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] ) );
-    }
-  }
-}
-
 bool SortInference::isMonotonic( TypeNode tn ) {
-  Assert( tn.isSort() );
+  Assert(tn.isSort());
   return d_non_monotonic_sorts_orig.find( tn )==d_non_monotonic_sorts_orig.end();
 }
 
+bool SortInference::isHandledApplyUf(Kind k) const
+{
+  return k == APPLY_UF && !options::ufHo();
+}
+
 }/* CVC4 namespace */