Make CegisUnif default to Cegis when no unif used (#1836)
authorHaniel Barbosa <hanielbbarbosa@gmail.com>
Thu, 3 May 2018 12:54:27 +0000 (07:54 -0500)
committerAndrew Reynolds <andrew.j.reynolds@gmail.com>
Thu, 3 May 2018 12:54:27 +0000 (07:54 -0500)
src/theory/quantifiers/sygus/cegis.cpp
src/theory/quantifiers/sygus/cegis.h
src/theory/quantifiers/sygus/cegis_unif.cpp
src/theory/quantifiers/sygus/cegis_unif.h
src/theory/quantifiers/sygus/sygus_unif_rl.cpp
src/theory/quantifiers/sygus/sygus_unif_rl.h

index ab448a2b877329519ae0061e0e0c8f79daaf6e20..f48955c9f37ce5d2b37b5204162e921da9b90d9c 100644 (file)
@@ -65,80 +65,79 @@ void Cegis::getTermList(const std::vector<Node>& candidates,
   enums.insert(enums.end(), candidates.begin(), candidates.end());
 }
 
-/** construct candidate */
-bool Cegis::constructCandidates(const std::vector<Node>& enums,
-                                const std::vector<Node>& enum_values,
-                                const std::vector<Node>& candidates,
-                                std::vector<Node>& candidate_values,
-                                std::vector<Node>& lems)
+bool Cegis::addEvalLemmas(const std::vector<Node>& candidates,
+                          const std::vector<Node>& candidate_values)
 {
-  candidate_values.insert(
-      candidate_values.end(), enum_values.begin(), enum_values.end());
-
-  if (options::sygusDirectEval())
+  if (!options::sygusDirectEval())
   {
-    NodeManager* nm = NodeManager::currentNM();
-    bool addedEvalLemmas = false;
-    if (options::sygusCRefEval())
+    return false;
+  }
+  NodeManager* nm = NodeManager::currentNM();
+  bool addedEvalLemmas = false;
+  if (options::sygusCRefEval())
+  {
+    Trace("cegqi-engine") << "  *** Do conjecture refinement evaluation..."
+                          << std::endl;
+    // see if any refinement lemma is refuted by evaluation
+    std::vector<Node> cre_lems;
+    getRefinementEvalLemmas(candidates, candidate_values, cre_lems);
+    if (!cre_lems.empty())
     {
-      Trace("cegqi-engine") << "  *** Do conjecture refinement evaluation..."
-                            << std::endl;
-      // see if any refinement lemma is refuted by evaluation
-      std::vector<Node> cre_lems;
-      getRefinementEvalLemmas(candidates, candidate_values, cre_lems);
-      if (!cre_lems.empty())
+      for (const Node& lem : cre_lems)
       {
-        for (unsigned j = 0; j < cre_lems.size(); j++)
+        if (d_qe->addLemma(lem))
         {
-          Node lem = cre_lems[j];
-          if (d_qe->addLemma(lem))
-          {
-            Trace("cegqi-lemma") << "Cegqi::Lemma : cref evaluation : " << lem
-                                 << std::endl;
-            addedEvalLemmas = true;
-          }
+          Trace("cegqi-lemma") << "Cegqi::Lemma : cref evaluation : " << lem
+                               << std::endl;
+          addedEvalLemmas = true;
         }
-        // we could, but do not return here.
-        // experimentally, it is better to add the lemmas below as well,
-        // in parallel.
       }
+      /* we could, but do not return here. experimentally, it is better to
+         add the lemmas below as well, in parallel. */
     }
-    Trace("cegqi-engine") << "  *** Do direct evaluation..." << std::endl;
-    std::vector<Node> eager_terms;
-    std::vector<Node> eager_vals;
-    std::vector<Node> eager_exps;
-    TermDbSygus* tds = d_qe->getTermDatabaseSygus();
-    for (unsigned j = 0, size = candidates.size(); j < size; j++)
-    {
-      Trace("cegqi-debug") << "  register " << candidates[j] << " -> "
-                           << candidate_values[j] << std::endl;
-      tds->registerModelValue(candidates[j],
-                              candidate_values[j],
-                              eager_terms,
-                              eager_vals,
-                              eager_exps);
-    }
-    Trace("cegqi-debug") << "...produced " << eager_terms.size()
-                         << " eager evaluation lemmas." << std::endl;
-
-    for (unsigned j = 0, size = eager_terms.size(); j < size; j++)
-    {
-      Node lem = nm->mkNode(kind::OR,
-                            eager_exps[j].negate(),
-                            eager_terms[j].eqNode(eager_vals[j]));
-      if (d_qe->addLemma(lem))
-      {
-        Trace("cegqi-lemma") << "Cegqi::Lemma : evaluation : " << lem
-                             << std::endl;
-        addedEvalLemmas = true;
-      }
-    }
-    if (addedEvalLemmas)
+  }
+  Trace("cegqi-engine") << "  *** Do direct evaluation..." << std::endl;
+  std::vector<Node> eager_terms, eager_vals, eager_exps;
+  TermDbSygus* tds = d_qe->getTermDatabaseSygus();
+  for (unsigned i = 0, size = candidates.size(); i < size; ++i)
+  {
+    Trace("cegqi-debug") << "  register " << candidates[i] << " -> "
+                         << candidate_values[i] << std::endl;
+    tds->registerModelValue(candidates[i],
+                            candidate_values[i],
+                            eager_terms,
+                            eager_vals,
+                            eager_exps);
+  }
+  Trace("cegqi-debug") << "...produced " << eager_terms.size()
+                       << " eager evaluation lemmas.\n";
+  for (unsigned i = 0, size = eager_terms.size(); i < size; ++i)
+  {
+    Node lem = nm->mkNode(
+        OR, eager_exps[i].negate(), eager_terms[i].eqNode(eager_vals[i]));
+    if (d_qe->addLemma(lem))
     {
-      return false;
+      Trace("cegqi-lemma") << "Cegqi::Lemma : evaluation : " << lem
+                           << std::endl;
+      addedEvalLemmas = true;
     }
   }
+  return addedEvalLemmas;
+}
 
+/** construct candidate */
+bool Cegis::constructCandidates(const std::vector<Node>& enums,
+                                const std::vector<Node>& enum_values,
+                                const std::vector<Node>& candidates,
+                                std::vector<Node>& candidate_values,
+                                std::vector<Node>& lems)
+{
+  if (addEvalLemmas(enums, enum_values))
+  {
+    return false;
+  }
+  candidate_values.insert(
+      candidate_values.end(), enum_values.begin(), enum_values.end());
   return true;
 }
 
index 358b505366ba786714a9271a0d704aee5e432b68..7500abb781c8872bfcff8f9ec4a3169f41ae4b67 100644 (file)
@@ -64,7 +64,7 @@ class Cegis : public SygusModule
                                        Node lem,
                                        std::vector<Node>& lems) override;
 
