Refactor static learning preprocessing pass (#1857)
authoryoni206 <yoni206@users.noreply.github.com>
Wed, 16 May 2018 01:17:01 +0000 (18:17 -0700)
committerGitHub <noreply@github.com>
Wed, 16 May 2018 01:17:01 +0000 (18:17 -0700)
src/Makefile.am
src/preprocessing/passes/static_learning.cpp [new file with mode: 0644]
src/preprocessing/passes/static_learning.h [new file with mode: 0644]
src/smt/smt_engine.cpp

index 17deeba812e8f97c9e13237e60d3e8963955422c..569bc3c4804288761bb5ddaebcbc39e50de1edfc 100644 (file)
@@ -80,6 +80,8 @@ libcvc4_la_SOURCES = \
        preprocessing/passes/bv_to_bool.h \
        preprocessing/passes/real_to_int.cpp \
        preprocessing/passes/real_to_int.h \
+       preprocessing/passes/static_learning.cpp \
+       preprocessing/passes/static_learning.h \
        preprocessing/passes/symmetry_breaker.cpp \
        preprocessing/passes/symmetry_breaker.h \
        preprocessing/passes/symmetry_detect.cpp \
diff --git a/src/preprocessing/passes/static_learning.cpp b/src/preprocessing/passes/static_learning.cpp
new file mode 100644 (file)
index 0000000..0a792b5
--- /dev/null
@@ -0,0 +1,55 @@
+/*********************                                                        */
+/*! \file static_learning.cpp
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Yoni Zohar
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief The static learning preprocessing pass
+ **
+ **/
+
+#include "preprocessing/passes/static_learning.h"
+
+#include <string>
+
+#include "expr/node.h"
+
+namespace CVC4 {
+namespace preprocessing {
+namespace passes {
+
+StaticLearning::StaticLearning(PreprocessingPassContext* preprocContext)
+    : PreprocessingPass(preprocContext, "static-learning"){};
+
+PreprocessingPassResult StaticLearning::applyInternal(
+    AssertionPipeline* assertionsToPreprocess)
+{
+  NodeManager::currentResourceManager()->spendResource(
+      options::preprocessStep());
+
+  for (unsigned i = 0; i < assertionsToPreprocess->size(); ++i)
+  {
+    NodeBuilder<> learned(kind::AND);
+    learned << (*assertionsToPreprocess)[i];
+    d_preprocContext->getTheoryEngine()->ppStaticLearn(
+        (*assertionsToPreprocess)[i], learned);
+    if (learned.getNumChildren() == 1)
+    {
+      learned.clear();
+    }
+    else
+    {
+      assertionsToPreprocess->replace(i, learned);
+    }
+  }
+  return PreprocessingPassResult::NO_CONFLICT;
+}
+
+}  // namespace passes
+}  // namespace preprocessing
+}  // namespace CVC4
diff --git a/src/preprocessing/passes/static_learning.h b/src/preprocessing/passes/static_learning.h
new file mode 100644 (file)
index 0000000..ade1f5a
--- /dev/null
@@ -0,0 +1,42 @@
+/*********************                                                        */
+/*! \file static_learning.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Yoni Zohar
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief The static learning preprocessing pass
+ **
+ **/
+
+#include "cvc4_private.h"
+
+#ifndef __CVC4__PREPROCESSING__PASSES__STATIC_LEARNING_H
+#define __CVC4__PREPROCESSING__PASSES__STATIC_LEARNING_H
+
+#include "preprocessing/preprocessing_pass.h"
+#include "preprocessing/preprocessing_pass_context.h"
+
+namespace CVC4 {
+namespace preprocessing {
+namespace passes {
+
+class StaticLearning : public PreprocessingPass
+{
+ public:
+  StaticLearning(PreprocessingPassContext* preprocContext);
+
+ protected:
+  PreprocessingPassResult applyInternal(
+      AssertionPipeline* assertionsToPreprocess) override;
+};
+
+}  // namespace passes
+}  // namespace preprocessing
+}  // namespace CVC4
+
+#endif /* __CVC4__PREPROCESSING__PASSES__STATIC_LEARNING_H */
index 7f34bb39e31596a5643051d02f3fff8c6801b3b8..82147c0948887c148e0734f7d33aa6ff5de21301 100644 (file)
@@ -77,6 +77,7 @@
 #include "preprocessing/passes/int_to_bv.h"
 #include "preprocessing/passes/pseudo_boolean_processor.h"
 #include "preprocessing/passes/real_to_int.h"
+#include "preprocessing/passes/static_learning.h"
 #include "preprocessing/passes/symmetry_breaker.h"
 #include "preprocessing/passes/symmetry_detect.h"
 #include "preprocessing/preprocessing_pass.h"
@@ -190,8 +191,6 @@ struct SmtEngineStatistics {
   IntStat d_numMiplibAssertionsRemoved;
   /** number of constant propagations found during nonclausal simp */
   IntStat d_numConstantProps;
-  /** time spent in static learning */
-  TimerStat d_staticLearningTime;
   /** time spent in simplifying ITEs */
   TimerStat d_simpITETime;
   /** time spent in simplifying ITEs */
@@ -233,7 +232,6 @@ struct SmtEngineStatistics {
     d_miplibPassTime("smt::SmtEngine::miplibPassTime"),
     d_numMiplibAssertionsRemoved("smt::SmtEngine::numMiplibAssertionsRemoved", 0),
     d_numConstantProps("smt::SmtEngine::numConstantProps", 0),
-    d_staticLearningTime("smt::SmtEngine::staticLearningTime"),
     d_simpITETime("smt::SmtEngine::simpITETime"),
     d_unconstrainedSimpTime("smt::SmtEngine::unconstrainedSimpTime"),
     d_iteRemovalTime("smt::SmtEngine::iteRemovalTime"),
@@ -258,7 +256,6 @@ struct SmtEngineStatistics {
     smtStatisticsRegistry()->registerStat(&d_miplibPassTime);
     smtStatisticsRegistry()->registerStat(&d_numMiplibAssertionsRemoved);
     smtStatisticsRegistry()->registerStat(&d_numConstantProps);
-    smtStatisticsRegistry()->registerStat(&d_staticLearningTime);
     smtStatisticsRegistry()->registerStat(&d_simpITETime);
     smtStatisticsRegistry()->registerStat(&d_unconstrainedSimpTime);
     smtStatisticsRegistry()->registerStat(&d_iteRemovalTime);
@@ -284,7 +281,6 @@ struct SmtEngineStatistics {
     smtStatisticsRegistry()->unregisterStat(&d_miplibPassTime);
     smtStatisticsRegistry()->unregisterStat(&d_numMiplibAssertionsRemoved);
     smtStatisticsRegistry()->unregisterStat(&d_numConstantProps);
-    smtStatisticsRegistry()->unregisterStat(&d_staticLearningTime);
     smtStatisticsRegistry()->unregisterStat(&d_simpITETime);
     smtStatisticsRegistry()->unregisterStat(&d_unconstrainedSimpTime);
     smtStatisticsRegistry()->unregisterStat(&d_iteRemovalTime);
@@ -2618,10 +2614,10 @@ void SmtEnginePrivate::finishInit() {
   // actually assembling preprocessing pipelines).
   std::unique_ptr<BoolToBV> boolToBv(
       new BoolToBV(d_preprocessingPassContext.get()));
-  std::unique_ptr<BVAckermann> bvAckermann(
-      new BVAckermann(d_preprocessingPassContext.get()));
   std::unique_ptr<BvAbstraction> bvAbstract(
       new BvAbstraction(d_preprocessingPassContext.get()));
+  std::unique_ptr<BVAckermann> bvAckermann(
+      new BVAckermann(d_preprocessingPassContext.get()));
   std::unique_ptr<BVGauss> bvGauss(
       new BVGauss(d_preprocessingPassContext.get()));
   std::unique_ptr<BvIntroPow2> bvIntroPow2(
@@ -2634,6 +2630,8 @@ void SmtEnginePrivate::finishInit() {
       new PseudoBooleanProcessor(d_preprocessingPassContext.get()));
   std::unique_ptr<RealToInt> realToInt(
       new RealToInt(d_preprocessingPassContext.get()));
+  std::unique_ptr<StaticLearning> staticLearning(
+      new StaticLearning(d_preprocessingPassContext.get()));
   std::unique_ptr<SymBreakerPass> sbProc(
       new SymBreakerPass(d_preprocessingPassContext.get()));
   d_preprocessingPassRegistry.registerPass("bool-to-bv", std::move(boolToBv));
@@ -2649,6 +2647,8 @@ void SmtEnginePrivate::finishInit() {
   d_preprocessingPassRegistry.registerPass("pseudo-boolean-processor",
                                            std::move(pbProc));
   d_preprocessingPassRegistry.registerPass("real-to-int", std::move(realToInt));
+  d_preprocessingPassRegistry.registerPass("static-learning", 
+                                           std::move(staticLearning));
   d_preprocessingPassRegistry.registerPass("sym-break", std::move(sbProc));
 }
 
@@ -2884,26 +2884,7 @@ void SmtEnginePrivate::removeITEs() {
   }
 }
 
-void SmtEnginePrivate::staticLearning() {
-  d_smt.finalOptionsAreSet();
-  spendResource(options::preprocessStep());
-
-  TimerStat::CodeTimer staticLearningTimer(d_smt.d_stats->d_staticLearningTime);
 
-  Trace("simplify") << "SmtEnginePrivate::staticLearning()" << endl;
-
-  for (unsigned i = 0; i < d_assertions.size(); ++ i) {
-
-    NodeBuilder<> learned(kind::AND);
-    learned << d_assertions[i];
-    d_smt.d_theoryEngine->ppStaticLearn(d_assertions[i], learned);
-    if(learned.getNumChildren() == 1) {
-      learned.clear();
-    } else {
-      d_assertions.replace(i, learned);
-    }
-  }
-}
 
 // do dumping (before/after any preprocessing pass)
 static void dumpAssertions(const char* key,
@@ -4264,21 +4245,12 @@ void SmtEnginePrivate::processAssertions() {
     d_preprocessingPassRegistry.getPass("sym-break")->apply(&d_assertions);
   }
 
-  dumpAssertions("pre-static-learning", d_assertions);
   if(options::doStaticLearning()) {
-    Trace("smt-proc") << "SmtEnginePrivate::processAssertions() : pre-static-learning" << endl;
-    // Perform static learning
-    Chat() << "doing static learning..." << endl;
-    Trace("simplify") << "SmtEnginePrivate::simplify(): "
-                      << "performing static learning" << endl;
-    staticLearning();
-    Trace("smt-proc") << "SmtEnginePrivate::processAssertions() : post-static-learning" << endl;
+    d_preprocessingPassRegistry.getPass("static-learning")
+        ->apply(&d_assertions);
   }
-  dumpAssertions("post-static-learning", d_assertions);
-
   Debug("smt") << " d_assertions     : " << d_assertions.size() << endl;
 
-
   Trace("smt-proc") << "SmtEnginePrivate::processAssertions() : pre-ite-removal" << endl;
   dumpAssertions("pre-ite-removal", d_assertions);
   {