Infrastructure for using active enumerators in sygus modules (#2547)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Thu, 27 Sep 2018 18:57:17 +0000 (13:57 -0500)
committerGitHub <noreply@github.com>
Thu, 27 Sep 2018 18:57:17 +0000 (13:57 -0500)
src/theory/quantifiers/sygus/cegis.cpp
src/theory/quantifiers/sygus/cegis.h
src/theory/quantifiers/sygus/cegis_unif.cpp
src/theory/quantifiers/sygus/sygus_pbe.cpp
src/theory/quantifiers/sygus/sygus_unif_io.cpp
src/theory/quantifiers/sygus/synth_conjecture.cpp
src/theory/quantifiers/sygus/synth_conjecture.h
src/theory/quantifiers/sygus/term_database_sygus.cpp
src/theory/quantifiers/sygus/term_database_sygus.h

index fbe0da7a81cb22b271899881552de86ba1a75414..db9af10b4767ab87c1a2c4709c46f1074e223c32 100644 (file)
@@ -13,6 +13,7 @@
  **/
 
 #include "theory/quantifiers/sygus/cegis.h"
+#include "expr/node_algorithm.h"
 #include "options/base_options.h"
 #include "options/quantifiers_options.h"
 #include "printer/printer.h"
@@ -100,15 +101,47 @@ void Cegis::getTermList(const std::vector<Node>& candidates,
 bool Cegis::addEvalLemmas(const std::vector<Node>& candidates,
                           const std::vector<Node>& candidate_values)
 {
+  // First, decide if this call will apply "conjecture-specific refinement".
+  // In other words, in some settings, the following method will identify and
+  // block a class of solutions {candidates -> S} that generalizes the current
+  // one (given by {candidates -> candidate_values}), such that for each
+  // candidate_values' in S, we have that {candidates -> candidate_values'} is
+  // also not a solution for the given conjecture. We may not
+  // apply this form of refinement if any (relevant) enumerator in candidates is
+  // "actively generated" (see TermDbSygs::isPassiveEnumerator), since its
+  // model values are themselves interpreted as classes of solutions.
+  bool doGen = true;
+  for (const Node& v : candidates)
+  {
+    // if it is relevant to refinement
+    if (d_refinement_lemma_vars.find(v) != d_refinement_lemma_vars.end())
+    {
+      if (!d_tds->isPassiveEnumerator(v))
+      {
+        doGen = false;
+        break;
+      }
+    }
+  }
   NodeManager* nm = NodeManager::currentNM();
   bool addedEvalLemmas = false;
   if (options::sygusRefEval())
   {
-    Trace("cegqi-engine") << "  *** Do refinement lemma evaluation..."
-                          << std::endl;
+    Trace("cegqi-engine") << "  *** Do refinement lemma evaluation"
+                          << (doGen ? " with conjecture-specific refinement"
+                                    : "")
+                          << "..." << std::endl;
     // see if any refinement lemma is refuted by evaluation
     std::vector<Node> cre_lems;
-    getRefinementEvalLemmas(candidates, candidate_values, cre_lems);
+    bool ret =
+        getRefinementEvalLemmas(candidates, candidate_values, cre_lems, doGen);
+    if (ret && !doGen)
+    {
+      Trace("cegqi-engine") << "...(actively enumerated) candidate failed "
+                               "refinement lemma evaluation."
+                            << std::endl;
+      return true;
+    }
     if (!cre_lems.empty())
     {
       for (const Node& lem : cre_lems)
@@ -124,7 +157,8 @@ bool Cegis::addEvalLemmas(const std::vector<Node>& candidates,
          add the lemmas below as well, in parallel. */
     }
   }
-  if (d_eval_unfold != nullptr)
+  // we only do evaluation unfolding for passive enumerators
+  if (doGen && d_eval_unfold != nullptr)
   {
     Trace("cegqi-engine") << "  *** Do evaluation unfolding..." << std::endl;
     std::vector<Node> eager_terms, eager_vals, eager_exps;
@@ -281,6 +315,8 @@ void Cegis::addRefinementLemma(Node lem)
   }
   // rewrite with extended rewriter
   slem = d_tds->getExtRewriter()->extendedRewrite(slem);
+  // collect all variables in slem
+  expr::getSymbols(slem, d_refinement_lemma_vars);
   std::vector<Node> waiting;
   waiting.push_back(lem);
   unsigned wcounter = 0;
@@ -408,10 +444,10 @@ void Cegis::registerRefinementLemma(const std::vector<Node>& vars,
 }
 
 bool Cegis::usingRepairConst() { return true; }
-
-void Cegis::getRefinementEvalLemmas(const std::vector<Node>& vs,
+bool Cegis::getRefinementEvalLemmas(const std::vector<Node>& vs,
                                     const std::vector<Node>& ms,
-                                    std::vector<Node>& lems)
+                                    std::vector<Node>& lems,
+                                    bool doGen)
 {
   Trace("sygus-cref-eval") << "Cref eval : conjecture has "
                            << d_refinement_lemma_unit.size() << " unit and "
@@ -424,6 +460,7 @@ void Cegis::getRefinementEvalLemmas(const std::vector<Node>& vs,
 
   Node nfalse = nm->mkConst(false);
   Node neg_guard = d_parent->getGuard().negate();
+  bool ret = false;
   for (unsigned r = 0; r < 2; r++)
   {
     std::unordered_set<Node, NodeHashFunction>& rlemmas =
@@ -447,6 +484,12 @@ void Cegis::getRefinementEvalLemmas(const std::vector<Node>& vs,
           << "...after unfolding is : " << lemcsu << std::endl;
       if (lemcsu.isConst() && !lemcsu.getConst<bool>())
       {
+        if (!doGen)
+        {
+          // we are not generating the lemmas, instead just return
+          return true;
+        }
+        ret = true;
         std::vector<Node> msu;
         std::vector<Node> mexp;
         msu.insert(msu.end(), ms.begin(), ms.end());
@@ -480,13 +523,12 @@ void Cegis::getRefinementEvalLemmas(const std::vector<Node>& vs,
         {
           cre_lem = neg_guard;
         }
-      }
-      if (!cre_lem.isNull()
-          && std::find(lems.begin(), lems.end(), cre_lem) == lems.end())
-      {
-        Trace("sygus-cref-eval")
-            << "...produced lemma : " << cre_lem << std::endl;
-        lems.push_back(cre_lem);
+        if (std::find(lems.begin(), lems.end(), cre_lem) == lems.end())
+        {
+          Trace("sygus-cref-eval") << "...produced lemma : " << cre_lem
+                                   << std::endl;
+          lems.push_back(cre_lem);
+        }
       }
     }
     if (!lems.empty())
@@ -494,6 +536,7 @@ void Cegis::getRefinementEvalLemmas(const std::vector<Node>& vs,
       break;
     }
   }
+  return ret;
 }
 
 bool Cegis::sampleAddRefinementLemma(const std::vector<Node>& candidates,
index c7392b378c00332702cadb57a817508a0368fd0f..7387bd4684e5700a0ef70925e104462c21f815a0 100644 (file)
@@ -104,6 +104,8 @@ class Cegis : public SygusModule
   /** substitution entailed by d_refinement_lemma_unit */
   std::vector<Node> d_rl_eval_hds;
   std::vector<Node> d_rl_vals;
+  /** all variables appearing in refinement lemmas */
+  std::unordered_set<Node, NodeHashFunction> d_refinement_lemma_vars;
   /** adds lem as a refinement lemma */
   void addRefinementLemma(Node lem);
   /** add refinement lemma conjunct
@@ -150,10 +152,14 @@ class Cegis : public SygusModule
    * Given a candidate solution ms for candidates vs, this function adds lemmas
    * to lems based on evaluating the conjecture, instantiated for ms, on lemmas
    * for previous refinements (d_refinement_lemmas).
+   *
+   * Returns true if any such lemma exists. If doGen is false, then the
+   * lemmas are not generated or added to lems.
    */
-  void getRefinementEvalLemmas(const std::vector<Node>& vs,
+  bool getRefinementEvalLemmas(const std::vector<Node>& vs,
                                const std::vector<Node>& ms,
-                               std::vector<Node>& lems);
+                               std::vector<Node>& lems,
+                               bool doGen);
   /** sampler object for the option cegisSample()
    *
    * This samples points of the type of the inner variables of the synthesis
index 6497bee0b3581773cee38581e4646a8799f750e3..56edc55c6465dd0b862e869df18b07cd01f04a1b 100644 (file)
@@ -260,11 +260,16 @@ bool CegisUnif::processConstructCandidates(const std::vector<Node>& enums,
       if (options::sygusUnifCondIndependent() && !unif_enums[1][e].empty())
       {
         Node eu = unif_enums[1][e][0];
-        Node g = d_u_enum_manager.getActiveGuardForEnumerator(eu);
-        Node exp_exc = d_tds->getExplain()
-                           ->getExplanationForEquality(eu, unif_values[1][e][0])
-                           .negate();
-        lems.push_back(nm->mkNode(OR, g.negate(), exp_exc));
+        Assert(d_tds->isEnumerator(eu));
+        if (d_tds->isPassiveEnumerator(eu))
+        {
+          Node g = d_u_enum_manager.getActiveGuardForEnumerator(eu);
+          Node exp_exc =
+              d_tds->getExplain()
+                  ->getExplanationForEquality(eu, unif_values[1][e][0])
+                  .negate();
+          lems.push_back(nm->mkNode(OR, g.negate(), exp_exc));
+        }
       }
     }
   }
index 647b16637c2d60472a44372964f773188a921bfe..b7e6e0c6570de5537799e975e37761e170ef64c5 100644 (file)
@@ -481,14 +481,17 @@ bool SygusPbe::constructCandidates(const std::vector<Node>& enums,
       Node c = d_enum_to_candidate[e];
       std::vector<Node> enum_lems;
       d_sygus_unif[c].notifyEnumeration(e, v, enum_lems);
-      // the lemmas must be guarded by the active guard of the enumerator
-      Assert(d_enum_to_active_guard.find(e) != d_enum_to_active_guard.end());
-      Node g = d_enum_to_active_guard[e];
-      for (unsigned j = 0, size = enum_lems.size(); j < size; j++)
+      if (!enum_lems.empty())
       {
-        enum_lems[j] = nm->mkNode(OR, g.negate(), enum_lems[j]);
+        // the lemmas must be guarded by the active guard of the enumerator
+        Assert(d_enum_to_active_guard.find(e) != d_enum_to_active_guard.end());
+        Node g = d_enum_to_active_guard[e];
+        for (unsigned j = 0, size = enum_lems.size(); j < size; j++)
+        {
+          enum_lems[j] = nm->mkNode(OR, g.negate(), enum_lems[j]);
+        }
+        lems.insert(lems.end(), enum_lems.begin(), enum_lems.end());
       }
-      lems.insert(lems.end(), enum_lems.begin(), enum_lems.end());
     }
   }
   for( unsigned i=0; i<candidates.size(); i++ ){
index eca88cab86f870ddaf46657e5afaceed554a9f06..4fe3cfbed0a938797aa8de2e3f12c5fa04eb4fb5 100644 (file)
@@ -562,7 +562,10 @@ void SygusUnifIo::notifyEnumeration(Node e, Node v, std::vector<Node>& lemmas)
 
   // is it excluded for domain-specific reason?
   std::vector<Node> exp_exc_vec;
-  if (getExplanationForEnumeratorExclude(e, v, base_results, exp_exc_vec))
+  Assert(d_tds->isEnumerator(e));
+  bool isPassive = d_tds->isPassiveEnumerator(e);
+  if (isPassive
+      && getExplanationForEnumeratorExclude(e, v, base_results, exp_exc_vec))
   {
     Assert(!exp_exc_vec.empty());
     exp_exc = exp_exc_vec.size() == 1
@@ -707,16 +710,20 @@ void SygusUnifIo::notifyEnumeration(Node e, Node v, std::vector<Node>& lemmas)
     }
   }
 
-  // exclude this value on subsequent iterations
-  if (exp_exc.isNull())
+  if (isPassive)
   {
-    // if we did not already explain why this should be excluded, use default
-    exp_exc = d_tds->getExplain()->getExplanationForEquality(e, v);
+    // exclude this value on subsequent iterations
+    if (exp_exc.isNull())
+    {
+      Trace("sygus-sui-enum-lemma") << "Use basic exclusion." << std::endl;
+      // if we did not already explain why this should be excluded, use default
+      exp_exc = d_tds->getExplain()->getExplanationForEquality(e, v);
+    }
+    exp_exc = exp_exc.negate();
+    Trace("sygus-sui-enum-lemma")
+        << "SygusUnifIo : enumeration exclude lemma : " << exp_exc << std::endl;
+    lemmas.push_back(exp_exc);
   }
-  exp_exc = exp_exc.negate();
-  Trace("sygus-sui-enum-lemma")
-      << "SygusUnifIo : enumeration exclude lemma : " << exp_exc << std::endl;
-  lemmas.push_back(exp_exc);
 }
 
 bool SygusUnifIo::constructSolution(std::vector<Node>& sols,
index a29fdcc9fef464544a537a49f33ef6ee767dac3b..dea67b7c38e3d189d0ed04231c48d1577163ce53 100644 (file)
@@ -41,6 +41,7 @@ namespace quantifiers {
 
 SynthConjecture::SynthConjecture(QuantifiersEngine* qe)
     : d_qe(qe),
+      d_tds(qe->getTermDatabaseSygus()),
       d_ceg_si(new CegSingleInv(qe, this)),
       d_ceg_proc(new SynthConjectureProcess(qe)),
       d_ceg_gc(new CegGrammarConstructor(qe, this)),
@@ -339,7 +340,7 @@ void SynthConjecture::doCheck(std::vector<Node>& lems)
 
   // get the model value of the relevant terms from the master module
   std::vector<Node> enum_values;
-  bool fullModel = getModelValues(terms, enum_values);
+  bool fullModel = getEnumeratedValues(terms, enum_values);
 
   // if the master requires a full model and the model is partial, we fail
   if (!d_master->allowPartialModel() && !fullModel)
@@ -450,7 +451,7 @@ void SynthConjecture::doCheck(std::vector<Node>& lems)
   // eagerly unfold applications of evaluation function
   Trace("cegqi-debug") << "pre-unfold counterexample : " << lem << std::endl;
   std::map<Node, Node> visited_n;
-  lem = d_qe->getTermDatabaseSygus()->getEagerUnfold(lem, visited_n);
+  lem = d_tds->getEagerUnfold(lem, visited_n);
   // record the instantiation
   // this is used for remembering the solution
   recordInstantiation(candidate_values);
@@ -542,7 +543,12 @@ void SynthConjecture::doRefine(std::vector<Node>& lems)
     if (d_ce_sk_var_mvs.empty())
     {
       std::vector<Node> model_values;
-      getModelValues(d_ce_sk_vars, model_values);
+      for (const Node& v : d_ce_sk_vars)
+      {
+        Node mv = getModelValue(v);
+        Trace("cegqi-refine") << "  " << v << " -> " << mv << std::endl;
+        model_values.push_back(mv);
+      }
       sk_subs.insert(sk_subs.end(), model_values.begin(), model_values.end());
     }
     else
@@ -594,13 +600,14 @@ void SynthConjecture::preregisterConjecture(Node q)
   d_ceg_si->preregisterConjecture(q);
 }
 
-bool SynthConjecture::getModelValues(std::vector<Node>& n, std::vector<Node>& v)
+bool SynthConjecture::getEnumeratedValues(std::vector<Node>& n,
+                                          std::vector<Node>& v)
 {
   bool ret = true;
   Trace("cegqi-engine") << "  * Value is : ";
   for (unsigned i = 0; i < n.size(); i++)
   {
-    Node nv = getModelValue(n[i]);
+    Node nv = getEnumeratedValue(n[i]);
     v.push_back(nv);
     ret = ret && !nv.isNull();
     if (Trace.isOn("cegqi-engine"))
@@ -619,7 +626,7 @@ bool SynthConjecture::getModelValues(std::vector<Node>& n, std::vector<Node>& v)
         Trace("cegqi-engine") << ss.str() << " ";
         if (Trace.isOn("cegqi-engine-rr"))
         {
-          Node bv = d_qe->getTermDatabaseSygus()->sygusToBuiltin(nv, tn);
+          Node bv = d_tds->sygusToBuiltin(nv, tn);
           bv = Rewriter::rewrite(bv);
           Trace("cegqi-engine-rr") << " -> " << bv << std::endl;
         }
@@ -630,18 +637,29 @@ bool SynthConjecture::getModelValues(std::vector<Node>& n, std::vector<Node>& v)
   return ret;
 }
 
-Node SynthConjecture::getModelValue(Node n)
+Node SynthConjecture::getEnumeratedValue(Node e)
 {
-  Trace("cegqi-mv") << "getModelValue for : " << n << std::endl;
-  if (n.getAttribute(SygusSymBreakExcAttribute()))
+  Assert(d_tds->isEnumerator(e));
+  if (e.getAttribute(SygusSymBreakExcAttribute()))
   {
-    // if the current model value of n was excluded by symmetry breaking, then
+    // if the current model value of e was excluded by symmetry breaking, then
     // it does not have a proper model value that we should consider, thus we
     // return null.
     return Node::null();
   }
-  Node mv = d_qe->getModel()->getValue(n);
-  return mv;
+  if (d_tds->isPassiveEnumerator(e))
+  {
+    return getModelValue(e);
+  }
+  Assert(false);
+  // management of actively generated enumerators goes here
+  return getModelValue(e);
+}
+
+Node SynthConjecture::getModelValue(Node n)
+{
+  Trace("cegqi-mv") << "getModelValue for : " << n << std::endl;
+  return d_qe->getModel()->getValue(n);
 }
 
 void SynthConjecture::debugPrint(const char* c)
@@ -718,8 +736,7 @@ void SynthConjecture::printAndContinueStream()
     {
       sol = d_cinfo[cprog].d_inst.back();
       // add to explanation of exclusion
-      d_qe->getTermDatabaseSygus()->getExplain()->getExplanationForEquality(
-          cprog, sol, exp);
+      d_tds->getExplain()->getExplanationForEquality(cprog, sol, exp);
     }
   }
   Assert(!exp.empty());
@@ -817,7 +834,6 @@ void SynthConjecture::getSynthSolutions(std::map<Node, Node>& sol_map,
                                         bool singleInvocation)
 {
   NodeManager* nm = NodeManager::currentNM();
-  TermDbSygus* sygusDb = d_qe->getTermDatabaseSygus();
   std::vector<Node> sols;
   std::vector<int> statuses;
   if (!getSynthSolutionsInternal(sols, statuses, singleInvocation))
@@ -833,7 +849,7 @@ void SynthConjecture::getSynthSolutions(std::map<Node, Node>& sol_map,
     if (status != 0)
     {
       // convert sygus to builtin here
-      bsol = sygusDb->sygusToBuiltin(sol, sol.getType());
+      bsol = d_tds->sygusToBuiltin(sol, sol.getType());
     }
     // convert to lambda
     TypeNode tn = d_embed_quant[0][i].getType();
@@ -894,8 +910,7 @@ bool SynthConjecture::getSynthSolutionsInternal(std::vector<Node>& sols,
           {
             TNode templa = d_ceg_si->getTemplateArg(sf);
             // make the builtin version of the full solution
-            TermDbSygus* sygusDb = d_qe->getTermDatabaseSygus();
-            sol = sygusDb->sygusToBuiltin(sol, sol.getType());
+            sol = d_tds->sygusToBuiltin(sol, sol.getType());
             Trace("cegqi-inv") << "Builtin version of solution is : " << sol
                                << ", type : " << sol.getType() << std::endl;
             TNode tsol = sol;
index 53bc829cfe54be7ede0757339ea8fd07e4746f28..694e4a0cbe55c30ee7920dc39bce141e26c54980 100644 (file)
@@ -116,11 +116,15 @@ class SynthConjecture
    * Get model values for terms n, store in vector v. This method returns true
    * if and only if all values added to v are non-null.
    */
-  bool getModelValues(std::vector<Node>& n, std::vector<Node>& v);
+  bool getEnumeratedValues(std::vector<Node>& n, std::vector<Node>& v);
   /**
    * Get model value for term n. If n has a value that was excluded by
    * datatypes sygus symmetry breaking, this method returns null.
    */
+  Node getEnumeratedValue(Node n);
+  /**
+   * Get model value for term n.
+   */
   Node getModelValue(Node n);
 
   /** get utility for static preprocessing and analysis of conjectures */
@@ -138,6 +142,8 @@ class SynthConjecture
  private:
   /** reference to quantifier engine */
   QuantifiersEngine* d_qe;
+  /** term database sygus of d_qe */
+  TermDbSygus* d_tds;
   /** The feasible guard. */
   Node d_feasible_guard;
   /** the decision strategy for the feasible guard */
index 18e9619cbb88e712d5e5423cc30f9334f33e1d91..23b35cfed20002ea2176a99ea6f060037a560f4c 100644 (file)
@@ -646,6 +646,16 @@ bool TermDbSygus::isVariableAgnosticEnumerator(Node e) const
   return false;
 }
 
+bool TermDbSygus::isPassiveEnumerator(Node e) const
+{
+  if (isVariableAgnosticEnumerator(e))
+  {
+    return false;
+  }
+  // other criteria go here
+  return true;
+}
+
 void TermDbSygus::getEnumerators(std::vector<Node>& mts)
 {
   for (std::map<Node, SynthConjecture*>::iterator itm =
index 361c6bae04ede809bc56df427c86fa5bb2906aa6..785e8731ca7bd98c734be687c0cc037ceeaa4b12 100644 (file)
@@ -96,6 +96,24 @@ class TermDbSygus {
   bool usingSymbolicConsForEnumerator(Node e) const;
   /** is this enumerator agnostic to variables? */
   bool isVariableAgnosticEnumerator(Node e) const;
+  /** is this a "passively-generated" enumerator?
+   *
+   * A "passively-generated" enumerator is one for which the terms it enumerates
+   * are obtained by looking at its model value only. For passively-generated
+   * enumerators, it is the responsibility of the user of that enumerator (say
+   * a SygusModule) to block the current model value of it before asking for
+   * another value. By default, the Cegis module uses passively-generated
+   * enumerators and "conjecture-specific refinement" to rule out models
+   * for passively-generated enumerators.
+   *
+   * On the other hand, an "actively-generated" enumerator is one for which the
+   * terms it enumerates are not necessarily a subset of the model values the
+   * enumerator takes. Actively-generated enumerators are centrally managed by
+   * SynthConjecture. The user of actively-generated enumerators are prohibited
+   * from influencing its model value. For example, conjecture-specific
+   * refinement in Cegis is not applied to actively-generated enumerators.
+   */
+  bool isPassiveEnumerator(Node e) const;
   /** get all registered enumerators */
   void getEnumerators(std::vector<Node>& mts);
   /** Register symmetry breaking lemma