- private:
+ protected:
   /** If CegConjecture::d_base_inst is exists y. P( d, y ), then this is y. */
   std::vector<Node> d_base_vars;
   /**
@@ -94,6 +94,17 @@ class Cegis : public SygusModule
   bool sampleAddRefinementLemma(const std::vector<Node>& candidates,
                                 const std::vector<Node>& vals,
                                 std::vector<Node>& lems);
+
+  /** evaluates candidate values on current refinement lemmas
+   *
+   * Returns true if refinement lemmas are added after evaluation, false
+   * otherwise.
+   *
+   * Also eagerly unfolds evaluation functions in a heuristic manner, which is
+   * useful e.g. for boolean connectives
+   */
+  bool addEvalLemmas(const std::vector<Node>& candidates,
+                     const std::vector<Node>& candidate_values);
   //-----------------------------------end refinement lemmas
 
   /** Get refinement evaluation lemmas
index cbd9358e068e3d9af88b2b3e3fdf67c57c2ed82f..14a5bedf516a4dd64c9fb4c7a128eeb320935971 100644 (file)
@@ -14,6 +14,7 @@
 
 #include "theory/quantifiers/sygus/cegis_unif.h"
 
+#include "options/quantifiers_options.h"
 #include "theory/quantifiers/sygus/ce_guided_conjecture.h"
 #include "theory/quantifiers/sygus/sygus_unif_rl.h"
 #include "theory/quantifiers/sygus/term_database_sygus.h"
@@ -25,7 +26,7 @@ namespace theory {
 namespace quantifiers {
 
 CegisUnif::CegisUnif(QuantifiersEngine* qe, CegConjecture* p)
-    : SygusModule(qe, p)
+    : Cegis(qe, p), d_sygus_unif(p)
 {
   d_tds = d_qe->getTermDatabaseSygus();
 }
@@ -36,43 +37,43 @@ bool CegisUnif::initialize(Node n,
                            std::vector<Node>& lemmas)
 {
   Trace("cegis-unif") << "Initialize CegisUnif : " << n << std::endl;
-  Assert(candidates.size() > 0);
-  if (candidates.size() > 1)
+  /* Init UNIF util */
+  d_sygus_unif.initialize(d_qe, candidates, d_cond_enums, lemmas);
+  /* TODO initialize unif enumerators */
+  Trace("cegis-unif") << "Initializing enums for pure Cegis case\n";
+  /* Initialize enumerators for non-unif functions-to-synhesize */
+  for (const Node& c : candidates)
   {
-    return false;
-  }
-  d_candidate = candidates[0];
-  Trace("cegis-unif") << "Initialize unif utility for " << d_candidate
-                      << "...\n";
-  d_sygus_unif.initialize(d_qe, candidates, d_enums, lemmas);
-  Assert(!d_enums.empty());
-  Trace("cegis-unif") << "Initialize " << d_enums.size() << " enumerators for "
-                      << d_candidate << "...\n";
-  /* initialize the enumerators */
-  for (const Node& e : d_enums)
-  {
-    d_tds->registerEnumerator(e, d_candidate, d_parent, true);
-    Node g = d_tds->getActiveGuardForEnumerator(e);
-    d_enum_to_active_guard[e] = g;
+    if (!d_sygus_unif.usingUnif(c))
+    {
+      d_tds->registerEnumerator(c, c, d_parent);
+    }
   }
   return true;
 }
 
 void CegisUnif::getTermList(const std::vector<Node>& candidates,
-                            std::vector<Node>& terms)
+                            std::vector<Node>& enums)
 {
-  Assert(candidates.size() == 1);
-  Valuation& valuation = d_qe->getValuation();
-  for (const Node& e : d_enums)
+  for (const Node& c : candidates)
   {
-    Assert(d_enum_to_active_guard.find(e) != d_enum_to_active_guard.end());
-    Node g = d_enum_to_active_guard[e];
-    /* Get whether the active guard for this enumerator is if so, then there may
-       exist more values for it, and hence we add it to terms. */
-    Node gstatus = valuation.getSatValue(g);
-    if (!gstatus.isNull() && gstatus.getConst<bool>())
+    if (!d_sygus_unif.usingUnif(c))
+    {
+      enums.push_back(c);
+      continue;
+    }
+    Valuation& valuation = d_qe->getValuation();
+    for (const Node& e : d_cond_enums)
     {
-      terms.push_back(e);
+      Assert(d_enum_to_active_guard.find(e) != d_enum_to_active_guard.end());
+      Node g = d_enum_to_active_guard[e];
+      /* Get whether the active guard for this enumerator is set. If so, then
+         there may exist more values for it, and hence we add it to enums. */
+      Node gstatus = valuation.getSatValue(g);
+      if (!gstatus.isNull() && gstatus.getConst<bool>())
+      {
+        enums.push_back(e);
+      }
     }
   }
 }
