namespace passes {
StaticLearning::StaticLearning(PreprocessingPassContext* preprocContext)
- : PreprocessingPass(preprocContext, "static-learning"){};
+ : PreprocessingPass(preprocContext, "static-learning"),
+ d_cache(userContext()){};
PreprocessingPassResult StaticLearning::applyInternal(
AssertionPipeline* assertionsToPreprocess)
{
d_preprocContext->spendResource(Resource::PreprocessStep);
- for (unsigned i = 0; i < assertionsToPreprocess->size(); ++i)
+ std::vector<TNode> toProcess;
+
+ for (size_t i = 0, size = assertionsToPreprocess->size(); i < size; ++i)
{
+ const Node& n = (*assertionsToPreprocess)[i];
+
+ /* Already processed in this context. */
+ if (d_cache.find(n) != d_cache.end())
+ {
+ continue;
+ }
+
NodeBuilder learned(kind::AND);
- learned << (*assertionsToPreprocess)[i];
- d_preprocContext->getTheoryEngine()->ppStaticLearn(
- (*assertionsToPreprocess)[i], learned);
+ learned << n;
+
+ /* Process all assertions in nested AND terms. */
+ std::vector<TNode> assertions;
+ flattenAnd(n, assertions);
+ for (TNode a : assertions)
+ {
+ d_preprocContext->getTheoryEngine()->ppStaticLearn(a, learned);
+ }
+
if (learned.getNumChildren() == 1)
{
learned.clear();
return PreprocessingPassResult::NO_CONFLICT;
}
+void StaticLearning::flattenAnd(TNode node, std::vector<TNode>& children)
+{
+ std::vector<TNode> visit = {node};
+ do
+ {
+ TNode cur = visit.back();
+ visit.pop_back();
+
+ if (d_cache.find(cur) != d_cache.end())
+ {
+ continue;
+ }
+ d_cache.insert(cur);
+
+ if (cur.getKind() == kind::AND)
+ {
+ visit.insert(visit.end(), cur.begin(), cur.end());
+ }
+ else
+ {
+ children.push_back(cur);
+ }
+ } while (!visit.empty());
+}
} // namespace passes
} // namespace preprocessing
#ifndef CVC5__PREPROCESSING__PASSES__STATIC_LEARNING_H
#define CVC5__PREPROCESSING__PASSES__STATIC_LEARNING_H
+#include "context/cdhashset.h"
#include "preprocessing/preprocessing_pass.h"
namespace cvc5 {
protected:
PreprocessingPassResult applyInternal(
AssertionPipeline* assertionsToPreprocess) override;
+
+ private:
+ /** Collect children of flattened AND term. */
+ void flattenAnd(TNode node, std::vector<TNode>& children);
+
+ /** CD-cache for visiting nodes used by `flattenAnd`. */
+ context::CDHashSet<Node> d_cache;
};
} // namespace passes