Synthesize candidate-rewrites from standard inputs (#1918)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 27 Jun 2018 19:12:17 +0000 (14:12 -0500)
committerGitHub <noreply@github.com>
Wed, 27 Jun 2018 19:12:17 +0000 (14:12 -0500)
14 files changed:
src/Makefile.am
src/options/smt_options.toml
src/preprocessing/passes/synth_rew_rules.cpp [new file with mode: 0644]
src/preprocessing/passes/synth_rew_rules.h [new file with mode: 0644]
src/smt/smt_engine.cpp
src/theory/quantifiers/candidate_rewrite_database.cpp
src/theory/quantifiers/candidate_rewrite_database.h
src/theory/quantifiers/dynamic_rewrite.cpp
src/theory/quantifiers/dynamic_rewrite.h
src/theory/quantifiers/sygus/ce_guided_conjecture.cpp
src/theory/quantifiers/sygus/term_database_sygus.cpp
src/theory/quantifiers/sygus/term_database_sygus.h
src/theory/quantifiers/sygus_sampler.cpp
src/theory/quantifiers/sygus_sampler.h

index b36c453e1fdcb4665e63830a1830c0851af9043d..b81d93081b75838ae705b50a95b32bc0833ef8d9 100644 (file)
@@ -86,6 +86,8 @@ libcvc4_la_SOURCES = \
        preprocessing/passes/symmetry_breaker.h \
        preprocessing/passes/symmetry_detect.cpp \
        preprocessing/passes/symmetry_detect.h \
+       preprocessing/passes/synth_rew_rules.cpp \
+       preprocessing/passes/synth_rew_rules.h \
        preprocessing/preprocessing_pass.cpp \
        preprocessing/preprocessing_pass.h \
        preprocessing/preprocessing_pass_context.cpp \
index 822f5c022c7526136478736b60cbbb669661ce60..ce7b3eebadd12482fae350f5ce84341f658e0430 100644 (file)
@@ -295,6 +295,22 @@ header = "options/smt_options.h"
   default    = "false"
   help       = "use aggressive extended rewriter as a preprocessing pass"
 
+[[option]]
+  name       = "synthRrPrep"
+  category   = "regular"
+  long       = "synth-rr-prep"
+  type       = "bool"
+  default    = "false"
+  help       = "synthesize and output rewrite rules during preprocessing"
+
+[[option]]
+  name       = "synthRrPrepExtRew"
+  category   = "regular"
+  long       = "synth-rr-prep-ext-rew"
+  type       = "bool"
+  default    = "false"
+  help       = "use the extended rewriter for synthRrPrep"
+
 [[option]]
   name       = "simplifyWithCareEnabled"
   category   = "regular"
diff --git a/src/preprocessing/passes/synth_rew_rules.cpp b/src/preprocessing/passes/synth_rew_rules.cpp
new file mode 100644 (file)
index 0000000..e3e3a54
--- /dev/null
@@ -0,0 +1,159 @@
+/*********************                                                        */
+/*! \file synth_rew_rules.cpp
+ ** \verbatim
+ ** Top contributors (to current version):
+ **  Andrew Reynolds
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief A technique for synthesizing candidate rewrites of the form t1 = t2,
+ ** where t1 and t2 are subterms of the input.
+ **/
+
+#include "preprocessing/passes/synth_rew_rules.h"
+
+#include "options/base_options.h"
+#include "options/quantifiers_options.h"
+#include "printer/printer.h"
+#include "theory/quantifiers/candidate_rewrite_database.h"
+
+using namespace std;
+
+namespace CVC4 {
+namespace preprocessing {
+namespace passes {
+
+// Attribute for whether we have computed rewrite rules for a given term.
+// Notice that this currently must be a global attribute, since if
+// we've computed rewrites for a term, we should not compute rewrites for the
+// same term in a subcall to another SmtEngine (for instance, when using
+// "exact" equivalence checking).
+struct SynthRrComputedAttributeId
+{
+};
+typedef expr::Attribute<SynthRrComputedAttributeId, bool>
+    SynthRrComputedAttribute;
+
+SynthRewRulesPass::SynthRewRulesPass(PreprocessingPassContext* preprocContext)
+    : PreprocessingPass(preprocContext, "synth-rr"){};
+
+PreprocessingPassResult SynthRewRulesPass::applyInternal(
+    AssertionPipeline* assertionsToPreprocess)
+{
+  Trace("synth-rr-pass") << "Synthesize rewrite rules from assertions..."
+                         << std::endl;
+  std::vector<Node>& assertions = assertionsToPreprocess->ref();
+
+  // compute the variables we will be sampling
+  std::vector<Node> vars;
+  unsigned nsamples = options::sygusSamples();
+
+  Options& nodeManagerOptions = NodeManager::currentNM()->getOptions();
+
+  // attribute to mark processed terms
+  SynthRrComputedAttribute srrca;
+
+  // initialize the candidate rewrite
+  std::unique_ptr<theory::quantifiers::CandidateRewriteDatabaseGen> crdg;
+  std::unordered_map<TNode, bool, TNodeHashFunction> visited;
+  std::unordered_map<TNode, bool, TNodeHashFunction>::iterator it;
+  std::vector<TNode> visit;
+  // two passes: the first collects the variables, the second registers the
+  // terms
+  for (unsigned r = 0; r < 2; r++)
+  {
+    visited.clear();
+    visit.clear();
+    TNode cur;
+    for (const Node& a : assertions)
+    {
+      visit.push_back(a);
+      do
+      {
+        cur = visit.back();
+        visit.pop_back();
+        it = visited.find(cur);
+        // if already processed, ignore
+        if (cur.getAttribute(SynthRrComputedAttribute()))
+        {
+          Trace("synth-rr-pass-debug")
+              << "...already processed " << cur << std::endl;
+        }
+        else if (it == visited.end())
+        {
+          Trace("synth-rr-pass-debug") << "...preprocess " << cur << std::endl;
+          visited[cur] = false;
+          Kind k = cur.getKind();
+          bool isQuant = k == kind::FORALL || k == kind::EXISTS
+                         || k == kind::LAMBDA || k == kind::CHOICE;
+          // we recurse on this node if it is not a quantified formula
+          if (!isQuant)
+          {
+            visit.push_back(cur);
+            for (const Node& cc : cur)
+            {
+              visit.push_back(cc);
+            }
+          }
+        }
+        else if (!it->second)
+        {
+          Trace("synth-rr-pass-debug") << "...postprocess " << cur << std::endl;
+          // check if all of the children are valid
+          // this ensures we do not register terms that have e.g. quantified
+          // formulas as subterms
+          bool childrenValid = true;
+          for (const Node& cc : cur)
+          {
+            Assert(visited.find(cc) != visited.end());
+            if (!visited[cc])
+            {
+              childrenValid = false;
+            }
+          }
+          if (childrenValid)
+          {
+            Trace("synth-rr-pass-debug")
+                << "...children are valid, check rewrites..." << std::endl;
+            if (r == 0)
+            {
+              if (cur.isVar())
+              {
+                vars.push_back(cur);
+              }
+            }
+            else
+            {
+              Trace("synth-rr-pass-debug") << "Add term " << cur << std::endl;
+              // mark as processed
+              cur.setAttribute(srrca, true);
+              bool ret = crdg->addTerm(cur, *nodeManagerOptions.getOut());
+              Trace("synth-rr-pass-debug") << "...return " << ret << std::endl;
+              // if we want only rewrites of minimal size terms, we would set
+              // childrenValid to false if ret is false here.
+            }
+          }
+          visited[cur] = childrenValid;
+        }
+      } while (!visit.empty());
+    }
+    if (r == 0)
+    {
+      Trace("synth-rr-pass-debug")
+          << "Initialize with " << nsamples
+          << " samples and variables : " << vars << std::endl;
+      crdg = std::unique_ptr<theory::quantifiers::CandidateRewriteDatabaseGen>(
+          new theory::quantifiers::CandidateRewriteDatabaseGen(vars, nsamples));
+    }
+  }
+
+  Trace("synth-rr-pass") << "...finished " << std::endl;
+  return PreprocessingPassResult::NO_CONFLICT;
+}
+
+}  // namespace passes
+}  // namespace preprocessing
+}  // namespace CVC4
diff --git a/src/preprocessing/passes/synth_rew_rules.h b/src/preprocessing/passes/synth_rew_rules.h
new file mode 100644 (file)
index 0000000..cf0b491
--- /dev/null
@@ -0,0 +1,48 @@
+/*********************                                                        */
+/*! \file synth_rew_rules.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ **  Andrew Reynolds
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief A technique for synthesizing candidate rewrites of the form t1 = t2,
+ ** where t1 and t2 are subterms of the input.
+ **/
+
+#ifndef __CVC4__PREPROCESSING__PASSES__SYNTH_REW_RULES_H
+#define __CVC4__PREPROCESSING__PASSES__SYNTH_REW_RULES_H
+
+#include "preprocessing/preprocessing_pass.h"
+#include "preprocessing/preprocessing_pass_context.h"
+
+namespace CVC4 {
+namespace preprocessing {
+namespace passes {
+
+/**
+ * This class computes candidate rewrite rules of the form t1 = t2, where
+ * t1 and t2 are subterms of assertionsToPreprocess. It prints
+ * "candidate-rewrite" messages on the output stream of options.
+ *
+ * In contrast to other preprocessing passes, this pass does not modify
+ * the set of assertions.
+ */
+class SynthRewRulesPass : public PreprocessingPass
+{
+ public:
+  SynthRewRulesPass(PreprocessingPassContext* preprocContext);
+
+ protected:
+  PreprocessingPassResult applyInternal(
+      AssertionPipeline* assertionsToPreprocess) override;
+};
+
+}  // namespace passes
+}  // namespace preprocessing
+}  // namespace CVC4
+
+#endif /* __CVC4__PREPROCESSING__PASSES__SYNTH_REW_RULES_H */
index 5652eeaa64edbc5389ce93997d0ab6f5a85bc3fe..ae0e80512d91c7df2db6716c33bbc2da32c30ae6 100644 (file)
@@ -80,6 +80,7 @@
 #include "preprocessing/passes/static_learning.h"
 #include "preprocessing/passes/symmetry_breaker.h"
 #include "preprocessing/passes/symmetry_detect.h"
+#include "preprocessing/passes/synth_rew_rules.h"
 #include "preprocessing/preprocessing_pass.h"
 #include "preprocessing/preprocessing_pass_context.h"
 #include "preprocessing/preprocessing_pass_registry.h"
@@ -2727,6 +2728,8 @@ void SmtEnginePrivate::finishInit() {
       new StaticLearning(d_preprocessingPassContext.get()));
   std::unique_ptr<SymBreakerPass> sbProc(
       new SymBreakerPass(d_preprocessingPassContext.get()));
+  std::unique_ptr<SynthRewRulesPass> srrProc(
+      new SynthRewRulesPass(d_preprocessingPassContext.get()));
   d_preprocessingPassRegistry.registerPass("bool-to-bv", std::move(boolToBv));
   d_preprocessingPassRegistry.registerPass("bv-abstraction",
                                            std::move(bvAbstract));
@@ -2743,6 +2746,7 @@ void SmtEnginePrivate::finishInit() {
   d_preprocessingPassRegistry.registerPass("static-learning", 
                                            std::move(staticLearning));
   d_preprocessingPassRegistry.registerPass("sym-break", std::move(sbProc));
+  d_preprocessingPassRegistry.registerPass("synth-rr", std::move(srrProc));
 }
 
 Node SmtEnginePrivate::expandDefinitions(TNode n, unordered_map<Node, Node, NodeHashFunction>& cache, bool expandOnly)
@@ -4323,6 +4327,12 @@ void SmtEnginePrivate::processAssertions() {
         ->apply(&d_assertions);
   }
 
+  if (options::synthRrPrep())
+  {
+    // do candidate rewrite rule synthesis
+    d_preprocessingPassRegistry.getPass("synth-rr")->apply(&d_assertions);
+  }
+
   Trace("smt-proc") << "SmtEnginePrivate::processAssertions() : pre-simplify" << endl;
   dumpAssertions("pre-simplify", d_assertions);
   Chat() << "simplifying assertions..." << endl;
index 03c39f718947d1e8b87113a73dd1b6a053608a77..9bbb88699f1c8f4207c7222ab250be0cca2141bf 100644 (file)
@@ -32,37 +32,93 @@ namespace CVC4 {
 namespace theory {
 namespace quantifiers {
 
-CandidateRewriteDatabase::CandidateRewriteDatabase() : d_qe(nullptr) {}
-void CandidateRewriteDatabase::initialize(QuantifiersEngine* qe,
-                                          Node f,
+// the number of d_drewrite objects we have allocated (to avoid name conflicts)
+static unsigned drewrite_counter = 0;
+
+CandidateRewriteDatabase::CandidateRewriteDatabase()
+    : d_qe(nullptr),
+      d_tds(nullptr),
+      d_ext_rewrite(nullptr),
+      d_using_sygus(false)
+{
+  if (options::sygusRewSynthFilterCong())
+  {
+    // initialize the dynamic rewriter
+    std::stringstream ss;
+    ss << "_dyn_rewriter_" << drewrite_counter;
+    drewrite_counter++;
+    d_drewrite = std::unique_ptr<DynamicRewriter>(
+        new DynamicRewriter(ss.str(), &d_fake_context));
+    d_sampler.setDynamicRewriter(d_drewrite.get());
+  }
+}
+void CandidateRewriteDatabase::initialize(ExtendedRewriter* er,
+                                          TypeNode tn,
+                                          std::vector<Node>& vars,
                                           unsigned nsamples,
-                                          bool useSygusType)
+                                          bool unique_type_ids)
+{
+  d_candidate = Node::null();
+  d_type = tn;
+  d_using_sygus = false;
+  d_qe = nullptr;
+  d_tds = nullptr;
+  d_ext_rewrite = er;
+  d_sampler.initialize(tn, vars, nsamples, unique_type_ids);
+}
+
+void CandidateRewriteDatabase::initializeSygus(QuantifiersEngine* qe,
+                                               Node f,
+                                               unsigned nsamples,
+                                               bool useSygusType)
 {
-  d_qe = qe;
   d_candidate = f;
-  d_sampler.initializeSygusExt(d_qe, f, nsamples, useSygusType);
+  d_type = f.getType();
+  Assert(d_type.isDatatype());
+  Assert(static_cast<DatatypeType>(d_type.toType()).getDatatype().isSygus());
+  d_using_sygus = true;
+  d_qe = qe;
+  d_tds = d_qe->getTermDatabaseSygus();
+  d_ext_rewrite = d_tds->getExtRewriter();
+  d_sampler.initializeSygus(d_tds, f, nsamples, useSygusType);
 }
 
-bool CandidateRewriteDatabase::addTerm(Node sol, std::ostream& out)
+bool CandidateRewriteDatabase::addTerm(Node sol,
+                                       std::ostream& out,
+                                       bool& rew_print)
 {
   bool is_unique_term = true;
-  TermDbSygus* sygusDb = d_qe->getTermDatabaseSygus();
   Node eq_sol = d_sampler.registerTerm(sol);
   // eq_sol is a candidate solution that is equivalent to sol
   if (eq_sol != sol)
   {
-    CegInstantiation* cei = d_qe->getCegInstantiation();
     is_unique_term = false;
     // if eq_sol is null, then we have an uninteresting candidate rewrite,
     // e.g. one that is alpha-equivalent to another.
-    bool success = true;
     if (!eq_sol.isNull())
     {
-      ExtendedRewriter* er = sygusDb->getExtRewriter();
-      Node solb = sygusDb->sygusToBuiltin(sol);
-      Node solbr = er->extendedRewrite(solb);
-      Node eq_solb = sygusDb->sygusToBuiltin(eq_sol);
-      Node eq_solr = er->extendedRewrite(eq_solb);
+      // get the actual term
+      Node solb = sol;
+      Node eq_solb = eq_sol;
+      if (d_using_sygus)
+      {
+        Assert(d_tds != nullptr);
+        solb = d_tds->sygusToBuiltin(sol);
+        eq_solb = d_tds->sygusToBuiltin(eq_sol);
+      }
+      // get the rewritten form
+      Node solbr;
+      Node eq_solr;
+      if (d_ext_rewrite != nullptr)
+      {
+        solbr = d_ext_rewrite->extendedRewrite(solb);
+        eq_solr = d_ext_rewrite->extendedRewrite(eq_solb);
+      }
+      else
+      {
+        solbr = Rewriter::rewrite(solb);
+        eq_solr = Rewriter::rewrite(eq_solb);
+      }
       bool verified = false;
       Trace("rr-check") << "Check candidate rewrite..." << std::endl;
       // verify it if applicable
@@ -108,27 +164,36 @@ bool CandidateRewriteDatabase::addTerm(Node sol, std::ostream& out)
         if (r.asSatisfiabilityResult().isSat() == Result::SAT)
         {
           Trace("rr-check") << "...rewrite does not hold for: " << std::endl;
-          success = false;
           is_unique_term = true;
           std::vector<Node> vars;
           d_sampler.getVariables(vars);
           std::vector<Node> pt;
           for (const Node& v : vars)
           {
-            std::map<Node, unsigned>::iterator itf = fv_index.find(v);
             Node val;
-            if (itf == fv_index.end())
+            Node refv = v;
+            // if a bound variable, map to the skolem we introduce before
+            // looking up the model value
+            if (v.getKind() == BOUND_VARIABLE)
             {
-              // not in conjecture, can use arbitrary value
-              val = v.getType().mkGroundTerm();
+              std::map<Node, unsigned>::iterator itf = fv_index.find(v);
+              if (itf == fv_index.end())
+              {
+                // not in conjecture, can use arbitrary value
+                val = v.getType().mkGroundTerm();
+              }
+              else
+              {
+                // get the model value of its skolem
+                refv = sks[itf->second];
+              }
             }
-            else
+            if (val.isNull())
             {
-              // get the model value of its skolem
-              Node sk = sks[itf->second];
-              val = Node::fromExpr(rrChecker.getValue(sk.toExpr()));
-              Trace("rr-check") << "  " << v << " -> " << val << std::endl;
+              Assert(!refv.isNull() && refv.getKind() != BOUND_VARIABLE);
+              val = Node::fromExpr(rrChecker.getValue(refv.toExpr()));
             }
+            Trace("rr-check") << "  " << v << " -> " << val << std::endl;
             pt.push_back(val);
           }
           d_sampler.addSamplePoint(pt);
@@ -145,22 +210,29 @@ bool CandidateRewriteDatabase::addTerm(Node sol, std::ostream& out)
       else
       {
         // just insist that constants are not relevant pairs
-        success = !solb.isConst() || !eq_solb.isConst();
+        is_unique_term = solb.isConst() && eq_solb.isConst();
       }
-      if (success)
+      if (!is_unique_term)
       {
         // register this as a relevant pair (helps filtering)
         d_sampler.registerRelevantPair(sol, eq_sol);
         // The analog of terms sol and eq_sol are equivalent under
         // sample points but do not rewrite to the same term. Hence,
         // this indicates a candidate rewrite.
-        Printer* p = Printer::getPrinter(options::outputLanguage());
         out << "(" << (verified ? "" : "candidate-") << "rewrite ";
-        p->toStreamSygus(out, sol);
-        out << " ";
-        p->toStreamSygus(out, eq_sol);
+        if (d_using_sygus)
+        {
+          Printer* p = Printer::getPrinter(options::outputLanguage());
+          p->toStreamSygus(out, sol);
+          out << " ";
+          p->toStreamSygus(out, eq_sol);
+        }
+        else
+        {
+          out << sol << " " << eq_sol;
+        }
         out << ")" << std::endl;
-        ++(cei->d_statistics.d_candidate_rewrites_print);
+        rew_print = true;
         // debugging information
         if (Trace.isOn("sygus-rr-debug"))
         {
@@ -169,32 +241,33 @@ bool CandidateRewriteDatabase::addTerm(Node sol, std::ostream& out)
           Trace("sygus-rr-debug")
               << "; candidate #2 ext-rewrites to: " << eq_solr << std::endl;
         }
-        if (options::sygusRewSynthAccel())
+        if (options::sygusRewSynthAccel() && d_using_sygus)
         {
+          Assert(d_tds != nullptr);
           // Add a symmetry breaking clause that excludes the larger
           // of sol and eq_sol. This effectively states that we no longer
           // wish to enumerate any term that contains sol (resp. eq_sol)
           // as a subterm.
           Node exc_sol = sol;
-          unsigned sz = sygusDb->getSygusTermSize(sol);
-          unsigned eqsz = sygusDb->getSygusTermSize(eq_sol);
+          unsigned sz = d_tds->getSygusTermSize(sol);
+          unsigned eqsz = d_tds->getSygusTermSize(eq_sol);
           if (eqsz > sz)
           {
             sz = eqsz;
             exc_sol = eq_sol;
           }
           TypeNode ptn = d_candidate.getType();
-          Node x = sygusDb->getFreeVar(ptn, 0);
-          Node lem =
-              sygusDb->getExplain()->getExplanationForEquality(x, exc_sol);
+          Node x = d_tds->getFreeVar(ptn, 0);
+          Node lem = d_tds->getExplain()->getExplanationForEquality(x, exc_sol);
           lem = lem.negate();
           Trace("sygus-rr-sb") << "Symmetry breaking lemma : " << lem
                                << std::endl;
-          sygusDb->registerSymBreakLemma(d_candidate, lem, ptn, sz);
+          d_tds->registerSymBreakLemma(d_candidate, lem, ptn, sz);
         }
       }
     }
     // We count this as a rewrite if we did not explicitly rule it out.
+    // The value of is_unique_term is false iff this call resulted in a rewrite.
     // Notice that when --sygus-rr-synth-check is enabled,
     // statistics on number of candidate rewrite rules is
     // an accurate count of (#enumerated_terms-#unique_terms) only if
@@ -203,14 +276,52 @@ bool CandidateRewriteDatabase::addTerm(Node sol, std::ostream& out)
     // rule is not useful since its variables are unordered, whereby
     // it discards it as a redundant candidate rewrite rule before
     // checking its correctness.
-    if (success)
-    {
-      ++(cei->d_statistics.d_candidate_rewrites);
-    }
   }
   return is_unique_term;
 }
 
+bool CandidateRewriteDatabase::addTerm(Node sol, std::ostream& out)
+{
+  bool rew_print = false;
+  return addTerm(sol, out, rew_print);
+}
+
+CandidateRewriteDatabaseGen::CandidateRewriteDatabaseGen(
+    std::vector<Node>& vars, unsigned nsamples)
+    : d_vars(vars.begin(), vars.end()), d_nsamples(nsamples)
+{
+}
+
+bool CandidateRewriteDatabaseGen::addTerm(Node n, std::ostream& out)
+{
+  ExtendedRewriter* er = nullptr;
+  if (options::synthRrPrepExtRew())
+  {
+    er = &d_ext_rewrite;
+  }
+  Node nr;
+  if (er == nullptr)
+  {
+    nr = Rewriter::rewrite(n);
+  }
+  else
+  {
+    nr = er->extendedRewrite(n);
+  }
+  TypeNode tn = nr.getType();
+  std::map<TypeNode, CandidateRewriteDatabase>::iterator itc = d_cdbs.find(tn);
+  if (itc == d_cdbs.end())
+  {
+    Trace("synth-rr-dbg") << "Initialize database for " << tn << std::endl;
+    // initialize with the extended rewriter owned by this class
+    d_cdbs[tn].initialize(er, tn, d_vars, d_nsamples, true);
+    itc = d_cdbs.find(tn);
+    Trace("synth-rr-dbg") << "...finish." << std::endl;
+  }
+  Trace("synth-rr-dbg") << "Add term " << nr << " for " << tn << std::endl;
+  return itc->second.addTerm(nr, out);
+}
+
 } /* CVC4::theory::quantifiers namespace */
 } /* CVC4::theory namespace */
 } /* CVC4 namespace */
index 9ca946d26ab6e68b6b0c466656e09226c72b3ee6..a2a6c5745c2ab7d00a2af384da359f1944670f10 100644 (file)
@@ -18,6 +18,9 @@
 #define __CVC4__THEORY__QUANTIFIERS__CANDIDATE_REWRITE_DATABASE_H
 
 #include <map>
+#include <memory>
+#include <unordered_set>
+#include <vector>
 #include "theory/quantifiers/sygus_sampler.h"
 
 namespace CVC4 {
@@ -43,7 +46,32 @@ class CandidateRewriteDatabase
   ~CandidateRewriteDatabase() {}
   /**  Initialize this class
    *
-   * qe : pointer to quantifiers engine,
+   * er : pointer to the extended rewriter (if any) we are using to compute
+   * candidate rewrites,
+   * tn : the return type of terms we will be testing with this class,
+   * vars : the variables we are testing substitutions for,
+   * nsamples : number of sample points this class will test,
+   * unique_type_ids : if this is set to true, then each variable is treated
+   * as unique. This affects whether or not a rewrite rule is considered
+   * redundant or not. For example the rewrite f(y)=y is redundant if
+   * f(x)=x has also been printed as a rewrite and x and y have the same type
+   * id (see SygusSampler for details). On the other hand, when a candidate
+   * rewrite database is initialized with sygus below, the type ids of the
+   * (sygus formal argument list) variables are always computed and used.
+   */
+  void initialize(ExtendedRewriter* er,
+                  TypeNode tn,
+                  std::vector<Node>& vars,
+                  unsigned nsamples,
+                  bool unique_type_ids = false);
+  /**  Initialize this class
+   *
+   * Serves the same purpose as the above function, but we will be using
+   * sygus to enumerate terms and generate samples.
+   *
+   * qe : pointer to quantifiers engine. We use the sygus term database of this
+   * quantifiers engine, and the extended rewriter of the corresponding term
+   * database when computing candidate rewrites,
    * f : a term of some SyGuS datatype type whose values we will be
    * testing under the free variables in the grammar of f. This is the
    * "candidate variable" CegConjecture::d_candidates,
@@ -55,28 +83,44 @@ class CandidateRewriteDatabase
    *
    * These arguments are used to initialize the sygus sampler class.
    */
-  void initialize(QuantifiersEngine* qe,
-                  Node f,
-                  unsigned nsamples,
-                  bool useSygusType);
+  void initializeSygus(QuantifiersEngine* qe,
+                       Node f,
+                       unsigned nsamples,
+                       bool useSygusType);
   /** add term
    *
    * Notifies this class that the solution sol was enumerated. This may
    * cause a candidate-rewrite to be printed on the output stream out.
+   * We return true if the term sol is distinct (up to equivalence) with
+   * all previous terms added to this class. The argument rew_print is set to
+   * true if this class printed a rewrite.
    */
+  bool addTerm(Node sol, std::ostream& out, bool& rew_print);
   bool addTerm(Node sol, std::ostream& out);
 
  private:
   /** reference to quantifier engine */
   QuantifiersEngine* d_qe;
-  /** the function-to-synthesize we are testing */
+  /** pointer to the sygus term database of d_qe */
+  TermDbSygus* d_tds;
+  /** pointer to the extended rewriter object we are using */
+  ExtendedRewriter* d_ext_rewrite;
+  /** the (sygus or builtin) type of terms we are testing */
+  TypeNode d_type;
+  /** the function-to-synthesize we are testing (if sygus) */
   Node d_candidate;
+  /** whether we are using sygus */
+  bool d_using_sygus;
   /** sygus sampler objects for each program variable
    *
    * This is used for the sygusRewSynth() option to synthesize new candidate
    * rewrite rules.
    */
   SygusSamplerExt d_sampler;
+  /** a (dummy) user context, used for d_drewrite */
+  context::UserContext d_fake_context;
+  /** dynamic rewriter class */
+  std::unique_ptr<DynamicRewriter> d_drewrite;
   /**
    * Cache of skolems for each free variable that appears in a synthesis check
    * (for --sygus-rr-synth-check).
@@ -84,6 +128,41 @@ class CandidateRewriteDatabase
   std::map<Node, Node> d_fv_to_skolem;
 };
 
+/**
+ * This class generates and stores candidate rewrite databases for multiple
+ * types as needed.
+ */
+class CandidateRewriteDatabaseGen
+{
+ public:
+  /** constructor
+   *
+   * vars : the variables we are testing substitutions for, for all types,
+   * nsamples : number of sample points this class will test for all types.
+   */
+  CandidateRewriteDatabaseGen(std::vector<Node>& vars, unsigned nsamples);
+  /** add term
+   *
+   * This registers term n with this class. We generate the candidate rewrite
+   * database of the appropriate type (if not allocated already), and register
+   * n with this database. This may result in "candidate-rewrite" being
+   * printed on the output stream out.
+   */
+  bool addTerm(Node n, std::ostream& out);
+
+ private:
+  /** reference to quantifier engine */
+  QuantifiersEngine* d_qe;
+  /** the variables */
+  std::vector<Node> d_vars;
+  /** the number of samples */
+  unsigned d_nsamples;
+  /** candidate rewrite databases for each type */
+  std::map<TypeNode, CandidateRewriteDatabase> d_cdbs;
+  /** an extended rewriter object */
+  ExtendedRewriter d_ext_rewrite;
+};
+
 } /* CVC4::theory::quantifiers namespace */
 } /* CVC4::theory namespace */
 } /* CVC4 namespace */
