Make all dependencies in the fast enumerator optional (#6918)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 27 Jul 2021 13:50:38 +0000 (08:50 -0500)
committerGitHub <noreply@github.com>
Tue, 27 Jul 2021 13:50:38 +0000 (13:50 +0000)
This allows one to use a fast enumerator without having access to sygus term database, statistics, etc.

16 files changed:
src/theory/datatypes/sygus_datatype_utils.cpp
src/theory/datatypes/sygus_datatype_utils.h
src/theory/datatypes/sygus_extension.cpp
src/theory/datatypes/theory_datatypes_utils.h
src/theory/quantifiers/candidate_rewrite_database.cpp
src/theory/quantifiers/sygus/cegis_unif.cpp
src/theory/quantifiers/sygus/rcons_type_info.cpp
src/theory/quantifiers/sygus/sygus_enumerator.cpp
src/theory/quantifiers/sygus/sygus_enumerator.h
src/theory/quantifiers/sygus/sygus_explain.cpp
src/theory/quantifiers/sygus/sygus_pbe.cpp
src/theory/quantifiers/sygus/sygus_unif.cpp
src/theory/quantifiers/sygus/sygus_unif_io.cpp
src/theory/quantifiers/sygus/synth_conjecture.cpp
src/theory/quantifiers/sygus/term_database_sygus.cpp
src/theory/quantifiers/sygus/term_database_sygus.h

index 7e5099d5538cc7593be301cd21802f2c6ab98e02..72ddd7b0e79e1c2dea7a571290c74b19d7474694 100644 (file)
@@ -718,6 +718,24 @@ TypeNode substituteAndGeneralizeSygusType(TypeNode sdt,
   return sdtS;
 }
 
+unsigned getSygusTermSize(Node n)
+{
+  if (n.getKind() != APPLY_CONSTRUCTOR)
+  {
+    return 0;
+  }
+  unsigned sum = 0;
+  for (const Node& nc : n)
+  {
+    sum += getSygusTermSize(nc);
+  }
+  const DType& dt = datatypeOf(n.getOperator());
+  int cindex = indexOf(n.getOperator());
+  Assert(cindex >= 0 && static_cast<size_t>(cindex) < dt.getNumConstructors());
+  unsigned weight = dt[cindex].getWeight();
+  return weight + sum;
+}
+
 }  // namespace utils
 }  // namespace datatypes
 }  // namespace theory
index 6f3791a4de766079bcdbfdc75980147ccfee64ad..35672434c77f050b2ebdf4ae02cb769b01caafb2 100644 (file)
@@ -227,6 +227,11 @@ TypeNode substituteAndGeneralizeSygusType(TypeNode sdt,
                                           const std::vector<Node>& syms,
                                           const std::vector<Node>& vars);
 
+/**
+ * Get SyGuS term size, which is based on the weight of constructor applications
+ * in n.
+ */
+unsigned getSygusTermSize(Node n);
 // ------------------------ end sygus utils
 
 }  // namespace utils
