Use expression mining independent of whether sygus stream is enabled (#8250)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 9 Mar 2022 11:42:01 +0000 (05:42 -0600)
committerGitHub <noreply@github.com>
Wed, 9 Mar 2022 11:42:01 +0000 (11:42 +0000)
Since it is possible to use SyGuS in incremental mode, e.g. for `get-abduct-next`, our expression mining, e.g. for filtering weak solutions, should apply independent of whether sygusStream is enabled.  This refactors the SyGuS solver so that we do so.

src/theory/quantifiers/sygus/synth_conjecture.cpp
src/theory/quantifiers/sygus/synth_conjecture.h

index 3760cffb6fd08649aee59e2e53ed1a1f62a48382..55611ad745557643ff4752eb8d775c101ce00836 100644 (file)
@@ -60,6 +60,11 @@ SynthConjecture::SynthConjecture(Env& env,
       d_verify(env, d_tds),
       d_hasSolution(false),
       d_computedSolution(false),
+      d_runExprMiner(options().quantifiers.sygusRewSynth
+                     || options().quantifiers.sygusQueryGen
+                            != options::SygusQueryGenMode::NONE
+                     || options().quantifiers.sygusFilterSolMode
+                            != options::SygusFilterSolMode::NONE),
       d_ceg_si(new CegSingleInv(env, tr, s)),
       d_templInfer(new SygusTemplateInfer(env)),
       d_ceg_proc(new SynthConjectureProcess(env)),
@@ -576,10 +581,13 @@ bool SynthConjecture::doCheck()
 
   // now mark that we have a solution
   d_hasSolution = true;
-  if (options().quantifiers.sygusStream)
+  ++(d_stats.d_solutions);
+  // determine if we should filter this solution, e.g. based on expression
+  // mining or sygus stream
+  if (runExprMiner())
   {
-    // immediately print the current solution
-    printAndContinueStream(candidate_values);
+    // excluded due to expression mining
+    excludeCurrentSolution(candidate_values);
     // streaming means now we immediately are looking for a new solution
     d_hasSolution = false;
     d_computedSolution = false;
@@ -725,6 +733,25 @@ EnumValueManager* SynthConjecture::getEnumValueManagerFor(Node e)
   return eman;
 }
 
+ExpressionMinerManager* SynthConjecture::getExprMinerManagerFor(Node e)
+{
+  if (!d_runExprMiner)
+  {
+    return nullptr;
+  }
+  std::map<Node, std::unique_ptr<ExpressionMinerManager>>::iterator its =
+      d_exprm.find(e);
+  if (its != d_exprm.end())
+  {
+    return its->second.get();
+  }
+  d_exprm[e].reset(new ExpressionMinerManager(d_env));
+  ExpressionMinerManager* emm = d_exprm[e].get();
+  emm->initializeSygus(d_tds, e, options().quantifiers.sygusSamples, true);
+  emm->initializeMinersForOptions();
+  return emm;
+}
+
 Node SynthConjecture::getModelValue(Node n)
 {
   Trace("cegqi-mv") << "getModelValue for : " << n << std::endl;
@@ -738,15 +765,6 @@ void SynthConjecture::debugPrint(const char* c)
   Trace(c) << "  * Counterexample skolems : " << d_innerSks << std::endl;
 }
 
-void SynthConjecture::printAndContinueStream(const std::vector<Node>& values)
-{
-  Assert(d_master != nullptr);
-  // we have generated a solution, print it
-  // get the current output stream
-  printSynthSolutionInternal(*options().base.out);
-  excludeCurrentSolution(values);
-}
-
 void SynthConjecture::excludeCurrentSolution(const std::vector<Node>& values)
 {
   Assert(values.size() == d_candidates.size());
@@ -783,111 +801,113 @@ void SynthConjecture::excludeCurrentSolution(const std::vector<Node>& values)
   }
 }
 
-void SynthConjecture::printSynthSolutionInternal(std::ostream& out)
+bool SynthConjecture::runExprMiner()
 {
-  Trace("cegqi-sol-debug") << "Printing synth solution..." << std::endl;
+  // if not using expression mining and sygus stream
+  if (!d_runExprMiner && !options().quantifiers.sygusStream)
+  {
+    return false;
+  }
+  Trace("cegqi-sol-debug") << "Run expression mining..." << std::endl;
   Assert(d_quant[0].getNumChildren() == d_embed_quant[0].getNumChildren());
   std::vector<Node> sols;
   std::vector<int8_t> statuses;
   if (!getSynthSolutionsInternal(sols, statuses))
   {
-    return;
+    return false;
   }
+  // always exclude if sygus stream is enabled
+  bool doExclude = options().quantifiers.sygusStream;
   NodeManager* nm = NodeManager::currentNM();
-  for (unsigned i = 0, size = d_embed_quant[0].getNumChildren(); i < size; i++)
+  std::ostream& out = options().base.out;
+  for (size_t i = 0, size = d_embed_quant[0].getNumChildren(); i < size; i++)
   {
     Node sol = sols[i];
-    if (!sol.isNull())
+    if (sol.isNull())
     {
-      Node prog = d_embed_quant[0][i];
-      int8_t status = statuses[i];
-      TypeNode tn = prog.getType();
-      const DType& dt = tn.getDType();
-      std::stringstream ss;
-      ss << prog;
-      std::string f(ss.str());
-      f.erase(f.begin());
-      ++(d_stats.d_solutions);
-
-      bool is_unique_term = true;
-
-      if (status != 0
-          && (options().quantifiers.sygusRewSynth
-              || options().quantifiers.sygusQueryGen
-                     != options::SygusQueryGenMode::NONE
-              || options().quantifiers.sygusFilterSolMode
-                     != options::SygusFilterSolMode::NONE))
+      // failed to reconstruct to syntax, skip
+      continue;
+    }
+    Node e = d_embed_quant[0][i];
+    int8_t status = statuses[i];
+    // run expression mining
+    if (status != 0)
+    {
+      ExpressionMinerManager* emm = getExprMinerManagerFor(e);
+      if (emm != nullptr)
       {
-        Trace("cegqi-sol-debug") << "Run expression mining..." << std::endl;
-        std::map<Node, std::unique_ptr<ExpressionMinerManager>>::iterator its =
-            d_exprm.find(prog);
-        if (its == d_exprm.end())
-        {
-          d_exprm[prog].reset(new ExpressionMinerManager(d_env));
-          ExpressionMinerManager* emm = d_exprm[prog].get();
-          emm->initializeSygus(
-              d_tds, d_candidates[i], options().quantifiers.sygusSamples, true);
-          emm->initializeMinersForOptions();
-          its = d_exprm.find(prog);
-        }
         bool rew_print = false;
-        is_unique_term = its->second->addTerm(sol, out, rew_print);
+        bool ret = emm->addTerm(sol, out, rew_print);
         if (rew_print)
         {
+          // count the number of rewrites we printed
           ++(d_stats.d_candidate_rewrites_print);
         }
-        if (!is_unique_term)
+        if (!ret)
         {
+          // count the number of filtered solutions
           ++(d_stats.d_filtered_solutions);
+          // if any term is excluded due to mining, its output is excluded
+          // from sygus stream, and the entire solution is excluded.
+          doExclude = true;
+          continue;
         }
       }
-      if (is_unique_term)
+    }
+    // print to stream
+    if (options().quantifiers.sygusStream)
+    {
+      TypeNode tn = e.getType();
+      const DType& dt = tn.getDType();
+      std::stringstream ss;
+      ss << e;
+      std::string f(ss.str());
+      f.erase(f.begin());
+      out << "(define-fun " << f << " ";
+      // Only include variables that are truly bound variables of the
+      // function-to-synthesize. This means we exclude variables that encode
+      // external terms. This ensures that --sygus-stream prints
+      // solutions with no arguments on the predicate for responses to
+      // the get-abduct command.
+      // pvs stores the variables that will be printed in the argument list
+      // below.
+      std::vector<Node> pvs;
+      Node vl = dt.getSygusVarList();
+      if (!vl.isNull())
       {
-        out << "(define-fun " << f << " ";
-        // Only include variables that are truly bound variables of the
-        // function-to-synthesize. This means we exclude variables that encode
-        // external terms. This ensures that --sygus-stream prints
-        // solutions with no arguments on the predicate for responses to
-        // the get-abduct command.
-        // pvs stores the variables that will be printed in the argument list
-        // below.
-        std::vector<Node> pvs;
-        Node vl = dt.getSygusVarList();
-        if (!vl.isNull())
+        Assert(vl.getKind() == BOUND_VAR_LIST);
+        SygusVarToTermAttribute sta;
+        for (const Node& v : vl)
         {
-          Assert(vl.getKind() == BOUND_VAR_LIST);
-          SygusVarToTermAttribute sta;
-          for (const Node& v : vl)
+          if (!v.hasAttribute(sta))
           {
-            if (!v.hasAttribute(sta))
-            {
-              pvs.push_back(v);
-            }
+            pvs.push_back(v);
           }
         }
-        if (pvs.empty())
-        {
-          out << "() ";
-        }
-        else
-        {
-          vl = nm->mkNode(BOUND_VAR_LIST, pvs);
-          out << vl << " ";
-        }
-        out << dt.getSygusType() << " ";
-        if (status == 0)
-        {
-          out << sol;
-        }
-        else
-        {
-          Node bsol = datatypes::utils::sygusToBuiltin(sol, true);
-          out << bsol;
-        }
-        out << ")" << std::endl;
       }
+      if (pvs.empty())
+      {
+        out << "() ";
+      }
+      else
+      {
+        vl = nm->mkNode(BOUND_VAR_LIST, pvs);
+        out << vl << " ";
+      }
+      out << dt.getSygusType() << " ";
+      if (status == 0)
+      {
+        out << sol;
+      }
+      else
+      {
+        Node bsol = datatypes::utils::sygusToBuiltin(sol, true);
+        out << bsol;
+      }
+      out << ")" << std::endl;
     }
   }
+  return doExclude;
 }
 
 bool SynthConjecture::getSynthSolutions(
index ce9788c5090cf25f320d47ddd90ce81d5fbab21e..568fa6ea863c0c24844c09161287b509b072943c 100644 (file)
@@ -88,12 +88,6 @@ class SynthConjecture : protected EnvObj
    */
   bool doCheck();
   //-------------------------------end for counterexample-guided check/refine
-  /**
-   * Prints the current synthesis solution to output stream out. This is
-   * currently used for printing solutions for sygusStream only. We do not
-   * enclose solutions in parentheses.
-   */
-  void printSynthSolutionInternal(std::ostream& out);
   /** get synth solutions
    *
    * This method returns true if this class has a solution available to the
@@ -201,6 +195,8 @@ class SynthConjecture : protected EnvObj
   bool d_hasSolution;
   /** Whether we have computed a solution */
   bool d_computedSolution;
+  /** whether we are running expression mining */
+  bool d_runExprMiner;
   /**
    * The final solution and status, caches getSynthSolutionsInternal, valid
    * if d_computedSolution is true.
@@ -269,6 +265,10 @@ class SynthConjecture : protected EnvObj
    * Get or make enumerator manager for the enumerator e.
    */
   EnumValueManager* getEnumValueManagerFor(Node e);
+  /**
+   * Get or make the expression miner manager for enumerator e.
+   */
+  ExpressionMinerManager* getExprMinerManagerFor(Node e);
   //------------------------end enumerators
 
   /** list of constants for quantified formula
@@ -325,16 +325,17 @@ class SynthConjecture : protected EnvObj
    */
   bool getSynthSolutionsInternal(std::vector<Node>& sols,
                                  std::vector<int8_t>& status);
-  //-------------------------------- sygus stream
   /**
-   * Prints the current synthesis solution to the output stream indicated by
-   * the Options object, send a lemma blocking the current solution to the
-   * output channel, which we refer to as a "stream exclusion lemma".
+   * Run expression mining on the last synthesis solution. Return true
+   * if we should skip it.
    *
-   * The argument enums is the set of enumerators that comprise the current
-   * solution, and values is their current values.
+   * This method also prints the current synthesis solution to output stream out
+   * when sygusStream is enabled, which does not enclose solutions in
+   * parentheses. If sygusStream is enabled, this always returns true, as the
+   * current solution should be printed and then immediately excluded.
    */
-  void printAndContinueStream(const std::vector<Node>& values);
+  bool runExprMiner();
+  //-------------------------------- sygus stream
   /** exclude the current solution { enums -> values } */
   void excludeCurrentSolution(const std::vector<Node>& values);
   /**