Initial implementation of SygusUnifRL (#1829)
authorHaniel Barbosa <hanielbbarbosa@gmail.com>
Sat, 28 Apr 2018 13:11:09 +0000 (08:11 -0500)
committerAndrew Reynolds <andrew.j.reynolds@gmail.com>
Sat, 28 Apr 2018 13:11:09 +0000 (08:11 -0500)
src/theory/quantifiers/sygus/cegis_unif.cpp
src/theory/quantifiers/sygus/sygus_unif_rl.cpp
src/theory/quantifiers/sygus/sygus_unif_rl.h

index f7d970ddff357b8386616616e3d25235771b8745..7794ec9124140525759d5e3a97e8370a372ec635 100644 (file)
@@ -130,6 +130,7 @@ bool CegisUnif::constructCandidates(const std::vector<Node>& enums,
   /* build candidate solution */
   Assert(candidates.size() == 1);
   Node vc = d_sygus_unif.constructSolution();
+  Trace("cegis-unif-enum") << "... candidate solution :" << vc << "\n";
   if (vc.isNull())
   {
     return false;
index bac46997d363d3f365a95217a680d59b4469ff92..723210ca164c9b9a588f4c7c456cd9314498dbb7 100644 (file)
@@ -14,6 +14,8 @@
 
 #include "theory/quantifiers/sygus/sygus_unif_rl.h"
 
+#include "theory/quantifiers/sygus/term_database_sygus.h"
+
 using namespace CVC4::kind;
 
 namespace CVC4 {
@@ -29,19 +31,113 @@ void SygusUnifRl::initialize(QuantifiersEngine* qe,
                              std::vector<Node>& enums,
                              std::vector<Node>& lemmas)
 {
+  d_true = NodeManager::currentNM()->mkConst(true);
+  d_false = NodeManager::currentNM()->mkConst(false);
+  d_prev_rlemmas = d_true;
+  d_rlemmas = d_true;
+  d_hasRLemmas = false;
+  d_ecache.clear();
   SygusUnif::initialize(qe, f, enums, lemmas);
 }
 
 void SygusUnifRl::notifyEnumeration(Node e, Node v, std::vector<Node>& lemmas)
 {
+  Trace("sygus-unif-rl-notify") << "SyGuSUnifRl: Adding to enum " << e
+                                << " value " << v << "\n";
+  d_ecache[e].d_enum_vals.push_back(v);
+  /* Exclude v from next enumerations for e */
+  Node exc_lemma =
+      d_tds->getExplain()->getExplanationForEquality(e, v).negate();
+  Trace("sygus-unif-rl-notify")
+      << "SygusUnifRl : enumeration exclude lemma : " << exc_lemma << std::endl;
+  lemmas.push_back(exc_lemma);
+}
+
+void SygusUnifRl::addRefLemma(Node lemma)
+{
+  d_prev_rlemmas = d_rlemmas;
+  d_rlemmas = d_tds->getExtRewriter()->extendedRewrite(
+      NodeManager::currentNM()->mkNode(AND, d_rlemmas, lemma));
+  Trace("sygus-unif-rl-lemma")
+      << "SyGuSUnifRl: New collection of ref lemmas is " << d_rlemmas << "\n";
+  d_hasRLemmas = d_rlemmas != d_true;
 }
 
-void SygusUnifRl::addRefLemma(Node lemma) {}
+void SygusUnifRl::collectPoints(Node n)
+{
+  std::unordered_set<TNode, TNodeHashFunction> visited;
+  std::unordered_set<TNode, TNodeHashFunction>::iterator it;
+  std::vector<TNode> visit;
+  TNode cur;
+  visit.push_back(n);
+  do
+  {
+    cur = visit.back();
+    visit.pop_back();
+    if (visited.find(cur) != visited.end())
+    {
+      continue;
+    }
+    visited.insert(cur);
+    unsigned size = cur.getNumChildren();
+    if (cur.getKind() == APPLY_UF && size > 0)
+    {
+      std::vector<Node> pt;
+      for (unsigned i = 1; i < size; ++i)
+      {
+        Assert(cur[i].isConst());
+        pt.push_back(cur[i]);
+      }
+      d_app_to_pt[cur] = pt;
+      continue;
+    }
+    for (const TNode& child : cur)
+    {
+      visit.push_back(child);
+    }
+  } while (!visit.empty());
+}
 
-void SygusUnifRl::initializeConstructSol() {}
+void SygusUnifRl::initializeConstructSol()
+{
+  if (d_hasRLemmas && d_rlemmas != d_prev_rlemmas)
+  {
+    collectPoints(d_rlemmas);
+    if (Trace.isOn("sygus-unif-rl-sol"))
+    {
+      Trace("sygus-unif-rl-sol") << "SyGuSUnifRl: Points from " << d_rlemmas
+                                 << "\n";
+      for (const std::pair<Node, std::vector<Node>>& pair : d_app_to_pt)
+      {
+        Trace("sygus-unif-rl-sol") << "...[" << pair.first << "] --> (";
+        for (const Node& pt_i : pair.second)
+        {
+          Trace("sygus-unif-rl-sol") << pt_i << " ";
+        }
+        Trace("sygus-unif-rl-sol") << ")\n";
+      }
+    }
+  }
+}
 
 Node SygusUnifRl::constructSol(Node e, NodeRole nrole, int ind)
 {
+  Node solution = canCloseBranch(e);
+  if (!solution.isNull())
+  {
+    return solution;
+  }
+  return Node::null();
+}
+
+Node SygusUnifRl::canCloseBranch(Node e)
+{
+  if (!d_hasRLemmas && !d_ecache[e].d_enum_vals.empty())
+  {
+    Trace("sygus-unif-rl-sol") << "SyGuSUnifRl: Closed branch and yielded "
+                                  << d_ecache[e].d_enum_vals[0] << "\n";
+    return d_ecache[e].d_enum_vals[0];
+  }
   return Node::null();
 }
 
index 8dc1906fbd8a9d9f1e9f94b4871ece0917fda195..0f3871056a1086447d521af928b0395042d80d8b 100644 (file)
@@ -50,10 +50,50 @@ class SygusUnifRl : public SygusUnif
   void addRefLemma(Node lemma);
 
  protected:
-  /** set of refinmente lemmas */
-  std::vector<Node> d_refLemmas;
-  /** initialize construct solution */
+  /** true and false nodes */
+  Node d_true, d_false;
+  /** current collecton of refinement lemmas */
+  Node d_rlemmas;
+  /** previous collecton of refinement lemmas */
+  Node d_prev_rlemmas;
+  /** whether there are refinement lemmas to satisfy when building solutions */
+  bool d_hasRLemmas;
+  /**
+   * maps applications of the function-to-synthesize to their tuple of arguments
+   * (which constitute a "data point") */
+  std::map<Node, std::vector<Node>> d_app_to_pt;
+  /**
+   * This class stores information regarding an enumerator, including: a
+   * database
+   * of values that have been enumerated for this enumerator.
+   */
+  class EnumCache
+  {
+   public:
+    EnumCache() {}
+    ~EnumCache() {}
+    /** Values that have been enumerated for this enumerator */
+    std::vector<Node> d_enum_vals;
+  };
+  /** maps enumerators to the information above */
+  std::map<Node, EnumCache> d_ecache;
+
+  /** Traverses n and populates d_app_to_pt */
+  void collectPoints(Node n);
+
+  /** collects data from refinement lemmas to drive solution construction
+   *
+   * In particular it rebuilds d_app_to_pt whenever d_prev_rlemmas is different
+   * from d_rlemmas, in which case we may have added or removed data points
+   */
   void initializeConstructSol() override;
+  /**
+   * Returns a term covering all data points in the current branch, on null if
+   * none can be found among the currently enumerated values for the respective
+   * enumerator
+   */
+  Node canCloseBranch(Node e);
+
   /** construct solution */
   Node constructSol(Node e, NodeRole nrole, int ind) override;
 };