index ee96b95d8c1c238c08b4e1c3770c16192dedd01c..63af575921ada6cf1ca4063bee6fa5d97c83482a 100644 (file)
@@ -1040,7 +1040,7 @@ Node SygusExtension::registerSearchValue(Node a,
     Node bvr = d_tds->getExtRewriter()->extendedRewrite(bv);
     Trace("sygus-sb-debug") << "  ......search value rewrites to " << bvr << std::endl;
     Trace("dt-sygus") << "  * DT builtin : " << n << " -> " << bvr << std::endl;
-    unsigned sz = d_tds->getSygusTermSize( nv );      
+    unsigned sz = utils::getSygusTermSize(nv);
     if( d_tds->involvesDivByZero( bvr ) ){
       quantifiers::DivByZeroSygusInvarianceTest dbzet;
       Trace("sygus-sb-mexp-debug") << "Minimize explanation for div-by-zero in "
@@ -1143,7 +1143,7 @@ Node SygusExtension::registerSearchValue(Node a,
           }
           Trace("sygus-sb-exc") << std::endl;
         }
-        Assert(d_tds->getSygusTermSize(bad_val) == sz);
+        Assert(utils::getSygusTermSize(bad_val) == sz);
 
         // generalize the explanation for why the analog of bad_val
         // is equivalent to bvr
@@ -1177,7 +1177,7 @@ void SygusExtension::registerSymBreakLemmaForValue(
 {
   TypeNode tn = val.getType();
   Node x = getFreeVar(tn);
-  unsigned sz = d_tds->getSygusTermSize(val);
+  unsigned sz = utils::getSygusTermSize(val);
   std::vector<Node> exp;
   d_tds->getExplain()->getExplanationFor(x, val, exp, et, valr, var_count, sz);
   Node lem =
index 898ee3491d46e63bb208856eea47690bd59db69a..705e867b9eb0c171f18e8fa69ddf1e1bc64fc64a 100644 (file)
@@ -15,8 +15,8 @@
 
 #include "cvc5_private.h"
 
-#ifndef CVC5__THEORY__STRINGS__THEORY_DATATYPES_UTILS_H
-#define CVC5__THEORY__STRINGS__THEORY_DATATYPES_UTILS_H
+#ifndef CVC5__THEORY__DATATYPES__THEORY_DATATYPES_UTILS_H
+#define CVC5__THEORY__DATATYPES__THEORY_DATATYPES_UTILS_H
 
 #include <vector>
 
index 789a723b92b54ac9bcee8967d230e36da3595d0c..c2ee563e3f2e3f9c283c1de09d34b2ea6c7ed815 100644 (file)
@@ -20,6 +20,7 @@
 #include "smt/smt_engine.h"
 #include "smt/smt_engine_scope.h"
 #include "smt/smt_statistics_registry.h"
+#include "theory/datatypes/sygus_datatype_utils.h"
 #include "theory/quantifiers/sygus/term_database_sygus.h"
 #include "theory/quantifiers/term_util.h"
 #include "theory/rewriter.h"
@@ -244,8 +245,8 @@ Node CandidateRewriteDatabase::addTerm(Node sol,
           // wish to enumerate any term that contains sol (resp. eq_sol)
           // as a subterm.
           Node exc_sol = sol;
-          unsigned sz = d_tds->getSygusTermSize(sol);
-          unsigned eqsz = d_tds->getSygusTermSize(eq_sol);
+          unsigned sz = datatypes::utils::getSygusTermSize(sol);
+          unsigned eqsz = datatypes::utils::getSygusTermSize(eq_sol);
           if (eqsz > sz)
           {
             sz = eqsz;
index 28788a5ea62bcc04dfeb873590b2245160ae41c0..544bdcc5cddd01226c8cbd1b97aca87a14f04652 100644 (file)
@@ -19,6 +19,7 @@
 #include "expr/sygus_datatype.h"
 #include "options/quantifiers_options.h"
 #include "printer/printer.h"
+#include "theory/datatypes/sygus_datatype_utils.h"
 #include "theory/quantifiers/sygus/sygus_unif_rl.h"
 #include "theory/quantifiers/sygus/synth_conjecture.h"
 #include "theory/quantifiers/sygus/term_database_sygus.h"
@@ -205,8 +206,8 @@ bool CegisUnif::getEnumValues(const std::vector<Node>& enums,
             if (curr_val < prev_val)
             {
               // must have the same size
-              unsigned prev_size = d_tds->getSygusTermSize(prev_val);
-              unsigned curr_size = d_tds->getSygusTermSize(curr_val);
+              unsigned prev_size = datatypes::utils::getSygusTermSize(prev_val);
+              unsigned curr_size = datatypes::utils::getSygusTermSize(curr_val);
               Assert(prev_size <= curr_size);
               if (curr_size == prev_size)
               {
index 1c62f030da1233bfd4b1462a584f94f8111e65f2..a1ae53ad1efe2b452da79106b9471298efc9322c 100644 (file)
@@ -31,7 +31,7 @@ void RConsTypeInfo::initialize(TermDbSygus* tds,
   NodeManager* nm = NodeManager::currentNM();
   SkolemManager* sm = nm->getSkolemManager();
 
-  d_enumerator.reset(new SygusEnumerator(tds, nullptr, s, true));
+  d_enumerator.reset(new SygusEnumerator(tds, nullptr, &s, true));
   d_enumerator->initialize(sm->mkDummySkolem("sygus_rcons", stn));
   d_crd.reset(new CandidateRewriteDatabase(true, false, true, false));
   // since initial samples are not always useful for equivalence checks, set
index 0cf92b37327173036da8b7dcda6b440e9ccd36bc..2dfd41fb448eadc61c80e69ff45420de0fcececa 100644 (file)
@@ -20,6 +20,7 @@
 #include "options/datatypes_options.h"
 #include "options/quantifiers_options.h"
 #include "smt/logic_exception.h"
+#include "theory/datatypes/sygus_datatype_utils.h"
 #include "theory/datatypes/theory_datatypes_utils.h"
 #include "theory/quantifiers/sygus/synth_engine.h"
 #include "theory/quantifiers/sygus/type_node_id_trie.h"
@@ -33,12 +34,14 @@ namespace quantifiers {
 
 SygusEnumerator::SygusEnumerator(TermDbSygus* tds,
                                  SynthConjecture* p,
-                                 SygusStatistics& s,
-                                 bool enumShapes)
+                                 SygusStatistics* s,
+                                 bool enumShapes,
+                                 bool enumAnyConstHoles)
     : d_tds(tds),
       d_parent(p),
       d_stats(s),
       d_enumShapes(enumShapes),
+      d_enumAnyConstHoles(enumAnyConstHoles),
       d_tlEnum(nullptr),
       d_abortSize(-1)
 {
@@ -54,6 +57,12 @@ void SygusEnumerator::initialize(Node e)
   d_tlEnum = getMasterEnumForType(d_etype);
   d_abortSize = options::sygusAbortSize();
 
+  // if we don't have a term database, we don't register symmetry breaking
+  // lemmas
+  if (!d_tds)
+  {
+    return;
+  }
   // Get the statically registered symmetry breaking clauses for e, see if they
   // can be used for speeding up the enumeration.
   NodeManager* nm = NodeManager::currentNM();
@@ -141,7 +150,8 @@ Node SygusEnumerator::getCurrent()
     if (d_sbExcTlCons.find(ret.getOperator()) != d_sbExcTlCons.end())
     {
       Trace("sygus-enum-exc")
-          << "Exclude (external) : " << d_tds->sygusToBuiltin(ret) << std::endl;
+          << "Exclude (external) : " << datatypes::utils::sygusToBuiltin(ret)
+          << std::endl;
       ret = Node::null();
     }
   }
@@ -330,9 +340,12 @@ bool SygusEnumerator::TermCache::addTerm(Node n)
   Assert(!n.isNull());
   if (options::sygusSymBreakDynamic())
   {
-    Node bn = d_tds->sygusToBuiltin(n);
-    Node bnr = d_tds->getExtRewriter()->extendedRewrite(bn);
-    ++(d_stats->d_enumTermsRewrite);
+    Node bn = datatypes::utils::sygusToBuiltin(n);
+    Node bnr = d_extr.extendedRewrite(bn);
+    if (d_stats != nullptr)
+    {
+      ++(d_stats->d_enumTermsRewrite);
+    }
     if (options::sygusRewVerify())
     {
       if (bn != bnr)
@@ -358,7 +371,10 @@ bool SygusEnumerator::TermCache::addTerm(Node n)
     // if we are doing PBE symmetry breaking
     if (d_eec != nullptr)
     {
-      ++(d_stats->d_enumTermsExampleEval);
+      if (d_stats != nullptr)
+      {
+        ++(d_stats->d_enumTermsExampleEval);
+      }
       // Is it equivalent under examples?
       Node bne = d_eec->addSearchVal(d_tn, bnr);
       if (!bne.isNull())
@@ -374,7 +390,10 @@ bool SygusEnumerator::TermCache::addTerm(Node n)
     }
     Trace("sygus-enum-terms") << "tc(" << d_tn << "): term " << bn << std::endl;
   }
-  ++(d_stats->d_enumTerms);
+  if (d_stats != nullptr)
+  {
+    ++(d_stats->d_enumTerms);
+  }
   d_terms.push_back(n);
   return true;
 }
@@ -474,8 +493,8 @@ Node SygusEnumerator::TermEnumSlave::getCurrent()
   Node curr = tc.getTerm(d_index);
   Trace("sygus-enum-debug2")
       << "slave(" << d_tn
-      << "): current : " << d_se->d_tds->sygusToBuiltin(curr)
-      << ", sizes = " << d_se->d_tds->getSygusTermSize(curr) << " "
+      << "): current : " << datatypes::utils::sygusToBuiltin(curr)
+      << ", sizes = " << datatypes::utils::getSygusTermSize(curr) << " "
       << getCurrentSize() << std::endl;
   Trace("sygus-enum-debug2") << "slave(" << d_tn
                              << "): indices : " << d_hasIndexNextEnd << " "
@@ -560,7 +579,7 @@ void SygusEnumerator::initializeTermCache(TypeNode tn)
   {
     eec = d_parent->getExampleEvalCache(d_enum);
   }
-  d_tcache[tn].initialize(&d_stats, d_enum, tn, d_tds, eec);
+  d_tcache[tn].initialize(d_stats, d_enum, tn, d_tds, eec);
 }
 
 SygusEnumerator::TermEnum* SygusEnumerator::getMasterEnumForType(TypeNode tn)
@@ -578,7 +597,7 @@ SygusEnumerator::TermEnum* SygusEnumerator::getMasterEnumForType(TypeNode tn)
     AlwaysAssert(ret);
     return &d_masterEnum[tn];
   }
-  if (options::sygusRepairConst())
+  if (d_enumAnyConstHoles)
   {
     std::map<TypeNode, TermEnumMasterFv>::iterator it = d_masterEnumFv.find(tn);
     if (it != d_masterEnumFv.end())
@@ -720,6 +739,7 @@ bool SygusEnumerator::TermEnumMaster::incrementInternal()
   // If we are enumerating shapes, the first enumerated term is a free variable.
   if (d_enumShapes && !d_enumShapesInit)
   {
+    Assert(d_tds != nullptr);
     Node fv = d_tds->getFreeVar(d_tn, 0);
     d_enumShapesInit = true;
     d_currTermSet = true;
@@ -1083,6 +1103,7 @@ void SygusEnumerator::TermEnumMaster::childrenToShape(
 Node SygusEnumerator::TermEnumMaster::convertShape(
     Node n, std::map<TypeNode, int>& vcounter)
 {
+  Assert(d_tds != nullptr);
   NodeManager* nm = NodeManager::currentNM();
   std::unordered_map<TNode, Node> visited;
   std::unordered_map<TNode, Node>::iterator it;
@@ -1195,6 +1216,7 @@ bool SygusEnumerator::TermEnumMasterFv::initialize(SygusEnumerator* se,
 
 Node SygusEnumerator::TermEnumMasterFv::getCurrent()
 {
+  Assert(d_se->d_tds != nullptr);
   Node ret = d_se->d_tds->getFreeVar(d_tn, d_currSize);
   Trace("sygus-enum-debug2") << "master_fv(" << d_tn << "): mk " << ret
                              << std::endl;
index 35510895771e16cd92bbd4abd62e889ae9abdc03..88133eb6e5d0e5d1dd01fbf1f2595b957d2bf284 100644 (file)
@@ -56,10 +56,23 @@ class SygusPbe;
 class SygusEnumerator : public EnumValGenerator
 {
  public:
-  SygusEnumerator(TermDbSygus* tds,
-                  SynthConjecture* p,
-                  SygusStatistics& s,
-                  bool enumShapes = false);
+  /**
+   * @param tds Pointer to the term database, required if enumShapes or
+   * enumAnyConstHoles is true, or if we want to include symmetry breaking from
+   * lemmas stored in the sygus term database,
+   * @param p Pointer to the conjecture, required if we wish to do
+   * conjecture-specific symmetry breaking
+   * @param s Pointer to the statistics
+   * @param enumShapes If true, this enumerator will generate terms having any
+   * number of free variables
+   * @param enumAnyConstHoles If true, this enumerator will generate terms where
+   * free variables are the arguments to any-constant constructors.
+   */
+  SygusEnumerator(TermDbSygus* tds = nullptr,
+                  SynthConjecture* p = nullptr,
+                  SygusStatistics* s = nullptr,
+                  bool enumShapes = false,
+                  bool enumAnyConstHoles = false);
   ~SygusEnumerator() {}
   /** initialize this class with enumerator e */
   void initialize(Node e) override;
@@ -77,10 +90,13 @@ class SygusEnumerator : public EnumValGenerator
   TermDbSygus* d_tds;
   /** pointer to the synth conjecture that owns this enumerator */
   SynthConjecture* d_parent;
-  /** reference to the statistics of parent */
-  SygusStatistics& d_stats;
+  /** pointer to the statistics */
+  SygusStatistics* d_stats;
   /** Whether we are enumerating shapes */
   bool d_enumShapes;
+  /** Whether we are enumerating free variables as arguments to any-constant
+   * constructors */
+  bool d_enumAnyConstHoles;
   /** Term cache
    *
    * This stores a list of terms for a given sygus type. The key features of
@@ -171,6 +187,8 @@ class SygusEnumerator : public EnumValGenerator
     TypeNode d_tn;
     /** pointer to term database sygus */
     TermDbSygus* d_tds;
+    /** extended rewriter */
+    ExtendedRewriter d_extr;
     /**
      * Pointer to the example evaluation cache utility (used for symmetry
      * breaking).
index 395f16beb0e732a56fb7dce77dbe34029ad914e6..23c315f425fe59808a3a6ee61d0aeebb0a28595c 100644 (file)
@@ -18,6 +18,7 @@
 #include "expr/dtype.h"
 #include "expr/dtype_cons.h"
 #include "smt/logic_exception.h"
+#include "theory/datatypes/sygus_datatype_utils.h"
 #include "theory/datatypes/theory_datatypes_utils.h"
 #include "theory/quantifiers/sygus/sygus_invariance.h"
 #include "theory/quantifiers/sygus/term_database_sygus.h"
@@ -220,7 +221,7 @@ void SygusExplain::getExplanationFor(TermRecBuild& trb,
       // we are tracking term size if positive
       if (sz >= 0)
       {
-        int s = d_tdb->getSygusTermSize(vn[i]);
+        int s = datatypes::utils::getSygusTermSize(vn[i]);
         sz = sz - s;
       }
     }
index 86d0bbc8eddbaec4b0a159620d34646dd6dc13aa..892ee6dd411af3048cd0e276cf366a4d8d6bb9fe 100644 (file)
@@ -15,6 +15,7 @@
 #include "theory/quantifiers/sygus/sygus_pbe.h"
 
 #include "options/quantifiers_options.h"
+#include "theory/datatypes/sygus_datatype_utils.h"
 #include "theory/quantifiers/sygus/example_infer.h"
 #include "theory/quantifiers/sygus/sygus_unif_io.h"
 #include "theory/quantifiers/sygus/synth_conjecture.h"
@@ -180,7 +181,7 @@ bool SygusPbe::constructCandidates(const std::vector<Node>& enums,
       Trace("sygus-pbe-enum") << std::endl;
       if (!enum_values[i].isNull())
       {
-        unsigned sz = d_tds->getSygusTermSize(enum_values[i]);
+        unsigned sz = datatypes::utils::getSygusTermSize(enum_values[i]);
         szs.push_back(sz);
         if (i == 0 || sz < min_term_size)
         {
index 16ca1f4e6ce4a1dd95b0b37d397030b50ebd8f5a..00370ffa24340bf4582bb41f54e6d2bb6495c92e 100644 (file)
@@ -15,6 +15,7 @@
 
 #include "theory/quantifiers/sygus/sygus_unif.h"
 
+#include "theory/datatypes/sygus_datatype_utils.h"
 #include "theory/quantifiers/sygus/term_database_sygus.h"
 #include "theory/quantifiers/term_util.h"
 #include "util/random.h"
@@ -52,7 +53,7 @@ Node SygusUnif::getMinimalTerm(const std::vector<Node>& terms)
     unsigned ssize = 0;
     if (it == d_termToSize.end())
     {
-      ssize = d_tds->getSygusTermSize(n);
+      ssize = datatypes::utils::getSygusTermSize(n);
       d_termToSize[n] = ssize;
     }
     else
index 8c8f5ccd438f0542b296aa1eef2b05f79aebe817..8207a07f25cedd07033d7cece5ea5c2f5ca88490 100644 (file)
@@ -16,6 +16,7 @@
 #include "theory/quantifiers/sygus/sygus_unif_io.h"
 
 #include "options/quantifiers_options.h"
+#include "theory/datatypes/sygus_datatype_utils.h"
 #include "theory/evaluator.h"
 #include "theory/quantifiers/sygus/example_infer.h"
 #include "theory/quantifiers/sygus/synth_conjecture.h"
@@ -835,7 +836,8 @@ Node SygusUnifIo::constructSolutionNode(std::vector<Node>& lemmas)
       if (!vcc.isNull()
           && (d_solution.isNull()
               || (!d_solution.isNull()
-                  && d_tds->getSygusTermSize(vcc) < d_sol_term_size)))
+                  && datatypes::utils::getSygusTermSize(vcc)
+                         < d_sol_term_size)))
       {
         if (Trace.isOn("sygus-pbe"))
         {
@@ -846,7 +848,7 @@ Node SygusUnifIo::constructSolutionNode(std::vector<Node>& lemmas)
         }
         d_solution = vcc;
         newSolution = vcc;
-        d_sol_term_size = d_tds->getSygusTermSize(vcc);
+        d_sol_term_size = datatypes::utils::getSygusTermSize(vcc);
         Trace("sygus-pbe-sol")
             << "PBE solution size: " << d_sol_term_size << std::endl;
         // We've determined its feasible, now, enable information gain and
index 1ddc2fa223a93c6c173845cb56f67ca764ebda37..73bd6b8a49e3327b9953e0cf98a033d9687d7715 100644 (file)
@@ -827,7 +827,10 @@ Node SynthConjecture::getEnumeratedValue(Node e, bool& activeIncomplete)
                    == options::SygusActiveGenMode::ENUM
                || options::sygusActiveGenMode()
                       == options::SygusActiveGenMode::AUTO);
-        d_evg[e].reset(new SygusEnumerator(d_tds, this, d_stats));
+        // if sygus repair const is enabled, we enumerate terms with free
+        // variables as arguments to any-constant constructors
+        d_evg[e].reset(new SygusEnumerator(
+            d_tds, this, &d_stats, false, options::sygusRepairConst()));
       }
     }
     Trace("sygus-active-gen")
index 8265634012633be0765257a23f2e908e3338b114..3b0ea33122affbf62b03b5368b8d490de0982ff6 100644 (file)
@@ -359,23 +359,6 @@ Node TermDbSygus::sygusToBuiltin(Node n, TypeNode tn)
   return ret;
 }
 
-unsigned TermDbSygus::getSygusTermSize( Node n ){
-  if (n.getKind() != APPLY_CONSTRUCTOR)
-  {
-    return 0;
-  }
-  unsigned sum = 0;
-  for (unsigned i = 0; i < n.getNumChildren(); i++)
-  {
-    sum += getSygusTermSize(n[i]);
-  }
-  const DType& dt = datatypes::utils::datatypeOf(n.getOperator());
-  int cindex = datatypes::utils::indexOf(n.getOperator());
-  Assert(cindex >= 0 && cindex < (int)dt.getNumConstructors());
-  unsigned weight = dt[cindex].getWeight();
-  return weight + sum;
-}
-
 bool TermDbSygus::registerSygusType(TypeNode tn)
 {
   std::map<TypeNode, bool>::iterator it = d_registerStatus.find(tn);
index e0a812069c3eaeeca7bb32052cf4acca2ad33ad0..80411b2581ec788e0bd9717fff6081bc21b8786c 100644 (file)
@@ -456,7 +456,6 @@ class TermDbSygus {
 
   Node getSygusNormalized( Node n, std::map< TypeNode, int >& var_count, std::map< Node, Node >& subs );
   Node getNormalized(TypeNode t, Node prog);
-  unsigned getSygusTermSize( Node n );
   /** involves div-by-zero */
   bool involvesDivByZero( Node n );
   /** get anchor */