From 7b9c2529c149a9cd046083af401cbdeadf406804 Mon Sep 17 00:00:00 2001 From: Haniel Barbosa Date: Fri, 24 Aug 2018 20:19:14 -0500 Subject: [PATCH] Refactor nlExtPurify preprocessing pass (#1963) --- src/Makefile.am | 2 + src/preprocessing/passes/nl_ext_purify.cpp | 130 ++++++++++++++++++ src/preprocessing/passes/nl_ext_purify.h | 57 ++++++++ src/smt/smt_engine.cpp | 102 ++------------ test/regress/Makefile.tests | 1 + .../regress/regress0/nl/nlExtPurify-test.smt2 | 15 ++ 6 files changed, 217 insertions(+), 90 deletions(-) create mode 100644 src/preprocessing/passes/nl_ext_purify.cpp create mode 100644 src/preprocessing/passes/nl_ext_purify.h create mode 100644 test/regress/regress0/nl/nlExtPurify-test.smt2 diff --git a/src/Makefile.am b/src/Makefile.am index 3b8a12fa5..d399602cb 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -85,6 +85,8 @@ libcvc4_la_SOURCES = \ preprocessing/passes/ite_removal.h \ preprocessing/passes/ite_simp.cpp \ preprocessing/passes/ite_simp.h \ + preprocessing/passes/nl_ext_purify.cpp \ + preprocessing/passes/nl_ext_purify.h \ preprocessing/passes/pseudo_boolean_processor.cpp \ preprocessing/passes/pseudo_boolean_processor.h \ preprocessing/passes/bool_to_bv.cpp \ diff --git a/src/preprocessing/passes/nl_ext_purify.cpp b/src/preprocessing/passes/nl_ext_purify.cpp new file mode 100644 index 000000000..afb092571 --- /dev/null +++ b/src/preprocessing/passes/nl_ext_purify.cpp @@ -0,0 +1,130 @@ +/********************* */ +/*! \file nl_ext_purify.cpp + ** \verbatim + ** Top contributors (to current version): + ** Haniel Barbosa + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief The NlExtPurify preprocessing pass + ** + ** Purifies non-linear terms + **/ + +#include "preprocessing/passes/nl_ext_purify.h" + +namespace CVC4 { +namespace preprocessing { +namespace passes { + +using namespace CVC4::theory; + +Node NlExtPurify::purifyNlTerms(TNode n, + NodeMap& cache, + NodeMap& bcache, + std::vector& var_eq, + bool beneathMult) +{ + if (beneathMult) + { + NodeMap::iterator find = bcache.find(n); + if (find != bcache.end()) + { + return (*find).second; + } + } + else + { + NodeMap::iterator find = cache.find(n); + if (find != cache.end()) + { + return (*find).second; + } + } + Node ret = n; + if (n.getNumChildren() > 0) + { + if (beneathMult + && (n.getKind() == kind::PLUS || n.getKind() == kind::MINUS)) + { + // don't do it if it rewrites to a constant + Node nr = Rewriter::rewrite(n); + if (nr.isConst()) + { + // return the rewritten constant + ret = nr; + } + else + { + // new variable + ret = NodeManager::currentNM()->mkSkolem( + "__purifyNl_var", + n.getType(), + "Variable introduced in purifyNl pass"); + Node np = purifyNlTerms(n, cache, bcache, var_eq, false); + var_eq.push_back(np.eqNode(ret)); + Trace("nl-ext-purify") << "Purify : " << ret << " -> " << np + << std::endl; + } + } + else + { + bool beneathMultNew = beneathMult || n.getKind() == kind::MULT; + bool childChanged = false; + std::vector children; + for (unsigned i = 0, size = n.getNumChildren(); i < size; ++i) + { + Node nc = purifyNlTerms(n[i], cache, bcache, var_eq, beneathMultNew); + childChanged = childChanged || nc != n[i]; + children.push_back(nc); + } + if (childChanged) + { + ret = NodeManager::currentNM()->mkNode(n.getKind(), children); + } + } + } + if (beneathMult) + { + bcache[n] = ret; + } + else + { + cache[n] = ret; + } + return ret; +} + +NlExtPurify::NlExtPurify(PreprocessingPassContext* preprocContext) + : PreprocessingPass(preprocContext, "nl-ext-purify"){}; + +PreprocessingPassResult NlExtPurify::applyInternal( + AssertionPipeline* assertionsToPreprocess) +{ + unordered_map cache; + unordered_map bcache; + std::vector var_eq; + unsigned size = assertionsToPreprocess->size(); + for (unsigned i = 0; i < size; ++i) + { + Node a = (*assertionsToPreprocess)[i]; + assertionsToPreprocess->replace(i, purifyNlTerms(a, cache, bcache, var_eq)); + Trace("nl-ext-purify") << "Purify : " << a << " -> " + << (*assertionsToPreprocess)[i] << "\n"; + } + if (!var_eq.empty()) + { + unsigned lastIndex = size - 1; + var_eq.insert(var_eq.begin(), (*assertionsToPreprocess)[lastIndex]); + assertionsToPreprocess->replace( + lastIndex, NodeManager::currentNM()->mkNode(kind::AND, var_eq)); + } + return PreprocessingPassResult::NO_CONFLICT; +} + +} // namespace passes +} // namespace preprocessing +} // namespace CVC4 diff --git a/src/preprocessing/passes/nl_ext_purify.h b/src/preprocessing/passes/nl_ext_purify.h new file mode 100644 index 000000000..8d28b0742 --- /dev/null +++ b/src/preprocessing/passes/nl_ext_purify.h @@ -0,0 +1,57 @@ +/********************* */ +/*! \file nl_ext_purify.h + ** \verbatim + ** Top contributors (to current version): + ** Haniel Barbosa + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief The NlExtPurify preprocessing pass + ** + ** Purifies non-linear terms by replacing sums under multiplications by fresh + ** variables + **/ + +#include "cvc4_private.h" + +#ifndef __CVC4__PREPROCESSING__PASSES__NL_EXT_PURIFY_H +#define __CVC4__PREPROCESSING__PASSES__NL_EXT_PURIFY_H + +#include +#include + +#include "expr/node.h" +#include "preprocessing/preprocessing_pass.h" +#include "preprocessing/preprocessing_pass_context.h" + +namespace CVC4 { +namespace preprocessing { +namespace passes { + +using NodeMap = std::unordered_map; + +class NlExtPurify : public PreprocessingPass +{ + public: + NlExtPurify(PreprocessingPassContext* preprocContext); + + protected: + PreprocessingPassResult applyInternal( + AssertionPipeline* assertionsToPreprocess) override; + + private: + Node purifyNlTerms(TNode n, + NodeMap& cache, + NodeMap& bcache, + std::vector& var_eq, + bool beneathMult = false); +}; + +} // namespace passes +} // namespace preprocessing +} // namespace CVC4 + +#endif /* __CVC4__PREPROCESSING__PASSES__NL_EXT_PURIFY_H */ diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp index deafcc96c..70e575487 100644 --- a/src/smt/smt_engine.cpp +++ b/src/smt/smt_engine.cpp @@ -83,6 +83,7 @@ #include "preprocessing/passes/int_to_bv.h" #include "preprocessing/passes/ite_removal.h" #include "preprocessing/passes/ite_simp.h" +#include "preprocessing/passes/nl_ext_purify.h" #include "preprocessing/passes/pseudo_boolean_processor.h" #include "preprocessing/passes/quantifiers_preprocess.h" #include "preprocessing/passes/real_to_int.h" @@ -566,14 +567,6 @@ class SmtEnginePrivate : public NodeManagerListener { */ bool nonClausalSimplify(); - /** - * Performs static learning on the assertions. - */ - void staticLearning(); - - Node realToInt(TNode n, NodeToNodeHashMap& cache, std::vector< Node >& var_eq); - Node purifyNlTerms(TNode n, NodeToNodeHashMap& cache, NodeToNodeHashMap& bcache, std::vector< Node >& var_eq, bool beneathMult = false); - /** * Helper function to fix up assertion list to restore invariants needed after * ite removal. @@ -790,7 +783,7 @@ class SmtEnginePrivate : public NodeManagerListener { /** Process a user push. */ void notifyPush() { - + } /** @@ -872,13 +865,13 @@ class SmtEnginePrivate : public NodeManagerListener { std::ostream* getReplayLog() const { return d_managedReplayLog.getReplayLog(); } - + //------------------------------- expression names // implements setExpressionName, as described in smt_engine.h void setExpressionName(Expr e, std::string name) { d_exprNames[Node::fromExpr(e)] = name; } - + // implements getExpressionName, as described in smt_engine.h bool getExpressionName(Expr e, std::string& name) const { context::CDHashMap< Node, std::string, NodeHashFunction >::const_iterator it = d_exprNames.find(e); @@ -2657,6 +2650,8 @@ void SmtEnginePrivate::finishInit() new IntToBV(d_preprocessingPassContext.get())); std::unique_ptr iteSimp( new ITESimp(d_preprocessingPassContext.get())); + std::unique_ptr nlExtPurify( + new NlExtPurify(d_preprocessingPassContext.get())); std::unique_ptr quantifiersPreprocess( new QuantifiersPreprocess(d_preprocessingPassContext.get())); std::unique_ptr pbProc( @@ -2700,6 +2695,8 @@ void SmtEnginePrivate::finishInit() std::move(globalNegate)); d_preprocessingPassRegistry.registerPass("int-to-bv", std::move(intToBV)); d_preprocessingPassRegistry.registerPass("ite-simp", std::move(iteSimp)); + d_preprocessingPassRegistry.registerPass("nl-ext-purify", + std::move(nlExtPurify)); d_preprocessingPassRegistry.registerPass("quantifiers-preprocess", std::move(quantifiersPreprocess)); d_preprocessingPassRegistry.registerPass("pseudo-boolean-processor", @@ -2712,7 +2709,7 @@ void SmtEnginePrivate::finishInit() std::move(sepSkolemEmp)); d_preprocessingPassRegistry.registerPass("sort-inference", std::move(sortInfer)); - d_preprocessingPassRegistry.registerPass("static-learning", + d_preprocessingPassRegistry.registerPass("static-learning", std::move(staticLearning)); d_preprocessingPassRegistry.registerPass("sygus-infer", std::move(sygusInfer)); @@ -2903,68 +2900,6 @@ Node SmtEnginePrivate::expandDefinitions(TNode n, unordered_map NodeMap; - -Node SmtEnginePrivate::purifyNlTerms(TNode n, NodeMap& cache, NodeMap& bcache, std::vector< Node >& var_eq, bool beneathMult) { - if( beneathMult ){ - NodeMap::iterator find = bcache.find(n); - if (find != bcache.end()) { - return (*find).second; - } - }else{ - NodeMap::iterator find = cache.find(n); - if (find != cache.end()) { - return (*find).second; - } - } - Node ret = n; - if( n.getNumChildren()>0 ){ - if (beneathMult - && (n.getKind() == kind::PLUS || n.getKind() == kind::MINUS)) - { - // don't do it if it rewrites to a constant - Node nr = Rewriter::rewrite(n); - if (nr.isConst()) - { - // return the rewritten constant - ret = nr; - } - else - { - // new variable - ret = NodeManager::currentNM()->mkSkolem( - "__purifyNl_var", - n.getType(), - "Variable introduced in purifyNl pass"); - Node np = purifyNlTerms(n, cache, bcache, var_eq, false); - var_eq.push_back(np.eqNode(ret)); - Trace("nl-ext-purify") - << "Purify : " << ret << " -> " << np << std::endl; - } - } - else - { - bool beneathMultNew = beneathMult || n.getKind()==kind::MULT; - bool childChanged = false; - std::vector< Node > children; - for( unsigned i=0; imkNode( n.getKind(), children ); - } - } - } - if( beneathMult ){ - bcache[n] = ret; - }else{ - cache[n] = ret; - } - return ret; -} - // do dumping (before/after any preprocessing pass) static void dumpAssertions(const char* key, const AssertionPipeline& assertionList) { @@ -4037,20 +3972,7 @@ void SmtEnginePrivate::processAssertions() { } if( options::nlExtPurify() ){ - unordered_map cache; - unordered_map bcache; - std::vector< Node > var_eq; - for (unsigned i = 0; i < d_assertions.size(); ++ i) { - Node a = d_assertions[i]; - d_assertions.replace(i, purifyNlTerms(a, cache, bcache, var_eq)); - Trace("nl-ext-purify") - << "Purify : " << a << " -> " << d_assertions[i] << std::endl; - } - if( !var_eq.empty() ){ - unsigned lastIndex = d_assertions.size()-1; - var_eq.insert( var_eq.begin(), d_assertions[lastIndex] ); - d_assertions.replace(lastIndex, NodeManager::currentNM()->mkNode( kind::AND, var_eq ) ); - } + d_preprocessingPassRegistry.getPass("nl-ext-purify")->apply(&d_assertions); } if( options::ceGuidedInst() ){ @@ -5527,7 +5449,7 @@ Expr SmtEngine::doQuantifierElimination(const Expr& e, bool doFull, bool strict) Assert( inst_qs.size()<=1 ); Node ret_n; if( inst_qs.size()==1 ){ - Node top_q = inst_qs[0]; + Node top_q = inst_qs[0]; //Node top_q = Rewriter::rewrite( nn_e ).negate(); Assert( top_q.getKind()==kind::FORALL ); Trace("smt-qe") << "Get qe for " << top_q << std::endl; @@ -5950,7 +5872,7 @@ void SmtEngine::setReplayStream(ExprStream* replayStream) { AlwaysAssert(!d_fullyInited, "Cannot set replay stream once fully initialized"); d_replayStream = replayStream; -} +} bool SmtEngine::getExpressionName(Expr e, std::string& name) const { return d_private->getExpressionName(e, name); diff --git a/test/regress/Makefile.tests b/test/regress/Makefile.tests index 2922085ca..f707da219 100644 --- a/test/regress/Makefile.tests +++ b/test/regress/Makefile.tests @@ -503,6 +503,7 @@ REG0_TESTS = \ regress0/nl/magnitude-wrong-1020-m.smt2 \ regress0/nl/mult-po.smt2 \ regress0/nl/nia-wrong-tl.smt2 \ + regress0/nl/nlExtPurify-test.smt2 \ regress0/nl/nta/cos-sig-value.smt2 \ regress0/nl/nta/exp-n0.5-lb.smt2 \ regress0/nl/nta/exp-n0.5-ub.smt2 \ diff --git a/test/regress/regress0/nl/nlExtPurify-test.smt2 b/test/regress/regress0/nl/nlExtPurify-test.smt2 new file mode 100644 index 000000000..1a2391c3b --- /dev/null +++ b/test/regress/regress0/nl/nlExtPurify-test.smt2 @@ -0,0 +1,15 @@ +; COMMAND-LINE: --nl-ext-purify +; EXPECT: sat +(set-info :smt-lib-version 2.6) +(set-logic QF_NRA) +(set-info :category "crafted") +(set-info :status sat) +(declare-fun skoX () Real) +(declare-fun skoS3 () Real) +(declare-fun skoSX () Real) + +(assert (and (not (<= skoX 0)) (and (not (<= (* (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX) (+ skoS3 skoSX)) 0)) (not (<= skoS3 0))))) + + +(check-sat) +(exit) -- 2.30.2