Introduce quantifiers inference manager (#5821)
[cvc5.git] / src / theory / sort_inference.cpp
index c4c0a8b473fc1fd04d488620cfd1677d3ff96ec5..9fa216e67214c81f56a2e202bc1a4b1dda8e5607 100644 (file)
@@ -2,10 +2,10 @@
 /*! \file sort_inference.cpp
  ** \verbatim
  ** Top contributors (to current version):
- **   Andrew Reynolds, Paul Meng, Tim King
+ **   Andrew Reynolds, Paul Meng, Mathias Preiner
  ** This file is part of the CVC4 project.
- ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS
- ** in the top-level source directory) and their institutional affiliations.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory and their institutional affiliations.
  ** All rights reserved.  See the file COPYING in the top-level source
  ** directory for licensing information.\endverbatim
  **
 #include "options/quantifiers_options.h"
 #include "options/smt_options.h"
 #include "options/uf_options.h"
-#include "proof/proof_manager.h"
 #include "theory/rewriter.h"
 #include "theory/quantifiers/quant_util.h"
 
 using namespace CVC4;
+using namespace CVC4::kind;
 using namespace std;
 
 namespace CVC4 {
@@ -119,11 +119,15 @@ void SortInference::initialize(const std::vector<Node>& assertions)
   Trace("sort-inference-proc") << "Calculating sort inference..." << std::endl;
   // process all assertions
   std::map<Node, int> visited;
+  NodeManager * nm = NodeManager::currentNM();
+  int btId = getIdForType( nm->booleanType() );
   for (const Node& a : assertions)
   {
     Trace("sort-inference-debug") << "Process " << a << std::endl;
     std::map<Node, Node> var_bound;
-    process(a, var_bound, visited);
+    int pid = process(a, var_bound, visited);
+    // the type of the topmost term must be Boolean
+    setEqual( pid, btId );
   }
   Trace("sort-inference-proc") << "...done" << std::endl;
   for (const std::pair<const Node, int>& rt : d_op_return_types)
@@ -367,12 +371,21 @@ int SortInference::process( Node n, std::map< Node, Node >& var_bound, std::map<
       if( d_var_types.find( n )!=d_var_types.end() ){
         return getIdForType( n.getType() );
       }else{
+        //apply sort inference to quantified variables
         for( size_t i=0; i<n[0].getNumChildren(); i++ ){
-          //apply sort inference to quantified variables
-          d_var_types[n][n[0][i]] = d_sortCount;
-          d_sortCount++;
-
-          //type of the quantified variable must be the same
+          TypeNode nitn = n[0][i].getType();
+          if( !nitn.isSort() )
+          {
+            // If the variable is of an interpreted sort, we assume the
+            // the sort of the variable will stay the same sort.
+            d_var_types[n][n[0][i]] = getIdForType( nitn );
+          }
+          else
+          {
+            // If it is of an uninterpreted sort, infer subsorts.
+            d_var_types[n][n[0][i]] = d_sortCount;
+            d_sortCount++;
+          }
           var_bound[ n[0][i] ] = n;
         }
       }
@@ -385,7 +398,10 @@ int SortInference::process( Node n, std::map< Node, Node >& var_bound, std::map<
     for( size_t i=0; i<n.getNumChildren(); i++ ){
       bool processChild = true;
       if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
-        processChild = options::userPatternsQuant()==theory::quantifiers::USER_PAT_MODE_IGNORE ? i==1 : i>=1;
+        processChild =
+            options::userPatternsQuant() == options::UserPatMode::IGNORE
+                ? i == 1
+                : i >= 1;
       }
       if( processChild ){
         children.push_back( n[i] );
@@ -418,7 +434,9 @@ int SortInference::process( Node n, std::map< Node, Node >& var_bound, std::map<
       }
       d_equality_types[n] = child_types[0];
       retType = getIdForType( n.getType() );
-    }else if( n.getKind()==kind::APPLY_UF ){
+    }
+    else if (isHandledApplyUf(n.getKind()))
+    {
       Node op = n.getOperator();
       TypeNode tn_op = op.getType();
       if( d_op_return_types.find( op )==d_op_return_types.end() ){
@@ -459,7 +477,7 @@ int SortInference::process( Node n, std::map< Node, Node >& var_bound, std::map<
         Trace("sort-inference-debug") << n << " is a bound variable." << std::endl;
         //the return type was specified while binding
         retType = d_var_types[it->second][n];
-      }else if( n.getKind() == kind::VARIABLE || n.getKind()==kind::SKOLEM ){
+      }else if( n.isVar() ){
         Trace("sort-inference-debug") << n << " is a variable." << std::endl;
         if( d_op_return_types.find( n )==d_op_return_types.end() ){
           //assign arbitrary sort
@@ -641,17 +659,24 @@ Node SortInference::simplifyNode(
     for( size_t i=0; i<n.getNumChildren(); i++ ){
       bool processChild = true;
       if( n.getKind()==kind::FORALL || n.getKind()==kind::EXISTS ){
-        processChild = options::userPatternsQuant()==theory::quantifiers::USER_PAT_MODE_IGNORE ? i==1 : i>=1;
+        processChild =
+            options::userPatternsQuant() == options::UserPatMode::IGNORE
+                ? i == 1
+                : i >= 1;
       }
       if( processChild ){
-        if( n.getKind()==kind::APPLY_UF ){
-          Assert( d_op_arg_types.find( op )!=d_op_arg_types.end() );
+        if (isHandledApplyUf(n.getKind()))
+        {
+          Assert(d_op_arg_types.find(op) != d_op_arg_types.end());
           tnnc = getOrCreateTypeForId( d_op_arg_types[op][i], n[i].getType() );
-          Assert( !tnnc.isNull() );
-        }else if( n.getKind()==kind::EQUAL && i==0 ){
-          Assert( d_equality_types.find( n )!=d_equality_types.end() );
+          Assert(!tnnc.isNull());
+        }
+        else if (n.getKind() == kind::EQUAL && !n[0].getType().isBoolean()
+                 && i == 0)
+        {
+          Assert(d_equality_types.find(n) != d_equality_types.end());
           tnnc = getOrCreateTypeForId( d_equality_types[n], n[0].getType() );
-          Assert( !tnnc.isNull() );
+          Assert(!tnnc.isNull());
         }
         Node nc = simplifyNode(n[i],
                                var_bound,
@@ -676,13 +701,15 @@ Node SortInference::simplifyNode(
     }else if( n.getKind()==kind::EQUAL ){
       TypeNode tn1 = children[0].getType();
       TypeNode tn2 = children[1].getType();
-      if( !tn1.isSubtypeOf( tn2 ) && !tn2.isSubtypeOf( tn1 ) ){
+      if( !tn1.isComparableTo( tn2 ) ){
         Trace("sort-inference-warn") << "Sort inference created bad equality: " << children[0] << " = " << children[1] << std::endl;
         Trace("sort-inference-warn") << "  Types : " << children[0].getType() << " " << children[1].getType() << std::endl;
-        Assert( false );
+        Assert(false);
       }
       ret = NodeManager::currentNM()->mkNode( kind::EQUAL, children );
-    }else if( n.getKind()==kind::APPLY_UF ){
+    }
+    else if (isHandledApplyUf(n.getKind()))
+    {
       if( d_symbol_map.find( op )==d_symbol_map.end() ){
         //make the new operator if necessary
         bool opChanged = false;
@@ -718,7 +745,7 @@ Node SortInference::simplifyNode(
         if (!tn.isSubtypeOf(tna))
         {
           Trace("sort-inference-warn") << "Sort inference created bad child: " << n << " " << n[i] << " " << tn << " " << tna << std::endl;
-          Assert( false );
+          Assert(false);
         }
       }
       ret = NodeManager::currentNM()->mkNode( kind::APPLY_UF, children );
@@ -793,7 +820,8 @@ void SortInference::setSkolemVar( Node f, Node v, Node sk ){
 }
 
 bool SortInference::isWellSortedFormula( Node n ) {
-  if( n.getType().isBoolean() && n.getKind()!=kind::APPLY_UF ){
+  if (n.getType().isBoolean() && !isHandledApplyUf(n.getKind()))
+  {
     for( unsigned i=0; i<n.getNumChildren(); i++ ){
       if( !isWellSortedFormula( n[i] ) ){
         return false;
@@ -809,7 +837,8 @@ bool SortInference::isWellSorted( Node n ) {
   if( getSortId( n )==0 ){
     return false;
   }else{
-    if( n.getKind()==kind::APPLY_UF ){
+    if (isHandledApplyUf(n.getKind()))
+    {
       for( unsigned i=0; i<n.getNumChildren(); i++ ){
         int s1 = getSortId( n[i] );
         int s2 = d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] );
@@ -825,18 +854,14 @@ bool SortInference::isWellSorted( Node n ) {
   }
 }
 
-void SortInference::getSortConstraints( Node n, UnionFind& uf ) {
-  if( n.getKind()==kind::APPLY_UF ){
-    for( unsigned i=0; i<n.getNumChildren(); i++ ){
-      getSortConstraints( n[i], uf );
-      uf.setEqual( getSortId( n[i] ), d_type_union_find.getRepresentative( d_op_arg_types[ n.getOperator() ][i] ) );
-    }
-  }
-}
-
 bool SortInference::isMonotonic( TypeNode tn ) {
-  Assert( tn.isSort() );
+  Assert(tn.isSort());
   return d_non_monotonic_sorts_orig.find( tn )==d_non_monotonic_sorts_orig.end();
 }
 
+bool SortInference::isHandledApplyUf(Kind k) const
+{
+  return k == APPLY_UF && !options::ufHo();
+}
+
 }/* CVC4 namespace */