@@ -83,16 +84,23 @@ bool CegisUnif::constructCandidates(const std::vector<Node>& enums,
                                     std::vector<Node>& candidate_values,
                                     std::vector<Node>& lems)
 {
-  Assert(enums.size() == enum_values.size());
-  if (enums.empty())
+  if (addEvalLemmas(enums, enum_values))
   {
+    Trace("cegis-unif-lemma") << "Added eval lemmas\n";
     return false;
   }
   unsigned min_term_size = 0;
   std::vector<unsigned> enum_consider;
+  NodeManager* nm = NodeManager::currentNM();
   Trace("cegis-unif-enum") << "Register new enumerated values :\n";
   for (unsigned i = 0, size = enums.size(); i < size; ++i)
   {
+    /* Only conditional enumerators will be notified to unif utility */
+    if (std::find(d_cond_enums.begin(), d_cond_enums.end(), enums[i])
+        == d_cond_enums.end())
+    {
+      continue;
+    }
     Trace("cegis-unif-enum") << "  " << enums[i] << " -> " << enum_values[i]
                              << std::endl;
     unsigned sz = d_tds->getSygusTermSize(enum_values[i]);
@@ -110,12 +118,10 @@ bool CegisUnif::constructCandidates(const std::vector<Node>& enums,
   /* only consider the enumerators that are at minimum size (for fairness) */
   Trace("cegis-unif-enum") << "...register " << enum_consider.size() << " / "
                            << enums.size() << std::endl;
-  NodeManager* nm = NodeManager::currentNM();
   for (unsigned i = 0, ecsize = enum_consider.size(); i < ecsize; ++i)
   {
     unsigned j = enum_consider[i];
-    Node e = enums[j];
-    Node v = enum_values[j];
+    Node e = enums[j], v = enum_values[j];
     std::vector<Node> enum_lems;
     d_sygus_unif.notifyEnumeration(e, v, enum_lems);
     /* the lemmas must be guarded by the active guard of the enumerator */
@@ -127,112 +133,29 @@ bool CegisUnif::constructCandidates(const std::vector<Node>& enums,
     }
     lems.insert(lems.end(), enum_lems.begin(), enum_lems.end());
   }
-  /* build candidate solution */
-  Assert(candidates.size() == 1);
-  if (d_sygus_unif.constructSolution(candidate_values))
+  /* divide-and-conquer solution bulding for candidates using unif util */
+  std::vector<Node> sols;
+  if (d_sygus_unif.constructSolution(sols))
   {
-    Node vc = candidate_values[0];
-    Trace("cegis-unif-enum") << "... candidate solution :" << vc << "\n";
+    candidate_values.insert(candidate_values.end(), sols.begin(), sols.end());
     return true;
   }
   return false;
 }
 
-Node CegisUnif::purifyLemma(Node n,
-                            bool ensureConst,
-                            std::vector<Node>& model_guards,
-                            BoolNodePairMap& cache)
-{
-  Trace("cegis-unif-purify") << "PurifyLemma : " << n << "\n";
-  BoolNodePairMap::const_iterator it = cache.find(BoolNodePair(ensureConst, n));
-  if (it != cache.end())
-  {
-    Trace("cegis-unif-purify") << "... already visited " << n << "\n";
-    return it->second;
-  }
-  /* Recurse */
-  unsigned size = n.getNumChildren();
-  Kind k = n.getKind();
-  bool fapp = k == APPLY_UF && size > 0 && n[0] == d_candidate;
-  bool childChanged = false;
-  std::vector<Node> children;
-  for (unsigned i = 0; i < size; ++i)
-  {
-    if (i == 0 && fapp)
-    {
-      children.push_back(n[0]);
-      continue;
-    }
-    Node child = purifyLemma(n[i], ensureConst || fapp, model_guards, cache);
-    children.push_back(child);
-    childChanged = childChanged || child != n[i];
-  }
-  Node nb;
-  if (childChanged)
-  {
-    if (n.hasOperator())
-    {
-      children.insert(children.begin(), n.getOperator());
-    }
-    nb = NodeManager::currentNM()->mkNode(k, children);
-    Trace("cegis-unif-purify") << "PurifyLemma : transformed " << n << " into "
-                               << nb << "\n";
-  }
-  else
-  {
-    nb = n;
-  }
-  /* get model value of non-top level applications of d_candidate */
-  if (ensureConst && fapp)
-  {
-    Node nv = d_parent->getModelValue(nb);
-    Trace("cegis-unif-purify") << "PurifyLemma : model value for " << nb
-                               << " is " << nv << "\n";
-    /* test if different, then update model_guards */
-    if (nv != nb)
-    {
-      Trace("cegis-unif-purify") << "PurifyLemma : adding model eq\n";
-      model_guards.push_back(
-          NodeManager::currentNM()->mkNode(EQUAL, nv, nb).negate());
-      nb = nv;
-    }
-  }
-  nb = Rewriter::rewrite(nb);
-  /* every non-top level application of function-to-synthesize must be reduced
-     to a concrete constant */
-  Assert(!ensureConst || nb.isConst());
-  Trace("cegis-unif-purify") << "... caching [" << n << "] = " << nb << "\n";
-  cache[BoolNodePair(ensureConst, n)] = nb;
-  return nb;
-}
-
 void CegisUnif::registerRefinementLemma(const std::vector<Node>& vars,
                                         Node lem,
                                         std::vector<Node>& lems)
 {
-  Node plem;
-  std::vector<Node> model_guards;
-  BoolNodePairMap cache;
-  Trace("cegis-unif") << "Registering lemma at CegisUnif : " << lem << "\n";
-  /* Make the purified lemma which will guide the unification utility. */
-  plem = purifyLemma(lem, false, model_guards, cache);
-  if (!model_guards.empty())
-  {
-    model_guards.push_back(plem);
-    plem = NodeManager::currentNM()->mkNode(OR, model_guards);
-  }
-  plem = Rewriter::rewrite(plem);
-  Trace("cegis-unif") << "Purified lemma : " << plem << "\n";
+  /* Notify lemma to unification utility and get its purified form */
+  Node plem = d_sygus_unif.addRefLemma(lem);
   d_refinement_lemmas.push_back(plem);
-  /* Notify lemma to unification utility */
-  d_sygus_unif.addRefLemma(plem);
   /* Make the refinement lemma and add it to lems. This lemma is guarded by the
      parent's guard, which has the semantics "this conjecture has a solution",
      hence this lemma states: if the parent conjecture has a solution, it
      satisfies the specification for the given concrete point. */
-  /* Store lemma for external modules */
-  lems.push_back(
-      NodeManager::currentNM()->mkNode(OR, d_parent->getGuard().negate(), lem));
+  lems.push_back(NodeManager::currentNM()->mkNode(
+      OR, d_parent->getGuard().negate(), plem));
 }
 
 CegisUnifEnumManager::CegisUnifEnumManager(QuantifiersEngine* qe,
index 3100d7d0d9b53a1a6e11c4b5976496044ea56782..ab2192ff85a58686091f0065e6f88a81f0981d39 100644 (file)
 #include <map>
 #include <vector>
 
-#include "theory/quantifiers/sygus/sygus_module.h"
+#include "theory/quantifiers/sygus/cegis.h"
 #include "theory/quantifiers/sygus/sygus_unif_rl.h"
 
 namespace CVC4 {
 namespace theory {
 namespace quantifiers {
 
-using BoolNodePair = std::pair<bool, Node>;
-using BoolNodePairHashFunction =
-    PairHashFunction<bool, Node, BoolHashFunction, NodeHashFunction>;
-using BoolNodePairMap =
-    std::unordered_map<BoolNodePair, Node, BoolNodePairHashFunction>;
-
 /** Synthesizes functions in a data-driven SyGuS approach
  *
  * Data is derived from refinement lemmas generated through the regular CEGIS
  * approach. SyGuS is used to generate terms for classifying the data
- * (e.g. using decision tree learning) and thus generate a candidate for a
- * function-to-synthesize.
+ * (e.g. using decision tree learning) and thus generate a candidates for
+ * functions-to-synthesize.
  *
  * This approach is inspired by the divide and conquer synthesis through
  * unification approach by Alur et al. TACAS 2017, by ICE-based invariant
  * synthesis from Garg et al. CAV 2014 and POPL 2016, and Padhi et al. PLDI 2016
  *
- * This module mantains a function-to-synthesize and a set of term
- * enumerators. When new terms are enumerated it tries to learn new candidate
- * function, which is verified outside this module. If verification fails a
+ * This module mantains a set of functions-to-synthesize and a set of term
+ * enumerators. When new terms are enumerated it tries to learn new candidate
+ * solutions, which are verified outside this module. If verification fails a
  * refinement lemma is generated, which this module sends to the utility that
  * learns candidates.
  */
-class CegisUnif : public SygusModule
+class CegisUnif : public Cegis
 {
  public:
   CegisUnif(QuantifiersEngine* qe, CegConjecture* p);
   ~CegisUnif();
-  /** initialize this class
-   *
-   * The module takes ownership of a conjecture when it contains a single
-   * function-to-synthesize
-  */
+  /** initialize this class */
   bool initialize(Node n,
                   const std::vector<Node>& candidates,
                   std::vector<Node>& lemmas) override;
-  /** adds the candidate itself to enums */
+  /** Retrieves enumerators for constructing solutions
+   *
+   * Non-unification candidates have themselves as enumerators, while for
+   * unification candidates we add their conditonal enumerators to enums if
+   * their respective guards are set in the current model
+   */
   void getTermList(const std::vector<Node>& candidates,
                    std::vector<Node>& enums) override;
-  /** Tries to build a new candidate solution with new enumerated expresion
+  /** Tries to build new candidate solutions with new enumerated expressions
    *
    * This function relies on a data-driven unification-based approach for
-   * constructing a solutions for the function-to-synthesize. See SygusUnifRl
-   * for more details.
+   * constructing solutions for the functions-to-synthesize. See SygusUnifRl for
+   * more details.
    *
    * Calls to this function are such that terms is the list of active
    * enumerators (returned by getTermList), and term_values are their current
@@ -93,7 +88,7 @@ class CegisUnif : public SygusModule
                            std::vector<Node>& candidate_values,
                            std::vector<Node>& lems) override;
 
-  /** Communicate refinement lemma to unification utility and external modules
+  /** Communicates refinement lemma to unification utility and external modules
    *
    * For the lemma to be sent to the external modules it adds a guard from the
    * parent conjecture which establishes that if the conjecture has a solution
@@ -124,32 +119,15 @@ class CegisUnif : public SygusModule
    * tree learning) that this module relies upon.
    */
   SygusUnifRl d_sygus_unif;
-  /* Function-to-synthesize (in deep embedding) */
-  Node d_candidate;
   /**
-   * list of enumerators being used to build solutions for candidate by the
-   * above utility.
+   * list of conditonal enumerators to build solutions for candidates being
+   * synthesized with unification techniques
    */
-  std::vector<Node> d_enums;
+  std::vector<Node> d_cond_enums;
   /** map from enumerators to active guards */
   std::map<Node, Node> d_enum_to_active_guard;
-  /* list of learned refinement lemmas */
-  std::vector<Node> d_refinement_lemmas;
-  /**
-  * This is called on the refinement lemma and will replace the arguments of the
-  * function-to-synthesize by their model values (constants).
-  *
-  * When the traversal hits a function application of the function-to-synthesize
-  * it will proceed to ensure that the arguments of that function application
-  * are constants (the ensureConst becomes "true"). It populates a vector of
-  * guards with the (negated) equalities between the original arguments and
-  * their model values.
-  */
-  Node purifyLemma(Node n,
-                   bool ensureConst,
-                   std::vector<Node>& model_guards,
-                   BoolNodePairMap& cache);
-
+  /* The null node */
+  Node d_null;
 }; /* class CegisUnif */
 
 /** Cegis Unif Enumeration Manager
index bf23cd0d1ec9d4c9d7b185ba22a43bc7cc416212..3b7cef4b9b5cb8ea16ef7983d14242ca430f57a3 100644 (file)
@@ -14,6 +14,7 @@
 
 #include "theory/quantifiers/sygus/sygus_unif_rl.h"
 
+#include "theory/quantifiers/sygus/ce_guided_conjecture.h"
 #include "theory/quantifiers/sygus/term_database_sygus.h"
 
 using namespace CVC4::kind;
@@ -22,21 +23,26 @@ namespace CVC4 {
 namespace theory {
 namespace quantifiers {
 
-SygusUnifRl::SygusUnifRl() {}
-
+SygusUnifRl::SygusUnifRl(CegConjecture* p) : d_parent(p) {}
 SygusUnifRl::~SygusUnifRl() {}
 void SygusUnifRl::initialize(QuantifiersEngine* qe,
                              const std::vector<Node>& funs,
                              std::vector<Node>& enums,
                              std::vector<Node>& lemmas)
 {
-  d_true = NodeManager::currentNM()->mkConst(true);
-  d_false = NodeManager::currentNM()->mkConst(false);
-  d_prev_rlemmas = d_true;
-  d_rlemmas = d_true;
-  d_hasRLemmas = false;
   d_ecache.clear();
+  d_cand_to_cond_enum.clear();
+  d_cand_to_pt_enum.clear();
+  /* TODO populate d_unif_candidates and remove lemmas cleaning */
   SygusUnif::initialize(qe, funs, enums, lemmas);
+  lemmas.clear();
+  /* Copy candidates and check whether CegisUnif for any of them */
+  for (const Node& c : d_unif_candidates)
+  {
+    d_app_to_pt[c].clear();
+    d_cand_to_pt_enum[c].clear();
+    d_purified_count[c] = 0;
+  }
 }
 
 void SygusUnifRl::notifyEnumeration(Node e, Node v, std::vector<Node>& lemmas)
@@ -52,95 +58,184 @@ void SygusUnifRl::notifyEnumeration(Node e, Node v, std::vector<Node>& lemmas)
   lemmas.push_back(exc_lemma);
 }
 
-void SygusUnifRl::addRefLemma(Node lemma)
-{
-  d_prev_rlemmas = d_rlemmas;
-  d_rlemmas = d_tds->getExtRewriter()->extendedRewrite(
-      NodeManager::currentNM()->mkNode(AND, d_rlemmas, lemma));
-  Trace("sygus-unif-rl-lemma")
-      << "SyGuSUnifRl: New collection of ref lemmas is " << d_rlemmas << "\n";
-  d_hasRLemmas = d_rlemmas != d_true;
-}
-
-void SygusUnifRl::collectPoints(Node n)
+Node SygusUnifRl::purifyLemma(Node n,
+                              bool ensureConst,
+                              std::vector<Node>& model_guards,
+                              BoolNodePairMap& cache)
 {
-  std::unordered_set<TNode, TNodeHashFunction> visited;
-  std::unordered_set<TNode, TNodeHashFunction>::iterator it;
-  std::vector<TNode> visit;
-  TNode cur;
-  visit.push_back(n);
-  do
+  Trace("sygus-unif-rl-purify") << "PurifyLemma : " << n << "\n";
+  BoolNodePairMap::const_iterator it = cache.find(BoolNodePair(ensureConst, n));
+  if (it != cache.end())
   {
-    cur = visit.back();
-    visit.pop_back();
-    if (visited.find(cur) != visited.end())
-    {
-      continue;
-    }
-    visited.insert(cur);
-    unsigned size = cur.getNumChildren();
-    if (cur.getKind() == APPLY_UF && size > 0)
+    Trace("sygus-unif-rl-purify-debug") << "... already visited " << n << "\n";
+    return it->second;
+  }
+  /* Recurse */
+  unsigned size = n.getNumChildren();
+  Kind k = n.getKind();
+  /* Whether application of a function-to-synthesize */
+  bool fapp = k == APPLY_UF && size > 0;
+  Assert(std::find(d_candidates.begin(), d_candidates.end(), n[0])
+         == d_candidates.end());
+  /* Whether application of a (non-)unification function-to-synthesize */
+  bool u_fapp = fapp && usingUnif(n[0]);
+  bool nu_fapp = fapp && !usingUnif(n[0]);
+  /* We retrive model value now because purified node may not have a value */
+  Node nv = n;
+  /* get model value of non-top level applications of functions-to-synthesize
+     occurring under a unification function-to-synthesize */
+  if (ensureConst && fapp)
+  {
+    nv = d_parent->getModelValue(n);
+    Assert(n != nv);
+    Trace("sygus-unif-rl-purify") << "PurifyLemma : model value for " << n
+                                  << " is " << nv << "\n";
+  }
+  /* Travese to purify */
+  bool childChanged = false;
+  std::vector<Node> children;
+  NodeManager* nm = NodeManager::currentNM();
+  for (unsigned i = 0; i < size; ++i)
+  {
+    if (i == 0 && fapp)
     {
-      std::vector<Node> pt;
-      for (unsigned i = 1; i < size; ++i)
-      {
-        Assert(cur[i].isConst());
-        pt.push_back(cur[i]);
-      }
-      d_app_to_pt[cur] = pt;
+      children.push_back(n[i]);
       continue;
     }
-    for (const TNode& child : cur)
+    /* Arguments of non-unif functions do not need to be constant */
+    Node child = purifyLemma(
+        n[i], !nu_fapp && (ensureConst || u_fapp), model_guards, cache);
+    children.push_back(child);
+    childChanged = childChanged || child != n[i];
+  }
+  Node nb;
+  if (childChanged)
+  {
+    if (n.hasOperator())
     {
-      visit.push_back(child);
+      children.insert(children.begin(), n.getOperator());
     }
-  } while (!visit.empty());
-}
-
-void SygusUnifRl::initializeConstructSol()
-{
-  if (d_hasRLemmas && d_rlemmas != d_prev_rlemmas)
+    nb = NodeManager::currentNM()->mkNode(k, children);
+    Trace("sygus-unif-rl-purify") << "PurifyLemma : transformed " << n
+                                  << " into " << nb << "\n";
+  }
+  else
   {
-    collectPoints(d_rlemmas);
-    if (Trace.isOn("sygus-unif-rl-sol"))
+    nb = n;
+  }
+  /* Map to point enumerator every unification function-to-synthesize  */
+  if (u_fapp)
+  {
+    Node np;
+    std::map<Node, Node>::const_iterator it = d_app_to_purified.find(nb);
+    if (it == d_app_to_purified.end())
     {
-      Trace("sygus-unif-rl-sol") << "SyGuSUnifRl: Points from " << d_rlemmas
-                                 << "\n";
-      for (const std::pair<Node, std::vector<Node>>& pair : d_app_to_pt)
+      if (!childChanged)
+      {
+        Assert(nb.hasOperator());
+        children.insert(children.begin(), n.getOperator());
+      }
+      /* Build purified head with fresh skolem and recreate node */
+      std::stringstream ss;
+      ss << nb[0] << "_" << d_purified_count[nb[0]]++;
+      Node new_f = nm->mkSkolem(ss.str(), nb[0].getType());
+      /* Adds new enumerator to map from candidate */
+      Trace("sygus-unif-rl-purify") << "...new enum " << new_f
+                                        << " for candidate " << nb[0] << "\n";
+      d_cand_to_pt_enum[nb[0]].push_back(new_f);
+      /* Maps new enumerator to its respective tuple of arguments */
+      d_app_to_pt[new_f] =
+          std::vector<Node>(children.begin() + 2, children.end());
+      if (Trace.isOn("sygus-unif-rl-purify"))
       {
-        Trace("sygus-unif-rl-sol") << "...[" << pair.first << "] --> (";
-        for (const Node& pt_i : pair.second)
+        Trace("sygus-unif-rl-purify") << "...[" << new_f << "] --> (";
+        for (const Node& pt_i : d_app_to_pt[new_f])
         {
-          Trace("sygus-unif-rl-sol") << pt_i << " ";
+          Trace("sygus-unif-rl-purify") << pt_i << " ";
         }
-        Trace("sygus-unif-rl-sol") << ")\n";
+        Trace("sygus-unif-rl-purify") << ")\n";
       }
+      /* replace first child and rebulid node */
+      children[1] = new_f;
+      np = NodeManager::currentNM()->mkNode(k, children);
+      d_app_to_purified[nb] = np;
     }
+    else
+    {
+      np = it->second;
+    }
+    Trace("sygus-unif-rl-purify")
+        << "PurifyLemma : purified head and transformed " << nb << " into "
+        << np << "\n";
+    nb = np;
   }
+  /* Add equality between purified fapp and model value */
+  if (ensureConst && fapp)
+  {
+    model_guards.push_back(
+        NodeManager::currentNM()->mkNode(EQUAL, nv, nb).negate());
+    nb = nv;
+    Trace("sygus-unif-rl-purify") << "PurifyLemma : adding model eq "
+                                  << model_guards.back() << "\n";
+  }
+  nb = Rewriter::rewrite(nb);
+  /* every non-top level application of function-to-synthesize must be reduced
+     to a concrete constant */
+  Assert(!ensureConst || nb.isConst());
+  Trace("sygus-unif-rl-purify-debug") << "... caching [" << n << "] = " << nb
+                                      << "\n";
+  cache[BoolNodePair(ensureConst, n)] = nb;
+  return nb;
 }
 
-void SygusUnifRl::initializeConstructSolFor(Node f) {}
-Node SygusUnifRl::constructSol(Node f, Node e, NodeRole nrole, int ind)
+Node SygusUnifRl::addRefLemma(Node lemma)
 {
-  Node solution = canCloseBranch(e);
-  if (!solution.isNull())
+  Trace("sygus-unif-rl-purify") << "Registering lemma at SygusUnif : " << lemma
+                               << "\n";
+  std::vector<Node> model_guards;
+  BoolNodePairMap cache;
+  /* Make the purified lemma which will guide the unification utility. */
+  Node plem = purifyLemma(lemma, false, model_guards, cache);
+  if (!model_guards.empty())
   {
-    return solution;
+    model_guards.push_back(plem);
+    plem = NodeManager::currentNM()->mkNode(OR, model_guards);
   }
-  return Node::null();
+  plem = Rewriter::rewrite(plem);
+  Trace("sygus-unif-rl-purify") << "Purified lemma : " << plem << "\n";
+  return plem;
 }
 
-Node SygusUnifRl::canCloseBranch(Node e)
+void SygusUnifRl::initializeConstructSol() {}
+void SygusUnifRl::initializeConstructSolFor(Node f) {}
+bool SygusUnifRl::constructSolution(std::vector<Node>& sols)
 {
-  if (!d_hasRLemmas && !d_ecache[e].d_enum_vals.empty())
+  for (const Node& c : d_candidates)
   {
-    Trace("sygus-unif-rl-sol") << "SyGuSUnifRl: Closed branch and yielded "
-                                  << d_ecache[e].d_enum_vals[0] << "\n";
-    return d_ecache[e].d_enum_vals[0];
+    if (!usingUnif(c))
+    {
+      Node v = d_parent->getModelValue(c);
+      Trace("sygus-unif-rl-sol") << "Adding solution " << v
+                                 << " to non-unif candidate " << c << "\n";
+      sols.push_back(v);
+    }
+    else
+    {
+      return false;
+    }
   }
+  return true;
+}
+
+Node SygusUnifRl::constructSol(Node f, Node e, NodeRole nrole, int ind)
+{
   return Node::null();
 }
 
+bool SygusUnifRl::usingUnif(Node f)
+{
+  return d_unif_candidates.find(f) != d_unif_candidates.end();
+}
 
 } /* CVC4::theory::quantifiers namespace */
 } /* CVC4::theory namespace */
index 13d0d1e566e43db725d0270a81544eede3c6e83b..dc1b14641a2058a0f1a9d1f49ee478b9eed06d27 100644 (file)
 #include <map>
 #include "theory/quantifiers/sygus/sygus_unif.h"
 
+#include "theory/quantifiers_engine.h"
+
 namespace CVC4 {
 namespace theory {
 namespace quantifiers {
 
+using BoolNodePair = std::pair<bool, Node>;
+using BoolNodePairHashFunction =
+    PairHashFunction<bool, Node, BoolHashFunction, NodeHashFunction>;
+using BoolNodePairMap =
+    std::unordered_map<BoolNodePair, Node, BoolNodePairHashFunction>;
+
+class CegConjecture;
+
 /** Sygus unification Refinement Lemmas utility
  *
  * This class implement synthesis-by-unification, where the specification is a
@@ -33,7 +43,7 @@ namespace quantifiers {
 class SygusUnifRl : public SygusUnif
 {
  public:
-  SygusUnifRl();
+  SygusUnifRl(CegConjecture* p);
   ~SygusUnifRl();
 
   /** initialize */
@@ -43,25 +53,28 @@ class SygusUnifRl : public SygusUnif
                   std::vector<Node>& lemmas) override;
   /** Notify enumeration */
   void notifyEnumeration(Node e, Node v, std::vector<Node>& lemmas) override;
+  /** Construct solution */
+  bool constructSolution(std::vector<Node>& sols) override;
+  Node constructSol(Node f, Node e, NodeRole nrole, int ind) override;
   /** add refinement lemma
    *
    * This adds a lemma to the specification for f.
    */
-  void addRefLemma(Node lemma);
+  Node addRefLemma(Node lemma);
+  /**
+   * whether f is being synthesized with unification strategies. This can be
+   * checked through wehether f has conditional or point enumerators (we use the
+   * former)
+    */
+  bool usingUnif(Node f);
 
  protected:
-  /** true and false nodes */
-  Node d_true, d_false;
-  /** current collecton of refinement lemmas */
-  Node d_rlemmas;
-  /** previous collecton of refinement lemmas */
-  Node d_prev_rlemmas;
-  /** whether there are refinement lemmas to satisfy when building solutions */
-  bool d_hasRLemmas;
-  /**
-   * maps applications of the function-to-synthesize to their tuple of arguments
-   * (which constitute a "data point") */
-  std::map<Node, std::vector<Node>> d_app_to_pt;
+  /** reference to the parent conjecture */
+  CegConjecture* d_parent;
+  /* Functions-to-synthesize (a.k.a. candidates) with unification strategies */
+  std::unordered_set<Node, NodeHashFunction> d_unif_candidates;
+  /* Maps unif candidates to their conditonal enumerators */
+  std::map<Node, Node> d_cand_to_cond_enum;
   /**
    * This class stores information regarding an enumerator, including: a
    * database
@@ -78,9 +91,6 @@ class SygusUnifRl : public SygusUnif
   /** maps enumerators to the information above */
   std::map<Node, EnumCache> d_ecache;
 
-  /** Traverses n and populates d_app_to_pt */
-  void collectPoints(Node n);
-
   /** collects data from refinement lemmas to drive solution construction
    *
    * In particular it rebuilds d_app_to_pt whenever d_prev_rlemmas is different
@@ -89,15 +99,58 @@ class SygusUnifRl : public SygusUnif
   void initializeConstructSol() override;
   /** initialize construction solution for function-to-synthesize f */
   void initializeConstructSolFor(Node f) override;
+  /*
+    --------------------------------------------------------------
+        Purification
+    --------------------------------------------------------------
+  */
+  /* Maps unif candidates to their point enumerators */
+  std::map<Node, std::vector<Node>> d_cand_to_pt_enum;
+  /**
+   * maps applications of the function-to-synthesize to their tuple of arguments
+   * (which constitute a "data point") */
+  std::map<Node, std::vector<Node>> d_app_to_pt;
+  /** Maps applications of unif functions-to-synthesize to purified symbols*/
+  std::map<Node, Node> d_app_to_purified;
+  /** Maps unif functions-to-synthesize to counters of purified symbols */
+  std::map<Node, unsigned> d_purified_count;
   /**
-   * Returns a term covering all data points in the current branch, on null if
-   * none can be found among the currently enumerated values for the respective
-   * enumerator
+   * This is called on the refinement lemma and will rewrite applications of
+   * functions-to-synthesize to their respective purified form, i.e. such that
+   * all unification functions are applied over concrete values. Moreover
+   * unification functions are also rewritten such that every different tuple of
+   * arguments has a fresh function symbol applied to it.
+   *
+   * Non-unification functions are also equated to their model values when they
+   * occur as arguments of unification functions.
+   *
+   * A vector of guards with the (negated) equalities between the original
+   * arguments and their model values is populated accordingly.
+   *
+   * When the traversal encounters an application of a unification
+   * function-to-synthesize it will proceed to ensure that the arguments of that
+   * function application are constants (the ensureConst becomes "true"). If an
+   * applicatin of a non-unif function-to-synthesize is reached, the requirement
+   * is lifted (the ensureConst becomes "false"). This avoides introducing
+   * spurious equalities in model_guards.
+   *
+   * For example if "f" is being synthesized with a unification strategy and "g"
+   * is not then the application
+   *   f(g(f(g(0))))=1
+   * would be purified into
+   *   g(0) = c1 ^ g(f1(c1)) = c3 => f2(c3)
+   *
+   * Similarly
+   *   f(+(0,f(g(0))))
+   * would be purified into
+   *   g(0) = c1 ^ f1(c1) = c2 => f2(+(0,c2))
+   *
+   * This function also populates the maps for point enumerators
    */
-  Node canCloseBranch(Node e);
-
-  /** construct solution */
-  Node constructSol(Node f, Node e, NodeRole nrole, int ind) override;
+  Node purifyLemma(Node n,
+                   bool ensureConst,
+                   std::vector<Node>& model_guards,
+                   BoolNodePairMap& cache);
 };
 
 } /* CVC4::theory::quantifiers namespace */