Refactor nlExtPurify preprocessing pass (#1963)
authorHaniel Barbosa <hanielbbarbosa@gmail.com>
Sat, 25 Aug 2018 01:19:14 +0000 (20:19 -0500)
committerGitHub <noreply@github.com>
Sat, 25 Aug 2018 01:19:14 +0000 (20:19 -0500)
src/Makefile.am
src/preprocessing/passes/nl_ext_purify.cpp [new file with mode: 0644]
src/preprocessing/passes/nl_ext_purify.h [new file with mode: 0644]
src/smt/smt_engine.cpp
test/regress/Makefile.tests
test/regress/regress0/nl/nlExtPurify-test.smt2 [new file with mode: 0644]

index 3b8a12fa501ebce9be529d0fdaf4b7a6e5833ea6..d399602cbda6ae5fb8c6d6df012f231f86001b0e 100644 (file)
@@ -85,6 +85,8 @@ libcvc4_la_SOURCES = \
        preprocessing/passes/ite_removal.h \
        preprocessing/passes/ite_simp.cpp \
        preprocessing/passes/ite_simp.h \
+       preprocessing/passes/nl_ext_purify.cpp \
+       preprocessing/passes/nl_ext_purify.h \
        preprocessing/passes/pseudo_boolean_processor.cpp \
        preprocessing/passes/pseudo_boolean_processor.h \
        preprocessing/passes/bool_to_bv.cpp \
diff --git a/src/preprocessing/passes/nl_ext_purify.cpp b/src/preprocessing/passes/nl_ext_purify.cpp
new file mode 100644 (file)
index 0000000..afb0925
--- /dev/null
@@ -0,0 +1,130 @@
+/*********************                                                        */
+/*! \file nl_ext_purify.cpp
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Haniel Barbosa
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief The NlExtPurify preprocessing pass
+ **
+ ** Purifies non-linear terms
+ **/
+
+#include "preprocessing/passes/nl_ext_purify.h"
+
+namespace CVC4 {
+namespace preprocessing {
+namespace passes {
+
+using namespace CVC4::theory;
+
+Node NlExtPurify::purifyNlTerms(TNode n,
+                                NodeMap& cache,
+                                NodeMap& bcache,
+                                std::vector<Node>& var_eq,
+                                bool beneathMult)
+{
+  if (beneathMult)
+  {
+    NodeMap::iterator find = bcache.find(n);
+    if (find != bcache.end())
+    {
+      return (*find).second;
+    }
+  }
+  else
+  {
+    NodeMap::iterator find = cache.find(n);
+    if (find != cache.end())
+    {
+      return (*find).second;
+    }
+  }
+  Node ret = n;
+  if (n.getNumChildren() > 0)
+  {
+    if (beneathMult
+        && (n.getKind() == kind::PLUS || n.getKind() == kind::MINUS))
+    {
+      // don't do it if it rewrites to a constant
+      Node nr = Rewriter::rewrite(n);
+      if (nr.isConst())
+      {
+        // return the rewritten constant
+        ret = nr;
+      }
+      else
+      {
+        // new variable
+        ret = NodeManager::currentNM()->mkSkolem(
+            "__purifyNl_var",
+            n.getType(),
+            "Variable introduced in purifyNl pass");
+        Node np = purifyNlTerms(n, cache, bcache, var_eq, false);
+        var_eq.push_back(np.eqNode(ret));
+        Trace("nl-ext-purify") << "Purify : " << ret << " -> " << np
+                               << std::endl;
+      }
+    }
+    else
+    {
+      bool beneathMultNew = beneathMult || n.getKind() == kind::MULT;
+      bool childChanged = false;
+      std::vector<Node> children;
+      for (unsigned i = 0, size = n.getNumChildren(); i < size; ++i)
+      {
+        Node nc = purifyNlTerms(n[i], cache, bcache, var_eq, beneathMultNew);
+        childChanged = childChanged || nc != n[i];
+        children.push_back(nc);
+      }
+      if (childChanged)
+      {
+        ret = NodeManager::currentNM()->mkNode(n.getKind(), children);
+      }
+    }
+  }
+  if (beneathMult)
+  {
+    bcache[n] = ret;
+  }
+  else
+  {
+    cache[n] = ret;
+  }
+  return ret;
+}
+
+NlExtPurify::NlExtPurify(PreprocessingPassContext* preprocContext)
+    : PreprocessingPass(preprocContext, "nl-ext-purify"){};
+
+PreprocessingPassResult NlExtPurify::applyInternal(
+    AssertionPipeline* assertionsToPreprocess)
+{
+  unordered_map<Node, Node, NodeHashFunction> cache;
+  unordered_map<Node, Node, NodeHashFunction> bcache;
+  std::vector<Node> var_eq;
+  unsigned size = assertionsToPreprocess->size();
+  for (unsigned i = 0; i < size; ++i)
+  {
+    Node a = (*assertionsToPreprocess)[i];
+    assertionsToPreprocess->replace(i, purifyNlTerms(a, cache, bcache, var_eq));
+    Trace("nl-ext-purify") << "Purify : " << a << " -> "
+                           << (*assertionsToPreprocess)[i] << "\n";
+  }
+  if (!var_eq.empty())
+  {
+    unsigned lastIndex = size - 1;
+    var_eq.insert(var_eq.begin(), (*assertionsToPreprocess)[lastIndex]);
+    assertionsToPreprocess->replace(
+        lastIndex, NodeManager::currentNM()->mkNode(kind::AND, var_eq));
+  }
+  return PreprocessingPassResult::NO_CONFLICT;
+}
+
+}  // namespace passes
+}  // namespace preprocessing
+}  // namespace CVC4
diff --git a/src/preprocessing/passes/nl_ext_purify.h b/src/preprocessing/passes/nl_ext_purify.h
new file mode 100644 (file)
index 0000000..8d28b07
--- /dev/null
@@ -0,0 +1,57 @@
+/*********************                                                        */
+/*! \file nl_ext_purify.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Haniel Barbosa
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief The NlExtPurify preprocessing pass
+ **
+ ** Purifies non-linear terms by replacing sums under multiplications by fresh
+ ** variables
+ **/
+
+#include "cvc4_private.h"
+
+#ifndef __CVC4__PREPROCESSING__PASSES__NL_EXT_PURIFY_H
+#define __CVC4__PREPROCESSING__PASSES__NL_EXT_PURIFY_H
+
+#include <unordered_map>
+#include <vector>
+
+#include "expr/node.h"
+#include "preprocessing/preprocessing_pass.h"
+#include "preprocessing/preprocessing_pass_context.h"
+
+namespace CVC4 {
+namespace preprocessing {
+namespace passes {
+
+using NodeMap = std::unordered_map<Node, Node, NodeHashFunction>;
+
+class NlExtPurify : public PreprocessingPass
+{
+ public:
+  NlExtPurify(PreprocessingPassContext* preprocContext);
+
+ protected:
+  PreprocessingPassResult applyInternal(
+      AssertionPipeline* assertionsToPreprocess) override;
+
+ private:
+  Node purifyNlTerms(TNode n,
+                     NodeMap& cache,
+                     NodeMap& bcache,
+                     std::vector<Node>& var_eq,
+                     bool beneathMult = false);
+};
+
+}  // namespace passes
+}  // namespace preprocessing
+}  // namespace CVC4
+
+#endif /* __CVC4__PREPROCESSING__PASSES__NL_EXT_PURIFY_H */
index deafcc96c9d7dd28a88a6749ac5de4aaae3efa2c..70e575487244774d55c8f71999429c43802b14ab 100644 (file)
@@ -83,6 +83,7 @@
 #include "preprocessing/passes/int_to_bv.h"
 #include "preprocessing/passes/ite_removal.h"
 #include "preprocessing/passes/ite_simp.h"
+#include "preprocessing/passes/nl_ext_purify.h"
 #include "preprocessing/passes/pseudo_boolean_processor.h"
 #include "preprocessing/passes/quantifiers_preprocess.h"
 #include "preprocessing/passes/real_to_int.h"
@@ -566,14 +567,6 @@ class SmtEnginePrivate : public NodeManagerListener {
    */
   bool nonClausalSimplify();
 
-  /**
-   * Performs static learning on the assertions.
-   */
-  void staticLearning();
-
-  Node realToInt(TNode n, NodeToNodeHashMap& cache, std::vector< Node >& var_eq);
-  Node purifyNlTerms(TNode n, NodeToNodeHashMap& cache, NodeToNodeHashMap& bcache, std::vector< Node >& var_eq, bool beneathMult = false);
-
   /**
    * Helper function to fix up assertion list to restore invariants needed after
    * ite removal.
@@ -790,7 +783,7 @@ class SmtEnginePrivate : public NodeManagerListener {
   /** Process a user push.
   */
   void notifyPush() {
-  
+
   }
 
   /**
@@ -872,13 +865,13 @@ class SmtEnginePrivate : public NodeManagerListener {
   std::ostream* getReplayLog() const {
     return d_managedReplayLog.getReplayLog();
   }
-  
+
   //------------------------------- expression names
   // implements setExpressionName, as described in smt_engine.h
   void setExpressionName(Expr e, std::string name) {
     d_exprNames[Node::fromExpr(e)] = name;
   }
-  
+
   // implements getExpressionName, as described in smt_engine.h
   bool getExpressionName(Expr e, std::string& name) const {
     context::CDHashMap< Node, std::string, NodeHashFunction >::const_iterator it = d_exprNames.find(e);
@@ -2657,6 +2650,8 @@ void SmtEnginePrivate::finishInit()
       new IntToBV(d_preprocessingPassContext.get()));
   std::unique_ptr<ITESimp> iteSimp(
       new ITESimp(d_preprocessingPassContext.get()));
+  std::unique_ptr<NlExtPurify> nlExtPurify(
+      new NlExtPurify(d_preprocessingPassContext.get()));
   std::unique_ptr<QuantifiersPreprocess> quantifiersPreprocess(
       new QuantifiersPreprocess(d_preprocessingPassContext.get()));
   std::unique_ptr<PseudoBooleanProcessor> pbProc(
@@ -2700,6 +2695,8 @@ void SmtEnginePrivate::finishInit()
                                            std::move(globalNegate));
   d_preprocessingPassRegistry.registerPass("int-to-bv", std::move(intToBV));
   d_preprocessingPassRegistry.registerPass("ite-simp", std::move(iteSimp));
+  d_preprocessingPassRegistry.registerPass("nl-ext-purify",
+                                           std::move(nlExtPurify));
   d_preprocessingPassRegistry.registerPass("quantifiers-preprocess",
                                            std::move(quantifiersPreprocess));
   d_preprocessingPassRegistry.registerPass("pseudo-boolean-processor",
@@ -2712,7 +2709,7 @@ void SmtEnginePrivate::finishInit()
                                            std::move(sepSkolemEmp));
   d_preprocessingPassRegistry.registerPass("sort-inference",
                                            std::move(sortInfer));
-  d_preprocessingPassRegistry.registerPass("static-learning", 
+  d_preprocessingPassRegistry.registerPass("static-learning",
                                            std::move(staticLearning));
   d_preprocessingPassRegistry.registerPass("sygus-infer",
                                            std::move(sygusInfer));
@@ -2903,68 +2900,6 @@ Node SmtEnginePrivate::expandDefinitions(TNode n, unordered_map<Node, Node, Node
   return result.top();
 }
 
-typedef std::unordered_map<Node, Node, NodeHashFunction> NodeMap;
-
-Node SmtEnginePrivate::purifyNlTerms(TNode n, NodeMap& cache, NodeMap& bcache, std::vector< Node >& var_eq, bool beneathMult) {
-  if( beneathMult ){
-    NodeMap::iterator find = bcache.find(n);
-    if (find != bcache.end()) {
-      return (*find).second;
-    }
-  }else{
-    NodeMap::iterator find = cache.find(n);
-    if (find != cache.end()) {
-      return (*find).second;
-    }
-  }
-  Node ret = n;
-  if( n.getNumChildren()>0 ){
-    if (beneathMult
-        && (n.getKind() == kind::PLUS || n.getKind() == kind::MINUS))
-    {
-      // don't do it if it rewrites to a constant
-      Node nr = Rewriter::rewrite(n);
-      if (nr.isConst())
-      {
-        // return the rewritten constant
-        ret = nr;
-      }
-      else
-      {
-        // new variable
-        ret = NodeManager::currentNM()->mkSkolem(
-            "__purifyNl_var",
-            n.getType(),
-            "Variable introduced in purifyNl pass");
-        Node np = purifyNlTerms(n, cache, bcache, var_eq, false);
-        var_eq.push_back(np.eqNode(ret));
-        Trace("nl-ext-purify")
-            << "Purify : " << ret << " -> " << np << std::endl;
-      }
-    }
-    else
-    {
-      bool beneathMultNew = beneathMult || n.getKind()==kind::MULT;
-      bool childChanged = false;
-      std::vector< Node > children;
-      for( unsigned i=0; i<n.getNumChildren(); i++ ){
-        Node nc = purifyNlTerms( n[i], cache, bcache, var_eq, beneathMultNew );
-        childChanged = childChanged || nc!=n[i];
-        children.push_back( nc );
-      }
-      if( childChanged ){
-        ret = NodeManager::currentNM()->mkNode( n.getKind(), children );
-      }
-    }
-  }
-  if( beneathMult ){
-    bcache[n] = ret;
-  }else{
-    cache[n] = ret;
-  }
-  return ret;
-}
-
 // do dumping (before/after any preprocessing pass)
 static void dumpAssertions(const char* key,
                            const AssertionPipeline& assertionList) {
@@ -4037,20 +3972,7 @@ void SmtEnginePrivate::processAssertions() {
   }
 
   if( options::nlExtPurify() ){
-    unordered_map<Node, Node, NodeHashFunction> cache;
-    unordered_map<Node, Node, NodeHashFunction> bcache;
-    std::vector< Node > var_eq;
-    for (unsigned i = 0; i < d_assertions.size(); ++ i) {
-      Node a = d_assertions[i];
-      d_assertions.replace(i, purifyNlTerms(a, cache, bcache, var_eq));
-      Trace("nl-ext-purify")
-          << "Purify : " << a << " -> " << d_assertions[i] << std::endl;
-    }
-    if( !var_eq.empty() ){
-      unsigned lastIndex = d_assertions.size()-1;
-      var_eq.insert( var_eq.begin(), d_assertions[lastIndex] );
-      d_assertions.replace(lastIndex, NodeManager::currentNM()->mkNode( kind::AND, var_eq ) );
-    }
+    d_preprocessingPassRegistry.getPass("nl-ext-purify")->apply(&d_assertions);
   }
 
   if( options::ceGuidedInst() ){
@@ -5527,7 +5449,7 @@ Expr SmtEngine::doQuantifierElimination(const Expr& e, bool doFull, bool strict)
     Assert( inst_qs.size()<=1 );
     Node ret_n;
     if( inst_qs.size()==1 ){
-      Node top_q = inst_qs[0]; 
+      Node top_q = inst_qs[0];
       //Node top_q = Rewriter::rewrite( nn_e ).negate();
       Assert( top_q.getKind()==kind::FORALL );
       Trace("smt-qe") << "Get qe for " << top_q << std::endl;
@@ -5950,7 +5872,7 @@ void SmtEngine::setReplayStream(ExprStream* replayStream) {
   AlwaysAssert(!d_fullyInited,
                "Cannot set replay stream once fully initialized");
   d_replayStream = replayStream;
-}  
+}
 
 bool SmtEngine::getExpressionName(Expr e, std::string& name) const {
   return d_private->getExpressionName(e, name);
index 2922085cafd07bc4b9e356df17190eb7a22783a4..f707da2191d077b8f78394ad32ccfd5811bb0d3a 100644 (file)
@@ -503,6 +503,7 @@ REG0_TESTS = \
        regress0/nl/magnitude-wrong-1020-m.smt2 \
        regress0/nl/mult-po.smt2 \
        regress0/nl/nia-wrong-tl.smt2 \
+       regress0/nl/nlExtPurify-test.smt2 \
        regress0/nl/nta/cos-sig-value.smt2 \
        regress0/nl/nta/exp-n0.5-lb.smt2 \
        regress0/nl/nta/exp-n0.5-ub.smt2 \
diff --git a/test/regress/regress0/nl/nlExtPurify-test.smt2 b/test/regress/regress0/nl/nlExtPurify-test.smt2
new file mode 100644 (file)
index 0000000..1a2391c
--- /dev/null
@@ -0,0 +1,15 @@
+; COMMAND-LINE: --nl-ext-purify
+; EXPECT: sat
+(set-info :smt-lib-version 2.6)
+(set-logic QF_NRA)
+(set-info :category "crafted")
+(set-info :status sat)
+(declare-fun skoX () Real)
+(declare-fun skoS3 () Real)
+(declare-fun skoSX () Real)
+
+(assert (and (not (<= skoX 0)) (and (not (<= (* (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX)) 0)) (not (<= skoS3 0)))))
+
+
+(check-sat)
+(exit)