Make sort inference a preprocessing pass (#2309)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 15 Aug 2018 18:02:46 +0000 (13:02 -0500)
committerGitHub <noreply@github.com>
Wed, 15 Aug 2018 18:02:46 +0000 (13:02 -0500)
src/Makefile.am
src/preprocessing/passes/sort_infer.cpp [new file with mode: 0644]
src/preprocessing/passes/sort_infer.h [new file with mode: 0644]
src/smt/smt_engine.cpp
src/theory/sort_inference.cpp
src/theory/sort_inference.h
test/regress/Makefile.tests
test/regress/regress0/fmf/sort-inf-int.smt2 [new file with mode: 0644]
test/regress/regress1/fmf/ALG008-1.smt2

index c2a620f578c9cf1f42ce5464f2a33ca6f729d41c..43aa70174351ae40a8d9fa120989322240c2ba06 100644 (file)
@@ -85,6 +85,8 @@ libcvc4_la_SOURCES = \
        preprocessing/passes/rewrite.h \
        preprocessing/passes/sep_skolem_emp.cpp \
        preprocessing/passes/sep_skolem_emp.h \
+       preprocessing/passes/sort_infer.cpp \
+       preprocessing/passes/sort_infer.h \
        preprocessing/passes/static_learning.cpp \
        preprocessing/passes/static_learning.h \
        preprocessing/passes/symmetry_breaker.cpp \
diff --git a/src/preprocessing/passes/sort_infer.cpp b/src/preprocessing/passes/sort_infer.cpp
new file mode 100644 (file)
index 0000000..e2b0bfb
--- /dev/null
@@ -0,0 +1,85 @@
+/*********************                                                        */
+/*! \file sort_infer.cpp
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Andrew Reynolds
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Sort inference preprocessing pass
+ **/
+
+#include "preprocessing/passes/sort_infer.h"
+
+#include "options/smt_options.h"
+#include "options/uf_options.h"
+#include "theory/rewriter.h"
+
+using namespace std;
+
+namespace CVC4 {
+namespace preprocessing {
+namespace passes {
+
+SortInferencePass::SortInferencePass(PreprocessingPassContext* preprocContext,
+                                     SortInference* si)
+    : PreprocessingPass(preprocContext, "sort-inference"), d_si(si)
+{
+}
+
+PreprocessingPassResult SortInferencePass::applyInternal(
+    AssertionPipeline* assertionsToPreprocess)
+{
+  if (options::sortInference())
+  {
+    d_si->initialize(assertionsToPreprocess->ref());
+    std::map<Node, Node> model_replace_f;
+    std::map<Node, std::map<TypeNode, Node> > visited;
+    for (unsigned i = 0, size = assertionsToPreprocess->size(); i < size; i++)
+    {
+      Node prev = (*assertionsToPreprocess)[i];
+      Node next = d_si->simplify(prev, model_replace_f, visited);
+      if (next != prev)
+      {
+        next = theory::Rewriter::rewrite(next);
+        assertionsToPreprocess->replace(i, next);
+        Trace("sort-infer-preprocess")
+            << "*** Preprocess SortInferencePass " << prev << endl;
+        Trace("sort-infer-preprocess")
+            << "   ...got " << (*assertionsToPreprocess)[i] << endl;
+      }
+    }
+    std::vector<Node> newAsserts;
+    d_si->getNewAssertions(newAsserts);
+    for (const Node& na : newAsserts)
+    {
+      Node nar = theory::Rewriter::rewrite(na);
+      Trace("sort-infer-preprocess")
+          << "*** Preprocess SortInferencePass : new constraint " << nar
+          << endl;
+      assertionsToPreprocess->push_back(nar);
+    }
+    // indicate correspondence between the functions
+    // TODO (#2308): move this to a better place
+    SmtEngine* smt = smt::currentSmtEngine();
+    for (const std::pair<const Node, Node>& mrf : model_replace_f)
+    {
+      smt->setPrintFuncInModel(mrf.first.toExpr(), false);
+      smt->setPrintFuncInModel(mrf.second.toExpr(), true);
+    }
+  }
+  // only need to compute monotonicity on the resulting formula if we are
+  // using this option
+  if (options::ufssFairnessMonotone())
+  {
+    d_si->computeMonotonicity(assertionsToPreprocess->ref());
+  }
+  return PreprocessingPassResult::NO_CONFLICT;
+}
+
+}  // namespace passes
+}  // namespace preprocessing
+}  // namespace CVC4
diff --git a/src/preprocessing/passes/sort_infer.h b/src/preprocessing/passes/sort_infer.h
new file mode 100644 (file)
index 0000000..e56d7ab
--- /dev/null
@@ -0,0 +1,58 @@
+/*********************                                                        */
+/*! \file sort_infer.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Andrew Reynolds
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Sort inference preprocessing pass
+ **/
+
+#ifndef __CVC4__PREPROCESSING__PASSES__SORT_INFERENCE_PASS_H_
+#define __CVC4__PREPROCESSING__PASSES__SORT_INFERENCE_PASS_H_
+
+#include <map>
+#include <string>
+#include <vector>
+#include "expr/node.h"
+
+#include "preprocessing/preprocessing_pass.h"
+#include "preprocessing/preprocessing_pass_context.h"
+#include "theory/sort_inference.h"
+
+namespace CVC4 {
+namespace preprocessing {
+namespace passes {
+
+/** SortInferencePass
+ *
+ * This preprocessing pass runs sort inference techniques on the input formula.
+ * For details on these techniques, see theory/sort_inference.h.
+ */
+class SortInferencePass : public PreprocessingPass
+{
+ public:
+  SortInferencePass(PreprocessingPassContext* preprocContext,
+                    SortInference* si);
+
+ protected:
+  PreprocessingPassResult applyInternal(
+      AssertionPipeline* assertionsToPreprocess) override;
+
+ private:
+  /**
+   * Pointer to the sort inference module. This should be the sort inference
+   * belonging to the theory engine of the current SMT engine.
+   */
+  SortInference* d_si;
+};
+
+}  // namespace passes
+}  // namespace preprocessing
+}  // namespace CVC4
+
+#endif /* __CVC4__PREPROCESSING__PASSES__SORT_INFERENCE_PASS_H_ */
index 38f6a2d5e7a89a16592d4a1f67a0d811de949247..cc6f09801928b065afb3570a73e8eaaef33fb2d8 100644 (file)
@@ -80,6 +80,7 @@
 #include "preprocessing/passes/real_to_int.h"
 #include "preprocessing/passes/rewrite.h"
 #include "preprocessing/passes/sep_skolem_emp.h"
+#include "preprocessing/passes/sort_infer.h"
 #include "preprocessing/passes/static_learning.h"
 #include "preprocessing/passes/symmetry_breaker.h"
 #include "preprocessing/passes/symmetry_detect.h"
@@ -2735,15 +2736,18 @@ void SmtEnginePrivate::finishInit()
       new RealToInt(d_preprocessingPassContext.get()));
   std::unique_ptr<Rewrite> rewrite(
       new Rewrite(d_preprocessingPassContext.get()));