index 352d6892fae46b50df58fe56f27249e4f50b2c10..ef1cb3a9d61aa1f9dd52435a2ed33859b10ffe53 100644 (file)
@@ -23,9 +23,9 @@ namespace CVC4 {
 namespace theory {
 namespace quantifiers {
 
-DynamicRewriter::DynamicRewriter(const std::string& name, QuantifiersEngine* qe)
-    : d_equalityEngine(qe->getUserContext(), "DynamicRewriter::" + name, true),
-      d_rewrites(qe->getUserContext())
+DynamicRewriter::DynamicRewriter(const std::string& name,
+                                 context::UserContext* u)
+    : d_equalityEngine(u, "DynamicRewriter::" + name, true), d_rewrites(u)
 {
   d_equalityEngine.addFunctionKind(kind::APPLY_UF);
 }
@@ -42,6 +42,11 @@ void DynamicRewriter::addRewrite(Node a, Node b)
   // add to the equality engine
   Node ai = toInternal(a);
   Node bi = toInternal(b);
+  if (ai.isNull() || bi.isNull())
+  {
+    Trace("dyn-rewrite") << "...not internalizable." << std::endl;
+    return;
+  }
   Trace("dyn-rewrite-debug") << "Internal : " << ai << " " << bi << std::endl;
 
   Trace("dyn-rewrite-debug") << "assert eq..." << std::endl;
@@ -58,11 +63,19 @@ bool DynamicRewriter::areEqual(Node a, Node b)
   {
     return true;
   }
+  Trace("dyn-rewrite-debug") << "areEqual? : " << a << " " << b << std::endl;
   // add to the equality engine
   Node ai = toInternal(a);
   Node bi = toInternal(b);
+  if (ai.isNull() || bi.isNull())
+  {
+    Trace("dyn-rewrite") << "...not internalizable." << std::endl;
+    return false;
+  }
+  Trace("dyn-rewrite-debug") << "internal : " << ai << " " << bi << std::endl;
   d_equalityEngine.addTerm(ai);
   d_equalityEngine.addTerm(bi);
+  Trace("dyn-rewrite-debug") << "...added terms" << std::endl;
   return d_equalityEngine.areEqual(ai, bi);
 }
 
@@ -84,6 +97,12 @@ Node DynamicRewriter::toInternal(Node a)
       if (a.getKind() != APPLY_UF)
       {
         op = d_ois_trie[op].getSymbol(a);
+        // if this term involves an argument that is not of first class type,
+        // we cannot reason about it. This includes operators like str.in-re.
+        if (op.isNull())
+        {
+          return Node::null();
+        }
       }
       children.push_back(op);
     }
@@ -120,6 +139,11 @@ Node DynamicRewriter::OpInternalSymTrie::getSymbol(Node n)
   OpInternalSymTrie* curr = this;
   for (unsigned i = 0, size = ctypes.size(); i < size; i++)
   {
+    // cannot handle certain types (e.g. regular expressions or functions)
+    if (!ctypes[i].isFirstClass())
+    {
+      return Node::null();
+    }
     curr = &(curr->d_children[ctypes[i]]);
   }
   if (!curr->d_sym.isNull())
index 0c115d8a17654d98d5e00cf5dd5eb2200ce35aa4..75f668b1130acf4c93dbde1e486cf310e7f0f1a1 100644 (file)
@@ -20,7 +20,6 @@
 #include <map>
 
 #include "context/cdlist.h"
-#include "theory/quantifiers_engine.h"
 #include "theory/uf/equality_engine.h"
 
 namespace CVC4 {
@@ -55,7 +54,7 @@ class DynamicRewriter
   typedef context::CDList<Node> NodeList;
 
  public:
-  DynamicRewriter(const std::string& name, QuantifiersEngine* qe);
+  DynamicRewriter(const std::string& name, context::UserContext* u);
   ~DynamicRewriter() {}
   /** inform this class that the equality a = b holds. */
   void addRewrite(Node a, Node b);
index 61869b3551d84ac2e356eb187765c5606f3d8169..3bb0fc51a1119f67b0f6843939d3579e1050e6d3 100644 (file)
@@ -658,11 +658,20 @@ void CegConjecture::printSynthSolution( std::ostream& out, bool singleInvocation
             d_crrdb.find(prog);
         if (its == d_crrdb.end())
         {
-          d_crrdb[prog].initialize(
+          d_crrdb[prog].initializeSygus(
               d_qe, d_candidates[i], options::sygusSamples(), true);
           its = d_crrdb.find(prog);
         }
-        is_unique_term = d_crrdb[prog].addTerm(sol, out);
+        bool rew_print = false;
+        is_unique_term = d_crrdb[prog].addTerm(sol, out, rew_print);
+        if (rew_print)
+        {
+          ++(cei->d_statistics.d_candidate_rewrites_print);
+        }
+        if (!is_unique_term)
+        {
+          ++(cei->d_statistics.d_candidate_rewrites);
+        }
       }
       if (is_unique_term)
       {
index 26f26a14519a3625dc13e6f847b35441ed36e774..c6976ac62b8c5aa112f4f36b5bc5c3f0c16b9dc1 100644 (file)
 #include "theory/quantifiers/sygus/term_database_sygus.h"
 
 #include "base/cvc4_check.h"
+#include "options/base_options.h"
 #include "options/quantifiers_options.h"
+#include "printer/printer.h"
 #include "theory/arith/arith_msum.h"
 #include "theory/datatypes/datatypes_rewriter.h"
 #include "theory/quantifiers/quantifiers_attributes.h"
 #include "theory/quantifiers/term_database.h"
 #include "theory/quantifiers/term_util.h"
 #include "theory/quantifiers_engine.h"
-#include "options/base_options.h"
-#include "printer/printer.h"
 
 using namespace CVC4::kind;
 
index 286533570e30b71d327004d676ab17f68ebf16f4..44139cf0db7e16545ad84f4597aabbc3537100d6 100644 (file)
@@ -185,8 +185,10 @@ class TermDbSygus {
    * form of bn [ args / vars(tn) ], where vars(tn) is the sygus variable
    * list for type tn (see Datatype::getSygusVarList).
    */
-  Node evaluateBuiltin(TypeNode tn, Node bn, std::vector<Node>& args,
-bool tryEval = true);
+  Node evaluateBuiltin(TypeNode tn,
+                       Node bn,
+                       std::vector<Node>& args,
+                       bool tryEval = true);
   /** evaluate with unfolding
    *
    * n is any term that may involve sygus evaluation functions. This function
index c290c027ad3a00d356c15ebe9b1eeeba8d7eaf27..8da65e4cad74f6bb07681aeed3db3438a9ea3788 100644 (file)
@@ -32,7 +32,8 @@ SygusSampler::SygusSampler()
 
 void SygusSampler::initialize(TypeNode tn,
                               std::vector<Node>& vars,
-                              unsigned nsamples)
+                              unsigned nsamples,
+                              bool unique_type_ids)
 {
   d_tds = nullptr;
   d_use_sygus_type = false;
@@ -53,15 +54,23 @@ void SygusSampler::initialize(TypeNode tn,
   {
     TypeNode svt = sv.getType();
     unsigned tnid = 0;
-    std::map<TypeNode, unsigned>::iterator itt = type_to_type_id.find(svt);
-    if (itt == type_to_type_id.end())
+    if (unique_type_ids)
     {
-      type_to_type_id[svt] = type_id_counter;
+      tnid = type_id_counter;
       type_id_counter++;
     }
     else
     {
-      tnid = itt->second;
+      std::map<TypeNode, unsigned>::iterator itt = type_to_type_id.find(svt);
+      if (itt == type_to_type_id.end())
+      {
+        type_to_type_id[svt] = type_id_counter;
+        type_id_counter++;
+      }
+      else
+      {
+        tnid = itt->second;
+      }
     }
     Trace("sygus-sample-debug")
         << "Type id for " << sv << " is " << tnid << std::endl;
@@ -586,7 +595,7 @@ Node SygusSampler::getRandomValue(TypeNode tn)
     if (!s.isNull() && !r.isNull())
     {
       Rational sr = s.getConst<Rational>();
-      Rational rr = s.getConst<Rational>();
+      Rational rr = r.getConst<Rational>();
       if (rr.sgn() == 0)
       {
         return s;
@@ -597,7 +606,19 @@ Node SygusSampler::getRandomValue(TypeNode tn)
       }
     }
   }
-  return Node::null();
+  // default: use type enumerator
+  unsigned counter = 0;
+  while (Random::getRandom().pickWithProb(0.5))
+  {
+    counter++;
+  }
+  Node ret = d_tenum.getEnumerateTerm(tn, counter);
+  if (ret.isNull())
+  {
+    // beyond bounds, return the first
+    ret = d_tenum.getEnumerateTerm(tn, 0);
+  }
+  return ret;
 }
 
 Node SygusSampler::getSygusRandomValue(TypeNode tn,
@@ -719,28 +740,23 @@ void SygusSampler::registerSygusType(TypeNode tn)
   }
 }
 
-SygusSamplerExt::SygusSamplerExt() : d_ssenm(*this) {}
+SygusSamplerExt::SygusSamplerExt() : d_drewrite(nullptr), d_ssenm(*this) {}
 
-void SygusSamplerExt::initializeSygusExt(QuantifiersEngine* qe,
-                                         Node f,
-                                         unsigned nsamples,
-                                         bool useSygusType)
+void SygusSamplerExt::initializeSygus(TermDbSygus* tds,
+                                      Node f,
+                                      unsigned nsamples,
+                                      bool useSygusType)
 {
-  SygusSampler::initializeSygus(
-      qe->getTermDatabaseSygus(), f, nsamples, useSygusType);
-
-  // initialize the dynamic rewriter
-  std::stringstream ss;
-  ss << f;
-  if (options::sygusRewSynthFilterCong())
-  {
-    d_drewrite =
-        std::unique_ptr<DynamicRewriter>(new DynamicRewriter(ss.str(), qe));
-  }
+  SygusSampler::initializeSygus(tds, f, nsamples, useSygusType);
   d_pairs.clear();
   d_match_trie.clear();
 }
 
+void SygusSamplerExt::setDynamicRewriter(DynamicRewriter* dr)
+{
+  d_drewrite = dr;
+}
+
 Node SygusSamplerExt::registerTerm(Node n, bool forceKeep)
 {
   Node eq_n = SygusSampler::registerTerm(n, forceKeep);
@@ -896,6 +912,9 @@ bool SygusSamplerExt::notify(Node s,
     for (unsigned i = 0, size = vars.size(); i < size; i++)
     {
       Trace("sse-match") << "    " << vars[i] << " -> " << subs[i] << std::endl;
+      // TODO (#1923) ensure that we use an internal representation to
+      // ensure polymorphism is handled correctly
+      Assert(vars[i].getType().isComparableTo(subs[i].getType()));
     }
   }
   Assert(it != d_pairs.end());
index fcd35613b790732dec35d82aad10c1bd03357d82..d323b36bd26f43b802e399a93058955eb14d4e62 100644 (file)
@@ -21,6 +21,7 @@
 #include "theory/quantifiers/dynamic_rewrite.h"
 #include "theory/quantifiers/lazy_trie.h"
 #include "theory/quantifiers/sygus/term_database_sygus.h"
+#include "theory/quantifiers/term_enumeration.h"
 
 namespace CVC4 {
 namespace theory {
@@ -69,14 +70,20 @@ class SygusSampler : public LazyTrieEvaluator
 
   /** initialize
    *
-   * tn : the return type of terms we will be testing with this class
-   * vars : the variables we are testing substitutions for
-   * nsamples : number of sample points this class will test.
+   * tn : the return type of terms we will be testing with this class,
+   * vars : the variables we are testing substitutions for,
+   * nsamples : number of sample points this class will test,
+   * unique_type_ids : if this is set to true, then we consider each variable
+   * in vars to have a unique "type id". A type id is a finer-grained notion of
+   * type that is used to determine when a rewrite rule is redundant.
    */
-  void initialize(TypeNode tn, std::vector<Node>& vars, unsigned nsamples);
+  virtual void initialize(TypeNode tn,
+                          std::vector<Node>& vars,
+                          unsigned nsamples,
+                          bool unique_type_ids = false);
   /** initialize sygus
    *
-   * tds : pointer to sygus database,
+   * qe : pointer to quantifiers engine,
    * f : a term of some SyGuS datatype type whose values we will be
    * testing under the free variables in the grammar of f,
    * nsamples : number of sample points this class will test,
@@ -85,10 +92,10 @@ class SygusSampler : public LazyTrieEvaluator
    * terms of the analog of the type of f, that is, the builtin type that
    * f's type encodes in the deep embedding.
    */
-  void initializeSygus(TermDbSygus* tds,
-                       Node f,
-                       unsigned nsamples,
-                       bool useSygusType);
+  virtual void initializeSygus(TermDbSygus* tds,
+                               Node f,
+                               unsigned nsamples,
+                               bool useSygusType);
   /** register term n with this sampler database
    *
    * forceKeep is whether we wish to force that n is chosen as a representative
@@ -145,6 +152,8 @@ class SygusSampler : public LazyTrieEvaluator
  protected:
   /** sygus term database of d_qe */
   TermDbSygus* d_tds;
+  /** term enumerator object (used for random sampling) */
+  TermEnumeration d_tenum;
   /** samples */
   std::vector<std::vector<Node> > d_samples;
   /** data structure to check duplication of sample points */
@@ -330,11 +339,19 @@ class SygusSamplerExt : public SygusSampler
 {
  public:
   SygusSamplerExt();
-  /** initialize extended */
-  void initializeSygusExt(QuantifiersEngine* qe,
-                          Node f,
-                          unsigned nsamples,
-                          bool useSygusType);
+  /** initialize */
+  void initializeSygus(TermDbSygus* tds,
+                       Node f,
+                       unsigned nsamples,
+                       bool useSygusType) override;
+  /** set dynamic rewriter
+   *
+   * This tells this class to use the dynamic rewriter object dr. This utility
+   * is used to query whether pairs of terms are already entailed to be
+   * equal based on previous rewrite rules.
+   */
+  void setDynamicRewriter(DynamicRewriter* dr);
+
   /** register term n with this sampler database
    *
    *  For each call to registerTerm( t, ... ) that returns s, we say that
@@ -366,7 +383,6 @@ class SygusSamplerExt : public SygusSampler
    * d_drewrite utility, or is an instance of a previous pair
    */
   Node registerTerm(Node n, bool forceKeep = false) override;
-
   /** register relevant pair
    *
    * This should be called after registerTerm( n ) returns eq_n.
@@ -375,8 +391,8 @@ class SygusSamplerExt : public SygusSampler
   void registerRelevantPair(Node n, Node eq_n);
 
  private:
-  /** dynamic rewriter class */
-  std::unique_ptr<DynamicRewriter> d_drewrite;
+  /** pointer to the dynamic rewriter class */
+  DynamicRewriter* d_drewrite;
 
   //----------------------------match filtering
   /**