+  std::unique_ptr<SortInferencePass> sortInfer(
+      new SortInferencePass(d_preprocessingPassContext.get(),
+                            d_smt.d_theoryEngine->getSortInference()));
   std::unique_ptr<StaticLearning> staticLearning(
       new StaticLearning(d_preprocessingPassContext.get()));
   std::unique_ptr<SymBreakerPass> sbProc(
       new SymBreakerPass(d_preprocessingPassContext.get()));
   std::unique_ptr<SynthRewRulesPass> srrProc(
       new SynthRewRulesPass(d_preprocessingPassContext.get()));
- std::unique_ptr<SepSkolemEmp> sepSkolemEmp(
 std::unique_ptr<SepSkolemEmp> sepSkolemEmp(
       new SepSkolemEmp(d_preprocessingPassContext.get()));
-   d_preprocessingPassRegistry.registerPass("apply-substs",
+  d_preprocessingPassRegistry.registerPass("apply-substs",
                                            std::move(applySubsts));
   d_preprocessingPassRegistry.registerPass("bool-to-bv", std::move(boolToBv));
   d_preprocessingPassRegistry.registerPass("bv-abstraction",
@@ -2761,6 +2765,8 @@ void SmtEnginePrivate::finishInit()
   d_preprocessingPassRegistry.registerPass("rewrite", std::move(rewrite));
   d_preprocessingPassRegistry.registerPass("sep-skolem-emp",
                                            std::move(sepSkolemEmp));
+  d_preprocessingPassRegistry.registerPass("sort-inference",
+                                           std::move(sortInfer));
   d_preprocessingPassRegistry.registerPass("static-learning", 
                                            std::move(staticLearning));
   d_preprocessingPassRegistry.registerPass("sym-break", std::move(sbProc));
@@ -4332,13 +4338,7 @@ void SmtEnginePrivate::processAssertions() {
   }
 
   if( options::sortInference() || options::ufssFairnessMonotone() ){
-    //sort inference technique
-    SortInference * si = d_smt.d_theoryEngine->getSortInference();
-    si->simplify( d_assertions.ref(), options::sortInference(), options::ufssFairnessMonotone() );
-    for( std::map< Node, Node >::iterator it = si->d_model_replace_f.begin(); it != si->d_model_replace_f.end(); ++it ){
-      d_smt.setPrintFuncInModel( it->first.toExpr(), false );
-      d_smt.setPrintFuncInModel( it->second.toExpr(), true );
-    }
+    d_preprocessingPassRegistry.getPass("sort-inference")->apply(&d_assertions);
   }
 
   if( options::pbRewrites() ){
index 96e1e3a383f14d502fef3e2e4d578b9f3f0f95f7..b6e8f7553ecebe25ac0ef60bc3cef860ae8d1a87 100644 (file)
@@ -102,7 +102,7 @@ void SortInference::reset() {
   d_non_monotonic_sorts.clear();
   d_type_sub_sorts.clear();
   //reset info
-  sortCount = 1;
+  d_sortCount = 1;
   d_type_union_find.clear();
   d_type_types.clear();
   d_id_for_types.clear();
@@ -114,203 +114,191 @@ void SortInference::reset() {
   d_const_map.clear();
 }
 
-void SortInference::simplify( std::vector< Node >& assertions, bool doSortInference, bool doMonotonicyInference ){
-  if( doSortInference ){
-    Trace("sort-inference-proc") << "Calculating sort inference..." << std::endl;
-    NodeManager* nm = NodeManager::currentNM();
-    //process all assertions
-    std::map< Node, int > visited;
-    for( unsigned i=0; i<assertions.size(); i++ ){
-      Trace("sort-inference-debug") << "Process " << assertions[i] << std::endl;
-      std::map< Node, Node > var_bound;
-      process( assertions[i], var_bound, visited );
-    }
-    Trace("sort-inference-proc") << "...done" << std::endl;
-    for( std::map< Node, int >::iterator it = d_op_return_types.begin(); it != d_op_return_types.end(); ++it ){
-      Trace("sort-inference") << it->first << " : ";
-      TypeNode retTn = it->first.getType();
-      if( !d_op_arg_types[ it->first ].empty() ){
-        Trace("sort-inference") << "( ";
-        for( size_t i=0; i<d_op_arg_types[ it->first ].size(); i++ ){
-          recordSubsort( retTn[i], d_op_arg_types[ it->first ][i] );
-          printSort( "sort-inference", d_op_arg_types[ it->first ][i] );
-          Trace("sort-inference") << " ";
-        }
-        Trace("sort-inference") << ") -> ";
-        retTn = retTn[(int)retTn.getNumChildren()-1];
+void SortInference::initialize(const std::vector<Node>& assertions)
+{
+  Trace("sort-inference-proc") << "Calculating sort inference..." << std::endl;
+  // process all assertions
+  std::map<Node, int> visited;
+  for (const Node& a : assertions)
+  {
+    Trace("sort-inference-debug") << "Process " << a << std::endl;
+    std::map<Node, Node> var_bound;
+    process(a, var_bound, visited);
+  }
+  Trace("sort-inference-proc") << "...done" << std::endl;
+  for (const std::pair<const Node, int>& rt : d_op_return_types)
+  {
+    Trace("sort-inference") << rt.first << " : ";
+    TypeNode retTn = rt.first.getType();
+    if (!d_op_arg_types[rt.first].empty())
+    {
+      Trace("sort-inference") << "( ";
+      for (size_t i = 0; i < d_op_arg_types[rt.first].size(); i++)
+      {
+        recordSubsort(retTn[i], d_op_arg_types[rt.first][i]);
+        printSort("sort-inference", d_op_arg_types[rt.first][i]);
+        Trace("sort-inference") << " ";
       }
-      recordSubsort( retTn, it->second );
-      printSort( "sort-inference", it->second );
-      Trace("sort-inference") << std::endl;
+      Trace("sort-inference") << ") -> ";
+      retTn = retTn[(int)retTn.getNumChildren() - 1];
     }
-    for( std::map< Node, std::map< Node, int > >::iterator it = d_var_types.begin(); it != d_var_types.end(); ++it ){
-      Trace("sort-inference") << "Quantified formula : " << it->first << " : " << std::endl;
-      for( unsigned i=0; i<it->first[0].getNumChildren(); i++ ){
-        recordSubsort( it->first[0][i].getType(), it->second[it->first[0][i]] );
-        printSort( "sort-inference", it->second[it->first[0][i]] );
-        Trace("sort-inference") << std::endl;
-      }
+    recordSubsort(retTn, rt.second);
+    printSort("sort-inference", rt.second);
+    Trace("sort-inference") << std::endl;
+  }
+  for (std::pair<const Node, std::map<Node, int> >& vt : d_var_types)
+  {
+    Trace("sort-inference")
+        << "Quantified formula : " << vt.first << " : " << std::endl;
+    for (const Node& v : vt.first[0])
+    {
+      recordSubsort(v.getType(), vt.second[v]);
+      printSort("sort-inference", vt.second[v]);
       Trace("sort-inference") << std::endl;
     }
+    Trace("sort-inference") << std::endl;
+  }
 
-    bool rewritten = false;
-    // determine monotonicity of sorts
-    Trace("sort-inference-proc") << "Calculating monotonicty for subsorts..."
-                                 << std::endl;
-    std::map<Node, std::map<int, bool> > visitedm;
-    for (const Node& a : assertions)
+  // determine monotonicity of sorts
+  Trace("sort-inference-proc")
+      << "Calculating monotonicty for subsorts..." << std::endl;
+  std::map<Node, std::map<int, bool> > visitedm;
+  for (const Node& a : assertions)
+  {
+    Trace("sort-inference-debug")
+        << "Process monotonicity for " << a << std::endl;
+    std::map<Node, Node> var_bound;
+    processMonotonic(a, true, true, var_bound, visitedm);
+  }
+  Trace("sort-inference-proc") << "...done" << std::endl;
+
+  Trace("sort-inference") << "We have " << d_sub_sorts.size()
+                          << " sub-sorts : " << std::endl;
+  for (unsigned i = 0, size = d_sub_sorts.size(); i < size; i++)
+  {
+    printSort("sort-inference", d_sub_sorts[i]);
+    if (d_type_types.find(d_sub_sorts[i]) != d_type_types.end())
     {
-      Trace("sort-inference-debug") << "Process monotonicity for " << a
-                                    << std::endl;
-      std::map<Node, Node> var_bound;
-      processMonotonic(a, true, true, var_bound, visitedm);
+      Trace("sort-inference") << " is interpreted." << std::endl;
     }
-    Trace("sort-inference-proc") << "...done" << std::endl;
-
-    Trace("sort-inference") << "We have " << d_sub_sorts.size()
-                            << " sub-sorts : " << std::endl;
-    for (unsigned i = 0, size = d_sub_sorts.size(); i < size; i++)
+    else if (d_non_monotonic_sorts.find(d_sub_sorts[i])
+             == d_non_monotonic_sorts.end())
     {
-      printSort("sort-inference", d_sub_sorts[i]);
-      if (d_type_types.find(d_sub_sorts[i]) != d_type_types.end())
-      {
-        Trace("sort-inference") << " is interpreted." << std::endl;
-      }
-      else if (d_non_monotonic_sorts.find(d_sub_sorts[i])
-               == d_non_monotonic_sorts.end())
-      {
-        Trace("sort-inference") << " is monotonic." << std::endl;
-      }
-      else
-      {
-        Trace("sort-inference") << " is not monotonic." << std::endl;
-      }
+      Trace("sort-inference") << " is monotonic." << std::endl;
     }
+    else
+    {
+      Trace("sort-inference") << " is not monotonic." << std::endl;
+    }
+  }
+}
 
-    // simplify all assertions by introducing new symbols wherever necessary
-    Trace("sort-inference-proc") << "Perform simplification..." << std::endl;
-    std::map<Node, std::map<TypeNode, Node> > visited2;
-    for (unsigned i = 0, size = assertions.size(); i < size; i++)
+Node SortInference::simplify(Node n,
+                             std::map<Node, Node>& model_replace_f,
+                             std::map<Node, std::map<TypeNode, Node> >& visited)
+{
+  Trace("sort-inference-debug") << "Simplify " << n << std::endl;
+  std::map<Node, Node> var_bound;
+  TypeNode tnn;
+  Node ret = simplifyNode(n, var_bound, tnn, model_replace_f, visited);
+  ret = theory::Rewriter::rewrite(ret);
+  return ret;
+}
+
+void SortInference::getNewAssertions(std::vector<Node>& new_asserts)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  // now, ensure constants are distinct
+  for (const std::pair<const TypeNode, std::map<Node, Node> >& cm : d_const_map)
+  {
+    std::vector<Node> consts;
+    for (const std::pair<const Node, Node>& c : cm.second)
     {
-      Node prev = assertions[i];
-      std::map<Node, Node> var_bound;
-      Trace("sort-inference-debug") << "Simplify " << prev << std::endl;
-      TypeNode tnn;
-      Node curr = simplifyNode(assertions[i], var_bound, tnn, visited2);
-      Trace("sort-inference-debug") << "Done." << std::endl;
-      if (curr != assertions[i])
-      {
-        Trace("sort-inference-debug") << "Rewrite " << curr << std::endl;
-        curr = theory::Rewriter::rewrite(curr);
-        rewritten = true;
-        Trace("sort-inference-rewrite") << assertions << std::endl;
-        Trace("sort-inference-rewrite") << " --> " << curr << std::endl;
-        PROOF(ProofManager::currentPM()->addDependence(curr, assertions[i]););
-        assertions[i] = curr;
-      }
+      Assert(c.first.isConst());
+      consts.push_back(c.second);
     }
-    Trace("sort-inference-proc") << "...done" << std::endl;
-    // now, ensure constants are distinct
-    for (std::map<TypeNode, std::map<Node, Node> >::iterator it =
-             d_const_map.begin();
-         it != d_const_map.end();
-         ++it)
+    // add lemma enforcing introduced constants to be distinct
+    if (consts.size() > 1)
     {
-      std::vector<Node> consts;
-      for (std::map<Node, Node>::iterator it2 = it->second.begin();
-           it2 != it->second.end();
-           ++it2)
-      {
-        Assert(it2->first.isConst());
-        consts.push_back(it2->second);
-      }
-      // add lemma enforcing introduced constants to be distinct
-      if (consts.size() > 1)
-      {
-        Node distinct_const = nm->mkNode(kind::DISTINCT, consts);
-        Trace("sort-inference-rewrite")
-            << "Add the constant distinctness lemma: " << std::endl;
-        Trace("sort-inference-rewrite") << "  " << distinct_const << std::endl;
-        assertions.push_back(distinct_const);
-        rewritten = true;
-      }
+      Node distinct_const = nm->mkNode(kind::DISTINCT, consts);
+      Trace("sort-inference-rewrite")
+          << "Add the constant distinctness lemma: " << std::endl;
+      Trace("sort-inference-rewrite") << "  " << distinct_const << std::endl;
+      new_asserts.push_back(distinct_const);
     }
+  }
+
+  // enforce constraints based on monotonicity
+  Trace("sort-inference-proc") << "Enforce monotonicity..." << std::endl;
 
-    // enforce constraints based on monotonicity
-    Trace("sort-inference-proc") << "Enforce monotonicity..." << std::endl;
-    for (std::map<TypeNode, std::vector<int> >::iterator it =
-             d_type_sub_sorts.begin();
-         it != d_type_sub_sorts.end();
-         ++it)
+  for (const std::pair<const TypeNode, std::vector<int> >& tss :
+       d_type_sub_sorts)
+  {
+    int nmonSort = -1;
+    unsigned nsorts = tss.second.size();
+    for (unsigned i = 0; i < nsorts; i++)
     {
-      int nmonSort = -1;
-      unsigned nsorts = it->second.size();
-      for (unsigned i = 0; i < nsorts; i++)
+      if (d_non_monotonic_sorts.find(tss.second[i])
+          != d_non_monotonic_sorts.end())
       {
-        if (d_non_monotonic_sorts.find(it->second[i])
-            != d_non_monotonic_sorts.end())
-        {
-          nmonSort = it->second[i];
-          break;
-        }
+        nmonSort = tss.second[i];
+        break;
       }
-      if (nmonSort != -1)
+    }
+    if (nmonSort != -1)
+    {
+      std::vector<Node> injections;
+      TypeNode base_tn = getOrCreateTypeForId(nmonSort, tss.first);
+      for (unsigned i = 0; i < nsorts; i++)
       {
-        std::vector<Node> injections;
-        TypeNode base_tn = getOrCreateTypeForId(nmonSort, it->first);
-        for (unsigned i = 0; i < nsorts; i++)
-        {
-          if (it->second[i] != nmonSort)
-          {
-            TypeNode new_tn = getOrCreateTypeForId(it->second[i], it->first);
-            // make injection to nmonSort
-            Node a1 = mkInjection(new_tn, base_tn);
-            injections.push_back(a1);
-            if (d_non_monotonic_sorts.find(it->second[i])
-                != d_non_monotonic_sorts.end())
-            {
-              // also must make injection from nmonSort to this
-              Node a2 = mkInjection(base_tn, new_tn);
-              injections.push_back(a2);
-            }
-          }
-        }
-        if (Trace.isOn("sort-inference-rewrite"))
+        if (tss.second[i] != nmonSort)
         {
-          Trace("sort-inference-rewrite") << "Add the following injections for " << it->first << " to ensure consistency wrt non-monotonic sorts : " << std::endl;
-          for (const Node& i : injections)
+          TypeNode new_tn = getOrCreateTypeForId(tss.second[i], tss.first);
+          // make injection to nmonSort
+          Node a1 = mkInjection(new_tn, base_tn);
+          injections.push_back(a1);
+          if (d_non_monotonic_sorts.find(tss.second[i])
+              != d_non_monotonic_sorts.end())
           {
-            Trace("sort-inference-rewrite") << "   " << i << std::endl;
+            // also must make injection from nmonSort to this
+            Node a2 = mkInjection(base_tn, new_tn);
+            injections.push_back(a2);
           }
         }
-        assertions.insert(
-            assertions.end(), injections.begin(), injections.end());
-        if (!injections.empty())
+      }
+      if (Trace.isOn("sort-inference-rewrite"))
+      {
+        Trace("sort-inference-rewrite")
+            << "Add the following injections for " << tss.first
+            << " to ensure consistency wrt non-monotonic sorts : " << std::endl;
+        for (const Node& i : injections)
         {
-          rewritten = true;
+          Trace("sort-inference-rewrite") << "   " << i << std::endl;
         }
       }
+      new_asserts.insert(
+          new_asserts.end(), injections.begin(), injections.end());
     }
-    Trace("sort-inference-proc") << "...done" << std::endl;
-    // no sub-sort information is stored
-    reset();
-    Trace("sort-inference-debug")
-        << "Finished sort inference, rewritten = " << rewritten << std::endl;
-
-    initialSortCount = sortCount;
   }
-  if( doMonotonicyInference ){
-    std::map<Node, std::map<int, bool> > visitedmt;
-    Trace("sort-inference-proc") << "Calculating monotonicty for types..." << std::endl;
-    for (const Node& a : assertions)
-    {
-      Trace("sort-inference-debug") << "Process type monotonicity for " << a
-                                    << std::endl;
-      std::map< Node, Node > var_bound;
-      processMonotonic(a, true, true, var_bound, visitedmt, true);
-    }
-    Trace("sort-inference-proc") << "...done" << std::endl;
+  Trace("sort-inference-proc") << "...done" << std::endl;
+  // no sub-sort information is stored
+  reset();
+  Trace("sort-inference-debug") << "Finished sort inference" << std::endl;
+}
+
+void SortInference::computeMonotonicity(const std::vector<Node>& assertions)
+{
+  std::map<Node, std::map<int, bool> > visitedmt;
+  Trace("sort-inference-proc")
+      << "Calculating monotonicty for types..." << std::endl;
+  for (const Node& a : assertions)
+  {
+    Trace("sort-inference-debug")
+        << "Process type monotonicity for " << a << std::endl;
+    std::map<Node, Node> var_bound;
+    processMonotonic(a, true, true, var_bound, visitedmt, true);
   }
+  Trace("sort-inference-proc") << "...done" << std::endl;
 }
 
 void SortInference::setEqual( int t1, int t2 ){
@@ -357,10 +345,10 @@ int SortInference::getIdForType( TypeNode tn ){
   //register the return type
   std::map< TypeNode, int >::iterator it = d_id_for_types.find( tn );
   if( it==d_id_for_types.end() ){
-    int sc = sortCount;
-    d_type_types[ sortCount ] = tn;
-    d_id_for_types[ tn ] = sortCount;
-    sortCount++;
+    int sc = d_sortCount;
+    d_type_types[d_sortCount] = tn;
+    d_id_for_types[tn] = d_sortCount;
+    d_sortCount++;
     return sc;
   }else{
     return it->second;
@@ -381,8 +369,8 @@ int SortInference::process( Node n, std::map< Node, Node >& var_bound, std::map<
       }else{
         for( size_t i=0; i<n[0].getNumChildren(); i++ ){
           //apply sort inference to quantified variables
-          d_var_types[n][ n[0][i] ] = sortCount;
-          sortCount++;
+          d_var_types[n][n[0][i]] = d_sortCount;
+          d_sortCount++;
 
           //type of the quantified variable must be the same
           var_bound[ n[0][i] ] = n;
@@ -439,14 +427,14 @@ int SortInference::process( Node n, std::map< Node, Node >& var_bound, std::map<
           d_op_return_types[op] = getIdForType( n.getType() );
         }else{
           //assign arbitrary sort for return type
-          d_op_return_types[op] = sortCount;
-          sortCount++;
+          d_op_return_types[op] = d_sortCount;
+          d_sortCount++;
         }
-        //d_type_eq_class[sortCount].push_back( op );
-        //assign arbitrary sort for argument types
+        // d_type_eq_class[d_sortCount].push_back( op );
+        // assign arbitrary sort for argument types
         for( size_t i=0; i<n.getNumChildren(); i++ ){
-          d_op_arg_types[op].push_back( sortCount );
-          sortCount++;
+          d_op_arg_types[op].push_back(d_sortCount);
+          d_sortCount++;
         }
       }
       for( size_t i=0; i<n.getNumChildren(); i++ ){
@@ -475,16 +463,16 @@ int SortInference::process( Node n, std::map< Node, Node >& var_bound, std::map<
         Trace("sort-inference-debug") << n << " is a variable." << std::endl;
         if( d_op_return_types.find( n )==d_op_return_types.end() ){
           //assign arbitrary sort
-          d_op_return_types[n] = sortCount;
-          sortCount++;
-          //d_type_eq_class[sortCount].push_back( n );
+          d_op_return_types[n] = d_sortCount;
+          d_sortCount++;
+          // d_type_eq_class[d_sortCount].push_back( n );
         }
         retType = d_op_return_types[n];
       }else if( n.isConst() ){
         Trace("sort-inference-debug") << n << " is a constant." << std::endl;
         //can be any type we want
-        retType = sortCount;
-        sortCount++;
+        retType = d_sortCount;
+        d_sortCount++;
       }else{
         Trace("sort-inference-debug") << n << " is a interpreted symbol." << std::endl;
         //it is an interpreted term
@@ -556,8 +544,13 @@ TypeNode SortInference::getOrCreateTypeForId( int t, TypeNode pref ){
     return d_type_types[rt];
   }else{
     TypeNode retType;
-    //see if we can assign pref
-    if( !pref.isNull() && d_id_for_types.find( pref )==d_id_for_types.end() ){
+    // See if we can assign pref. This is an optimization for reusing an
+    // uninterpreted sort as the first subsort, so that fewer symbols needed
+    // to be rewritten in the sort-inferred signature. Notice we only assign
+    // pref here if it is an uninterpreted sort.
+    if (!pref.isNull() && d_id_for_types.find(pref) == d_id_for_types.end()
+        && pref.isSort())
+    {
       retType = pref;
     }else{
       //must create new type
@@ -606,7 +599,13 @@ Node SortInference::getNewSymbol( Node old, TypeNode tn ){
   }
 }
 
-Node SortInference::simplifyNode( Node n, std::map< Node, Node >& var_bound, TypeNode tnn, std::map< Node, std::map< TypeNode, Node > >& visited ){
+Node SortInference::simplifyNode(
+    Node n,
+    std::map<Node, Node>& var_bound,
+    TypeNode tnn,
+    std::map<Node, Node>& model_replace_f,
+    std::map<Node, std::map<TypeNode, Node> >& visited)
+{
   std::map< TypeNode, Node >::iterator itv = visited[n].find( tnn );
   if( itv!=visited[n].end() ){
     return itv->second;
@@ -654,7 +653,11 @@ Node SortInference::simplifyNode( Node n, std::map< Node, Node >& var_bound, Typ
           tnnc = getOrCreateTypeForId( d_equality_types[n], n[0].getType() );
           Assert( !tnnc.isNull() );
         }
-        Node nc = simplifyNode( n[i], var_bound, tnnc, use_new_visited ? new_visited : visited );
+        Node nc = simplifyNode(n[i],
+                               var_bound,
+                               tnnc,
+                               model_replace_f,
+                               use_new_visited ? new_visited : visited);
         Trace("sort-inference-debug2") << "Simplify " << i << " " << n[i] << " returned " << nc << std::endl;
         children.push_back( nc );
         childChanged = childChanged || nc!=n[i];
@@ -701,7 +704,7 @@ Node SortInference::simplifyNode( Node n, std::map< Node, Node >& var_bound, Typ
           TypeNode typ = NodeManager::currentNM()->mkFunctionType( argTypes, retType );
           d_symbol_map[op] = NodeManager::currentNM()->mkSkolem( ss.str(), typ, "op created during sort inference" );
           Trace("setp-model") << "Function " << op << " is replaced with " << d_symbol_map[op] << std::endl;
-          d_model_replace_f[op] = d_symbol_map[op];
+          model_replace_f[op] = d_symbol_map[op];
         }else{
           d_symbol_map[op] = op;
         }
index 6daf6157a8accd10b81be00f11df7f108dcb268d..b93d5531c0e98254b001bf2b887a0b16cb7383b7 100644 (file)
 
 namespace CVC4 {
 
+/** sort inference
+ *
+ * This class implements sort inference techniques, which rewrites a
+ * formula F into an equisatisfiable formula F', where the symbols g in F are
+ * replaced by others g', possibly of different types. For details, see e.g.:
+ *   "Sort it out with Monotonicity" Claessen 2011
+ *   "Non-Cyclic Sorts for First-Order Satisfiability" Korovin 2013.
+ */
 class SortInference {
 private:
   //all subsorts
@@ -52,9 +60,10 @@ public:
     bool areEqual( int t1, int t2 ) { return getRepresentative( t1 )==getRepresentative( t2 ); }
     bool isValid();
   };
-private:
-  int sortCount;
-  int initialSortCount;
+
+ private:
+  /** the id count for all subsorts we have allocated */
+  int d_sortCount;
   UnionFind d_type_union_find;
   std::map< int, TypeNode > d_type_types;
   std::map< TypeNode, int > d_id_for_types;
@@ -70,8 +79,8 @@ private:
   void printSort( const char* c, int t );
   //process
   int process( Node n, std::map< Node, Node >& var_bound, std::map< Node, int >& visited );
-//for monotonicity inference
-private:
+  // for monotonicity inference
+ private:
   void processMonotonic( Node n, bool pol, bool hasPol, std::map< Node, Node >& var_bound, std::map< Node, std::map< int, bool > >& visited, bool typeMode = false );
 
 //for rewriting
@@ -85,17 +94,56 @@ private:
   TypeNode getTypeForId( int t );
   Node getNewSymbol( Node old, TypeNode tn );
   //simplify
-  Node simplifyNode( Node n, std::map< Node, Node >& var_bound, TypeNode tnn, std::map< Node, std::map< TypeNode, Node > >& visited );
+  Node simplifyNode(Node n,
+                    std::map<Node, Node>& var_bound,
+                    TypeNode tnn,
+                    std::map<Node, Node>& model_replace_f,
+                    std::map<Node, std::map<TypeNode, Node> >& visited);
   //make injection
   Node mkInjection( TypeNode tn1, TypeNode tn2 );
   //reset
   void reset();
 
  public:
-  SortInference() : sortCount(1), initialSortCount() {}
+  SortInference() : d_sortCount(1) {}
   ~SortInference(){}
 
-  void simplify( std::vector< Node >& assertions, bool doSortInference, bool doMonotonicyInference );
+  /** initialize
+   *
+   * This initializes this class. The input formula is indicated by assertions.
+   */
+  void initialize(const std::vector<Node>& assertions);
+  /** simplify
+   *
+   * This returns the simplified form of formula n, based on the information
+   * computed during initialization. The argument model_replace_f stores the
+   * mapping between functions and their analog in the sort-inferred signature.
+   * The argument visited is a cache of the internal results of simplifying
+   * previous nodes with this class.
+   *
+   * Must call initialize() before this function.
+   */
+  Node simplify(Node n,
+                std::map<Node, Node>& model_replace_f,
+                std::map<Node, std::map<TypeNode, Node> >& visited);
+  /** get new constraints
+   *
+   * This adds constraints to new_asserts that ensure the following.
+   * Let F be the conjunction of assertions from the input. Let F' be the
+   * conjunction of the simplified form of each conjunct in F. Let C be the
+   * conjunction of formulas adding to new_asserts. Then, F and F' ^ C are
+   * equisatisfiable.
+   */
+  void getNewAssertions(std::vector<Node>& new_asserts);
+  /** compute monotonicity
+   *
+   * This computes whether sorts are monotonic (see e.g. Claessen 2011). If
+   * this function is called, then calls to isMonotonic() can subsequently be
+   * used to query whether sorts are monotonic.
+   */
+  void computeMonotonicity(const std::vector<Node>& assertions);
+  /** return true if tn was inferred to be monotonic */
+  bool isMonotonic(TypeNode tn);
   //get sort id for term n
   int getSortId( Node n );
   //get sort id for variable of quantified formula f
@@ -108,16 +156,9 @@ public:
   bool isWellSorted( Node n );
   //get constraints for being well-typed according to computed sub-types
   void getSortConstraints( Node n, SortInference::UnionFind& uf );
-public:
-  //list of all functions and the uninterpreted symbols they were replaced with
-  std::map< Node, Node > d_model_replace_f;
-
 private:
   // store monotonicity for original sorts as well
-  std::map< TypeNode, bool > d_non_monotonic_sorts_orig;  
-public:
-  //is monotonic
-  bool isMonotonic( TypeNode tn );  
+ std::map<TypeNode, bool> d_non_monotonic_sorts_orig;
 };
 
 }
index 543fbd158a0917a661e12ce0d5e6939d4670b549..cd79fe050e254d6678d4cd46d420da9c7872fa92 100644 (file)
@@ -433,6 +433,7 @@ REG0_TESTS = \
        regress0/fmf/quant_real_univ.cvc \
        regress0/fmf/sat-logic.smt2 \
        regress0/fmf/sc_bad_model_1221.smt2 \
+       regress0/fmf/sort-inf-int.smt2 \
        regress0/fmf/syn002-si-real-int.smt2 \
        regress0/fmf/tail_rec.smt2 \
        regress0/fp/simple.smt2 \
diff --git a/test/regress/regress0/fmf/sort-inf-int.smt2 b/test/regress/regress0/fmf/sort-inf-int.smt2
new file mode 100644 (file)
index 0000000..e4a8978
--- /dev/null
@@ -0,0 +1,13 @@
+; COMMAND-LINE: --finite-model-find --sort-inference --no-check-models
+; EXPECT: sat
+(set-logic UFLIRA)
+(set-info :status sat)
+(declare-fun f (Int) Int)
+(declare-fun g (Int) Int)
+(declare-fun h (Int) Int)
+(assert (forall ((x Int)) (or (= (f x) (h x)) (= (f x) (g x)))))
+(assert (not (= (f 3) (h 3))))
+(assert (not (= (f 5) (g 5))))
+(assert (= (f 4) (g 8)))
+
+(check-sat)
index 2c3bab80d4a396e391d90ba8ee2c18923cdade04..5bf36a715602d80f50a2489d0426279fdc787962 100644 (file)
@@ -1,4 +1,5 @@
 ; COMMAND-LINE: --finite-model-find
+; COMMAND-LINE: --finite-model-find --sort-inference --no-check-models
 ; EXPECT: sat
 ;%--------------------------------------------------------------------------
 ;% File     : ALG008-1 : TPTP v5.4.0. Released v2.2.0.