Improve the separation resolution scheme in cegis unif (#1931)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 16 May 2018 23:22:35 +0000 (18:22 -0500)
committerGitHub <noreply@github.com>
Wed, 16 May 2018 23:22:35 +0000 (18:22 -0500)
src/theory/quantifiers/sygus/cegis_unif.cpp
src/theory/quantifiers/sygus/sygus_pbe.cpp
src/theory/quantifiers/sygus/sygus_unif.cpp
src/theory/quantifiers/sygus/sygus_unif.h
src/theory/quantifiers/sygus/sygus_unif_io.cpp
src/theory/quantifiers/sygus/sygus_unif_io.h
src/theory/quantifiers/sygus/sygus_unif_rl.cpp
src/theory/quantifiers/sygus/sygus_unif_rl.h

index ac70a97b24b113f3b521141e2f2153a11d196add..ab98342545da5d7bc5cacdbc3e66ef97f41e3c59 100644 (file)
@@ -127,6 +127,7 @@ bool CegisUnif::constructCandidates(const std::vector<Node>& enums,
   // values
   NodeManager* nm = NodeManager::currentNM();
   bool addedUnifEnumSymBreakLemma = false;
+  Node cost_lit = d_u_enum_manager.getCurrentLiteral();
   std::map<Node, std::vector<Node>> unif_enums[2];
   std::map<Node, std::vector<Node>> unif_values[2];
   for (const Node& c : candidates)
@@ -157,40 +158,47 @@ bool CegisUnif::constructCandidates(const std::vector<Node>& enums,
           }
           unif_values[index][e].push_back(m_eu);
         }
-        // inter-enumerator symmetry breaking
-        // given a pool of unification enumerators eu_1, ..., eu_n,
-        // CegisUnifEnumManager insists that size(eu_1) <= ... <= size(eu_n).
-        // We additionally insist that M(eu_i) < M(eu_{i+1}) when
-        // size(eu_i) = size(eu_{i+1}), where < is pointer comparison.
-        // We enforce this below by adding symmetry breaking lemmas of the form
-        //  ~( eu_i = M(eu_i) ^ eu_{i+1} = M(eu_{i+1} ) )
-        // when applicable.
-        for (unsigned j = 1, nenum = unif_values[index][e].size(); j < nenum;
-             j++)
+        if (index == 0)
         {
-          Node prev_val = unif_values[index][e][j - 1];
-          Node curr_val = unif_values[index][e][j];
-          // compare the node values
-          if (curr_val < prev_val)
+          // inter-enumerator symmetry breaking
+          // given a pool of unification enumerators eu_1, ..., eu_n,
+          // CegisUnifEnumManager insists that size(eu_1) <= ... <= size(eu_n).
+          // We additionally insist that M(eu_i) < M(eu_{i+1}) when
+          // size(eu_i) = size(eu_{i+1}), where < is pointer comparison.
+          // We enforce this below by adding symmetry breaking lemmas of the
+          // form ~( eu_i = M(eu_i) ^ eu_{i+1} = M(eu_{i+1} ) )
+          // when applicable.
+          // we only do this for return value enumerators, since condition
+          // enumerators cannot be ordered (their order is based on the
+          // seperation resolution scheme during model construction).
+          for (unsigned j = 1, nenum = unif_values[index][e].size(); j < nenum;
+               j++)
           {
-            // must have the same size
-            unsigned prev_size = d_tds->getSygusTermSize(prev_val);
-            unsigned curr_size = d_tds->getSygusTermSize(curr_val);
-            Assert(prev_size <= curr_size);
-            if (curr_size == prev_size)
+            Node prev_val = unif_values[index][e][j - 1];
+            Node curr_val = unif_values[index][e][j];
+            // compare the node values
+            if (curr_val < prev_val)
             {
-              Node slem = nm->mkNode(AND,
-                                     unif_enums[index][e][j - 1].eqNode(
-                                         unif_values[index][e][j - 1]),
-                                     unif_enums[index][e][j].eqNode(
-                                         unif_values[index][e][j]))
-                              .negate();
-              Trace("cegis-unif") << "CegisUnif::lemma, inter-unif-enumerator "
-                                     "symmetry breaking lemma : "
-                                  << slem << "\n";
-              d_qe->getOutputChannel().lemma(slem);
-              addedUnifEnumSymBreakLemma = true;
-              break;
+              // must have the same size
+              unsigned prev_size = d_tds->getSygusTermSize(prev_val);
+              unsigned curr_size = d_tds->getSygusTermSize(curr_val);
+              Assert(prev_size <= curr_size);
+              if (curr_size == prev_size)
+              {
+                Node slem = nm->mkNode(AND,
+                                       unif_enums[index][e][j - 1].eqNode(
+                                           unif_values[index][e][j - 1]),
+                                       unif_enums[index][e][j].eqNode(
+                                           unif_values[index][e][j]))
+                                .negate();
+                Trace("cegis-unif")
+                    << "CegisUnif::lemma, inter-unif-enumerator "
+                       "symmetry breaking lemma : "
+                    << slem << "\n";
+                d_qe->getOutputChannel().lemma(slem);
+                addedUnifEnumSymBreakLemma = true;
+                break;
+              }
             }
           }
         }
@@ -206,12 +214,14 @@ bool CegisUnif::constructCandidates(const std::vector<Node>& enums,
   {
     for (const Node& e : d_cand_to_strat_pt[c])
     {
-      d_sygus_unif.setConditions(e, unif_values[1][e]);
+      d_sygus_unif.setConditions(
+          e, cost_lit, unif_enums[1][e], unif_values[1][e]);
     }
   }
   // build solutions (for unif candidates a divide-and-conquer approach is used)
   std::vector<Node> sols;
-  if (d_sygus_unif.constructSolution(sols))
+  std::vector<Node> lemmas;
+  if (d_sygus_unif.constructSolution(sols, lemmas))
   {
     candidate_values.insert(candidate_values.end(), sols.begin(), sols.end());
     if (Trace.isOn("cegis-unif"))
@@ -226,51 +236,13 @@ bool CegisUnif::constructCandidates(const std::vector<Node>& enums,
     }
     return true;
   }
-  std::map<Node, std::vector<Node>> sepPairs;
-  if (d_sygus_unif.getSeparationPairs(sepPairs))
+
+  Assert(!lemmas.empty());
+  for (const Node& lem : lemmas)
   {
-    // Build separation lemma based on current size, and for each heads that
-    // could not be separated, the condition values currently enumerated for its
-    // decision tree
-    Node neg_cost_lit = d_u_enum_manager.getCurrentLiteral().negate();
-    std::vector<Node> cenums, cond_eqs;
-    for (std::pair<const Node, std::vector<Node>>& np : sepPairs)
-    {
-      Node e = np.first;
-      // Build equalities between condition enumerators associated with the
-      // strategy point whose decision tree could not separate the given heads
-      std::vector<Node> cond_eqs;
-      std::map<Node, std::vector<Node>>::iterator itue = unif_enums[1].find(e);
-      Assert(itue != unif_enums[1].end());
-      std::map<Node, std::vector<Node>>::iterator ituv = unif_values[1].find(e);
-      Assert(ituv != unif_values[1].end());
-      Assert(itue->second.size() == ituv->second.size());
-      for (unsigned i = 0, size = itue->second.size(); i < size; i++)
-      {
-        cond_eqs.push_back(itue->second[i].eqNode(ituv->second[i]));
-      }
-      Assert(!cond_eqs.empty());
-      Node neg_conds_lit =
-          cond_eqs.size() > 1 ? nm->mkNode(AND, cond_eqs) : cond_eqs[0];
-      neg_conds_lit = neg_conds_lit.negate();
-      for (const Node& eq : np.second)
-      {
-        // A separation lemma is of the shape
-        //   (cost_n+1 ^ (c_1 = M(c_1) ^ ... ^ M(c_n))) => e_i = e_j
-        // in which cost_n+1 is the cost function for the size n+1, each c_k is
-        // a conditional enumerator associated with the respective decision
-        // tree, each M(c_k) its current model value, and e_i, e_j are two
-        // distinct heads that could not be separated by these condition values
-        //
-        // Such a lemma will force the ground solver, within the restrictions of
-        // the respective cost function, to make e_i and e_j equal or to come up
-        // with new values for the conditional enumerators such that we can try
-        Node sep_lemma = nm->mkNode(OR, neg_cost_lit, neg_conds_lit, eq);
-        Trace("cegis-unif")
-            << "CegisUnif::lemma, separation lemma : " << sep_lemma << "\n";
-        d_qe->getOutputChannel().lemma(sep_lemma);
-      }
-    }
+    Trace("cegis-unif") << "CegisUnif::lemma, separation lemma : " << lem
+                        << "\n";
+    d_qe->getOutputChannel().lemma(lem);
   }
   return false;
 }
@@ -478,7 +450,7 @@ void CegisUnifEnumManager::incrementNumEnumerators()
           d_qe->getOutputChannel().lemma(sym_break_red_ops);
         }
         // symmetry breaking between enumerators
-        if (!ci.second.d_enums[index].empty())
+        if (!ci.second.d_enums[index].empty() && index == 0)
         {
           Node e_prev = ci.second.d_enums[index].back();
           Node size_e = nm->mkNode(DT_SIZE, e);
index 0afd7a82cef920d461392e3c0f73db86a862688f..cd011ef4450eadf80099ee31601e61e5d9c6fada 100644 (file)
@@ -441,7 +441,7 @@ bool CegConjecturePbe::constructCandidates(const std::vector<Node>& enums,
     Node c = candidates[i];
     //build decision tree for candidate
     std::vector<Node> sol;
-    if (d_sygus_unif[c].constructSolution(sol))
+    if (d_sygus_unif[c].constructSolution(sol, lems))
     {
       Assert(sol.size() == 1);
       candidate_values.push_back(sol[0]);
index 15606c9a4ebea4a45458e6bee1cdbbeafb7c1859..76ca94e058c25d6bdba9f051e76da4b1d6969b1b 100644 (file)
@@ -42,7 +42,8 @@ void SygusUnif::initializeCandidate(
   d_strategy[f].initialize(qe, f, enums);
 }
 
-bool SygusUnif::constructSolution(std::vector<Node>& sols)
+bool SygusUnif::constructSolution(std::vector<Node>& sols,
+                                  std::vector<Node>& lemmas)
 {
   // initialize a call to construct solution
   initializeConstructSol();
@@ -52,7 +53,7 @@ bool SygusUnif::constructSolution(std::vector<Node>& sols)
     initializeConstructSolFor(f);
     // call the virtual construct solution method
     Node e = d_strategy[f].getRootEnumerator();
-    Node sol = constructSol(f, e, role_equal, 1);
+    Node sol = constructSol(f, e, role_equal, 1, lemmas);
     if (sol.isNull())
     {
       return false;
index 1c7972978b060d43c99f8598f45faf3b182def10..a19f8e41b177af1ae2a559e3cd96265f93381c32 100644 (file)
@@ -80,8 +80,12 @@ class SygusUnif
    * based on the current set of enumerated values. Returns null if it cannot
    * for some function (for example, if the set of enumerated values is
    * insufficient, or if a non-deterministic strategy aborts).
+   *
+   * This call may add lemmas to lemmas that should be sent out on an output
+   * channel by the caller.
    */
-  virtual bool constructSolution(std::vector<Node>& sols);
+  virtual bool constructSolution(std::vector<Node>& sols,
+                                 std::vector<Node>& lemmas);
 
  protected:
   /** reference to quantifier engine */
@@ -150,7 +154,8 @@ class SygusUnif
    *
    * ind is the term depth of the context (for debugging).
    */
-  virtual Node constructSol(Node f, Node e, NodeRole nrole, int ind) = 0;
+  virtual Node constructSol(
+      Node f, Node e, NodeRole nrole, int ind, std::vector<Node>& lemmas) = 0;
   /** Heuristically choose the best solved term from solved in context x,
    * currently return the first. */
   virtual Node constructBestSolvedTerm(const std::vector<Node>& solved);
index 1b43b77bac57205f5a43b607580a99cc568069ed..f9c59711872cd9c59adc2b95ff2ea51c97810bc1 100644 (file)
@@ -680,9 +680,10 @@ void SygusUnifIo::notifyEnumeration(Node e, Node v, std::vector<Node>& lemmas)
   lemmas.push_back(exp_exc);
 }
 
-bool SygusUnifIo::constructSolution(std::vector<Node>& sols)
+bool SygusUnifIo::constructSolution(std::vector<Node>& sols,
+                                    std::vector<Node>& lemmas)
 {
-  Node sol = constructSolutionNode();
+  Node sol = constructSolutionNode(lemmas);
   if (!sol.isNull())
   {
     sols.push_back(sol);
@@ -691,7 +692,7 @@ bool SygusUnifIo::constructSolution(std::vector<Node>& sols)
   return false;
 }
 
-Node SygusUnifIo::constructSolutionNode()
+Node SygusUnifIo::constructSolutionNode(std::vector<Node>& lemmas)
 {
   Node c = d_candidate;
   if (!d_solution.isNull())
@@ -716,7 +717,7 @@ Node SygusUnifIo::constructSolutionNode()
       initializeConstructSolFor(c);
       // call the virtual construct solution method
       Node e = d_strategy[c].getRootEnumerator();
-      Node vcc = constructSol(c, e, role_equal, 1);
+      Node vcc = constructSol(c, e, role_equal, 1, lemmas);
       // if we constructed the solution, and we either did not previously have
       // a solution, or the new solution is better (smaller).
       if (!vcc.isNull()
@@ -854,7 +855,8 @@ void SygusUnifIo::initializeConstructSolFor(Node f)
   Assert(d_candidate == f);
 }
 
-Node SygusUnifIo::constructSol(Node f, Node e, NodeRole nrole, int ind)
+Node SygusUnifIo::constructSol(
+    Node f, Node e, NodeRole nrole, int ind, std::vector<Node>& lemmas)
 {
   Assert(d_candidate == f);
   UnifContextIo& x = d_context;
@@ -1285,7 +1287,7 @@ Node SygusUnifIo::constructSol(Node f, Node e, NodeRole nrole, int ind)
           }
           else
           {
-            rec_c = constructSol(f, cenum.first, cenum.second, ind + 2);
+            rec_c = constructSol(f, cenum.first, cenum.second, ind + 2, lemmas);
           }
 
           // undo update the context
index 9a6c0242109e97cefecce2fd34fe8782cd4c5243..a8e7fc011b78e0cc322cadbaa3a699cb9dcc8b4a 100644 (file)
@@ -285,7 +285,8 @@ class SygusUnifIo : public SygusUnif
   void notifyEnumeration(Node e, Node v, std::vector<Node>& lemmas) override;
 
   /** Construct solution */
-  bool constructSolution(std::vector<Node>& sols) override;
+  bool constructSolution(std::vector<Node>& sols,
+                         std::vector<Node>& lemmas) override;
 
   /** add example
    *
@@ -377,7 +378,7 @@ class SygusUnifIo : public SygusUnif
    * constructSolution. If this returns a non-null node, then that term is a
    * solution for the function-to-synthesize in the overall conjecture.
    */
-  Node constructSolutionNode();
+  Node constructSolutionNode(std::vector<Node>& lemmas);
   /** domain-specific enumerator exclusion techniques
    *
    * Returns true if the value v for e can be excluded based on a
@@ -414,7 +415,11 @@ class SygusUnifIo : public SygusUnif
   /** initialize construct solution for */
   void initializeConstructSolFor(Node f) override;
   /** construct solution */
-  Node constructSol(Node f, Node e, NodeRole nrole, int ind) override;
+  Node constructSol(Node f,
+                    Node e,
+                    NodeRole nrole,
+                    int ind,
+                    std::vector<Node>& lemmas) override;
 };
 
 } /* CVC4::theory::quantifiers namespace */
index f7337a92de18b0e098e4e744a90f7d95bb7b2682..3fbb4b2b79cfb7652663cc7e4c5ea55d53c9a91a 100644 (file)
@@ -14,6 +14,8 @@
 
 #include "theory/quantifiers/sygus/sygus_unif_rl.h"
 
+#include "options/base_options.h"
+#include "printer/printer.h"
 #include "theory/quantifiers/sygus/ce_guided_conjecture.h"
 #include "theory/quantifiers/sygus/term_database_sygus.h"
 
@@ -276,9 +278,10 @@ Node SygusUnifRl::addRefLemma(Node lemma,
   return plem;
 }
 
-void SygusUnifRl::initializeConstructSol() { d_sepPairs.clear(); }
+void SygusUnifRl::initializeConstructSol() {}
 void SygusUnifRl::initializeConstructSolFor(Node f) {}
-bool SygusUnifRl::constructSolution(std::vector<Node>& sols)
+bool SygusUnifRl::constructSolution(std::vector<Node>& sols,
+                                    std::vector<Node>& lemmas)
 {
   initializeConstructSol();
   bool successful = true;
@@ -291,7 +294,8 @@ bool SygusUnifRl::constructSolution(std::vector<Node>& sols)
       continue;
     }
     initializeConstructSolFor(c);
-    Node v = constructSol(c, d_strategy[c].getRootEnumerator(), role_equal, 0);
+    Node v = constructSol(
+        c, d_strategy[c].getRootEnumerator(), role_equal, 0, lemmas);
     if (v.isNull())
     {
       // we continue trying to build solutions to accumulate potentitial
@@ -305,7 +309,8 @@ bool SygusUnifRl::constructSolution(std::vector<Node>& sols)
   return successful;
 }
 
-Node SygusUnifRl::constructSol(Node f, Node e, NodeRole nrole, int ind)
+Node SygusUnifRl::constructSol(
+    Node f, Node e, NodeRole nrole, int ind, std::vector<Node>& lemmas)
 {
   indent("sygus-unif-sol", ind);
   Trace("sygus-unif-sol") << "ConstructSol: SygusRL : " << e << std::endl;
@@ -334,23 +339,14 @@ Node SygusUnifRl::constructSol(Node f, Node e, NodeRole nrole, int ind)
     return d_parent->getModelValue(e);
   }
   EnumTypeInfoStrat* etis = snode.d_strats[itd->second.getStrategyIndex()];
-  std::vector<Node> toSeparate;
-  Node sol = itd->second.buildSol(etis->d_cons, toSeparate);
+  Node sol = itd->second.buildSol(etis->d_cons, lemmas);
   if (sol.isNull())
   {
-    Assert(!toSeparate.empty());
-    d_sepPairs[e] = toSeparate;
+    Assert(!lemmas.empty());
   }
   return sol;
 }
 
-bool SygusUnifRl::getSeparationPairs(
-    std::map<Node, std::vector<Node>>& sepPairs)
-{
-  sepPairs = d_sepPairs;
-  return !sepPairs.empty();
-}
-
 bool SygusUnifRl::usingUnif(Node f) const
 {
   return d_unif_candidates.find(f) != d_unif_candidates.end();
@@ -363,12 +359,15 @@ Node SygusUnifRl::getConditionForEvaluationPoint(Node e) const
   return it->second.getConditionEnumerator();
 }
 
-void SygusUnifRl::setConditions(Node e, const std::vector<Node>& conds)
+void SygusUnifRl::setConditions(Node e,
+                                Node guard,
+                                const std::vector<Node>& enums,
+                                const std::vector<Node>& conds)
 {
   std::map<Node, DecisionTreeInfo>::iterator it = d_stratpt_to_dt.find(e);
   Assert(it != d_stratpt_to_dt.end());
-  // Clear previous trie
-  it->second.resetPointSeparator(conds);
+  // set the conditions for the appropriate tree
+  it->second.setConditions(guard, enums, conds);
 }
 
 std::vector<Node> SygusUnifRl::getEvalPointHeads(Node c)
@@ -487,27 +486,20 @@ void SygusUnifRl::DecisionTreeInfo::initialize(Node cond_enum,
   d_pt_sep.initialize(this);
 }
 
-void SygusUnifRl::DecisionTreeInfo::resetPointSeparator(
-    const std::vector<Node>& conds)
+void SygusUnifRl::DecisionTreeInfo::setConditions(
+    Node guard, const std::vector<Node>& enums, const std::vector<Node>& conds)
 {
-  // clear old condition values and trie
+  Assert(enums.size() == conds.size());
+  // set the guard
+  d_guard = guard;
+  // clear old condition values
+  d_enums.clear();
   d_conds.clear();
-  d_pt_sep.d_trie.clear();
   // set new condition values
+  d_enums.insert(d_enums.end(), enums.begin(), enums.end());
   d_conds.insert(d_conds.end(), conds.begin(), conds.end());
 }
 
-void SygusUnifRl::DecisionTreeInfo::addPoint(Node f)
-{
-  d_pt_sep.d_trie.add(f, &d_pt_sep, d_conds.size());
-}
-
-void SygusUnifRl::DecisionTreeInfo::addCondValue(Node condv)
-{
-  d_conds.push_back(condv);
-  d_pt_sep.d_trie.addClassifier(&d_pt_sep, d_conds.size() - 1);
-}
-
 unsigned SygusUnifRl::DecisionTreeInfo::getStrategyIndex() const
 {
   return d_strategy_index;
@@ -516,21 +508,234 @@ unsigned SygusUnifRl::DecisionTreeInfo::getStrategyIndex() const
 using UNodePair = std::pair<unsigned, Node>;
 
 Node SygusUnifRl::DecisionTreeInfo::buildSol(Node cons,
-                                             std::vector<Node>& toSeparate)
+                                             std::vector<Node>& lemmas)
 {
   if (!d_template.first.isNull())
   {
     Trace("sygus-unif-sol") << "...templated conditions unsupported\n";
     return Node::null();
   }
-  if (!isSeparated(toSeparate))
+  Trace("sygus-unif-sol") << "Decision::buildSol with " << d_hds.size()
+                          << " evaluation heads and " << d_conds.size()
+                          << " conditions..." << std::endl;
+  NodeManager* nm = NodeManager::currentNM();
+  // model values for evaluation heads
+  std::map<Node, Node> hd_mv;
+  // reset the trie
+  d_pt_sep.d_trie.clear();
+  // the current explanation of why there has not yet been a separation conflict
+  std::vector<Node> exp;
+  // is the above explanation ready to be sent out as a lemma?
+  bool exp_conflict = false;
+  // the index of the head we are considering
+  unsigned hd_counter = 0;
+  // the index of the condition we are considering
+  unsigned c_counter = 0;
+  // do we need to resolve a separation conflict?
+  bool needs_sep_resolve = false;
+  // This loop simultaneously builds the solution in terms of a lazy trie
+  // (LazyTrieMulti), and checks whether a separation conflict exists. We
+  // enforce that the separation conflicts we encounter while building
+  // this solution are resolved, in order, by the condition enumerators.
+  // If not, then we add a (conflict) lemma stating that the current model
+  // value of the condition enumerator must be different. We also call this
+  // a "separation lemma".
+  //
+  // As a simple example, say we have:
+  //   evalution heads: (eval e1 0 0), (eval e2 1 2)
+  //   conditions: c1
+  // where M(e1) = x, M(e2) = y, and M(c1) = x>1. After adding e1 and e2, we are
+  // in conflict since { e1, e2 } form a separation class, M(e1)!=M(e2), and
+  // M(c1) does not separate e1 and e2 since:
+  //   (x>1){x->0,y->0} = (x>1){x->1,y->2} = false
+  // Hence, we would fail to build a solution in this case, and instead send a
+  // separation lemma of the form:
+  //   ~( e1 != e2 ^ c1 = [x<1] )
+  //
+  // Say we have:
+  //   evalution heads: (eval e1 0 0), (eval e2 1 2), (eval e3 1 3)
+  //   conditions: c1 c2
+  // where M(e1) = x, M(e2) = y, M(e3) = x+1, M(c1) = x>0 and M(c2) = x<0.
+  // After adding e1 and e2, { e1, e2 } form a separation class, M(e1)!=M(e2),
+  // but M(c1) separates e1 and e2 since
+  //   (x>0){x->0,y->0} = false, and
+  //   (x>1){x->1,y->2} = true
+  // Hence, we get new separation classes { e1 } and { e2 }, and afterwards
+  // add e3. We then get { e2, e3 } as a separation class, which is also a
+  // conflict since M(e2)!=M(e3). We check if M(c2) resolves this conflict.
+  // It does not, since (x<1){x->0,y->0} = (x<1){x->1,y->2} = false. Hence,
+  // we get a separation lemma:
+  //  ~( c1 = [x>1] ^ e2 != e3 ^ c2 = [x<1] )
+  //
+  // Say we have:
+  //   evalution heads: (eval e1 0 0), (eval e2 1 2), (eval e3 1 3)
+  //   conditions: c1
+  // where M(e1) = x, M(e2) = x, M(e3) = y, M(c1) = x>0.
+  // After adding e1 and e2, we have separation class { e1, e2 }. This is not a
+  // conflict since M(e1)=M(e2). We then add e3, obtaining separation class
+  // { e1, e2, e3 }, which is in conflict since M(e3)!=M(e1), and the condition
+  // c1 does not separate e3 and the representative of this class, e1. Hence we
+  // get a separation lemma of the form:
+  //  ~( e1 = e2 ^ e1 != e3 ^ c1 = [x>0] )
+  //
+  // It also may be the case that we exhaust the pool of condition enumerators.
+  // Say we have:
+  //   evalution heads: (eval e1 0 0), (eval e2 1 2), (eval e3 1 3)
+  //   conditions: c1
+  // where M(e1) = x, M(e2) = x, M(e3) = y, M(c1) = y>0. After adding e1, e2,
+  // and e3, we have a separation class { e1, e2, e3 } that is in conflict
+  // since M(e3)!=M(e1). We add the condition c1, which separates into new
+  // equivalence classes { e1 }, { e2, e3 }. We are still in separation conflict
+  // since M(e3)!=M(e2). However, we do not have any further conditions to use
+  // to resolve this conflict. Thus, we add the separation lemma:
+  //  ~( e1 = e2 ^ e1 != e3 ^ e2 != e3 ^ c1 = [y>0] ^ G_1 )
+  // where G_1 is a guard stating that we use at most 1 condition.
+  Node e;
+  Node er;
+  while (hd_counter < d_hds.size() || needs_sep_resolve)
+  {
+    if (!needs_sep_resolve)
+    {
+      // add the head to the trie
+      e = d_hds[hd_counter];
+      hd_mv[e] = d_unif->d_parent->getModelValue(e);
+      if (Trace.isOn("sygus-unif-sol"))
+      {
+        std::stringstream ss;
+        Printer::getPrinter(options::outputLanguage())
+            ->toStreamSygus(ss, hd_mv[e]);
+        Trace("sygus-unif-sol")
+            << "  add evaluation head (" << hd_counter << "/" << d_hds.size()
+            << "): " << e << " -> " << ss.str() << std::endl;
+      }
+      hd_counter++;
+      // get the representative of the trie
+      er = d_pt_sep.d_trie.add(e, &d_pt_sep, c_counter);
+      Trace("sygus-unif-sol") << "  ...separation class " << er << std::endl;
+      // are we in conflict?
+      if (er == e)
+      {
+        // new separation class, no conflict
+        continue;
+      }
+      Assert(hd_mv.find(er) != hd_mv.end());
+      if (hd_mv[er] == hd_mv[e])
+      {
+        // merged into separation class with same model value, no conflict
+        // add to explanation
+        // this states that it mattered that (er = e) at the time that e was
+        // added to the trie. Notice that er and e may become separated later,
+        // but to ensure the overall invariant, this equality must persist in
+        // the explanation.
+        exp.push_back(er.eqNode(e));
+        Trace("sygus-unif-sol") << "  ...equal model values " << std::endl;
+        Trace("sygus-unif-sol")
+            << "  ...add to explanation " << er.eqNode(e) << std::endl;
+        continue;
+      }
+    }
+    // must include in the explanation that we hit a conflict at this point in
+    // the construction
+    exp.push_back(e.eqNode(er).negate());
+    // we are in separation conflict, does the next condition resolve this?
+    // check whether we have have exhausted our condition pool. If so, we
+    // are in conflict and this conflict depends on the guard.
+    if (c_counter >= d_conds.size())
+    {
+      // truncated separation lemma
+      Assert(!d_guard.isNull());
+      exp.push_back(d_guard);
+      exp_conflict = true;
+      break;
+    }
+    Assert(c_counter < d_conds.size());
+    Node ce = d_enums[c_counter];
+    Node cv = d_conds[c_counter];
+    Assert(ce.getType() == cv.getType());
+    if (Trace.isOn("sygus-unif-sol"))
+    {
+      std::stringstream ss;
+      Printer::getPrinter(options::outputLanguage())->toStreamSygus(ss, cv);
+      Trace("sygus-unif-sol")
+          << "  add condition (" << c_counter << "/" << d_conds.size()
+          << "): " << ce << " -> " << ss.str() << std::endl;
+    }
+    // cache the separation class
+    std::vector<Node> prev_sep_c = d_pt_sep.d_trie.d_rep_to_class[er];
+    // add new classifier
+    d_pt_sep.d_trie.addClassifier(&d_pt_sep, c_counter);
+    c_counter++;
+    // add to explanation
+    // c_exp is a conjunction of testers applied to shared selector chains
+    Node c_exp = d_unif->d_tds->getExplain()->getExplanationForEquality(ce, cv);
+    exp.push_back(c_exp);
+    std::map<Node, std::vector<Node>>::iterator itr =
+        d_pt_sep.d_trie.d_rep_to_class.find(e);
+    // since e is last in its separation class, if it becomes a representative,
+    // then it is separated from all values in prev_sep_c
+    if (itr != d_pt_sep.d_trie.d_rep_to_class.end())
+    {
+      Trace("sygus-unif-sol")
+          << "  ...resolves separation conflict with all" << std::endl;
+      needs_sep_resolve = false;
+      continue;
+    }
+    itr = d_pt_sep.d_trie.d_rep_to_class.find(er);
+    // since er is first in its separation class, it remains a representative
+    Assert(itr != d_pt_sep.d_trie.d_rep_to_class.end());
+    // is e still in the separation class of er?
+    if (std::find(itr->second.begin(), itr->second.end(), e)
+        != itr->second.end())
+    {
+      Trace("sygus-unif-sol")
+          << "  ...does not resolve separation conflict with current"
+          << std::endl;
+      // the condition does not separate e and er
+      // this violates the invariant that the i^th conditional enumerator
+      // resolves the i^th separation conflict
+      exp_conflict = true;
+      break;
+    }
+    Trace("sygus-unif-sol")
+        << "  ...resolves separation conflict with current, but not all"
+        << std::endl;
+    // find the new term to resolve a separation
+    Node new_er = Node::null();
+    // scan the previous list and find the representative of the class that e is
+    // now in
+    for (unsigned i = 0, size = prev_sep_c.size(); i < size; i++)
+    {
+      Node check_er = prev_sep_c[i];
+      if (check_er != er && check_er != e)
+      {
+        itr = d_pt_sep.d_trie.d_rep_to_class.find(check_er);
+        if (itr != d_pt_sep.d_trie.d_rep_to_class.end())
+        {
+          if (std::find(itr->second.begin(), itr->second.end(), e)
+              != itr->second.end())
+          {
+            new_er = check_er;
+            break;
+          }
+        }
+      }
+    }
+    // should find exactly one
+    Assert(!new_er.isNull());
+    er = new_er;
+    needs_sep_resolve = true;
+  }
+  if (exp_conflict)
   {
-    Trace("sygus-unif-sol") << "...separation check failed\n";
+    Node lemma = exp.size() == 1 ? exp[0] : nm->mkNode(AND, exp);
+    lemma = lemma.negate();
+    Trace("sygus-unif-sol") << "  ......conflict is " << lemma << std::endl;
+    lemmas.push_back(lemma);
     return Node::null();
   }
+
   Trace("sygus-unif-sol") << "...ready to build solution from DT\n";
   // Traverse trie and build ITE with cons
-  NodeManager* nm = NodeManager::currentNM();
   std::map<IndTriePair, Node> cache;
   std::map<IndTriePair, Node>::iterator it;
   std::vector<IndTriePair> visit;
@@ -551,8 +756,8 @@ Node SygusUnifRl::DecisionTreeInfo::buildSol(Node cons,
       // leaf
       if (trie->d_children.empty())
       {
-        Assert(d_hd_values.find(trie->d_lazy_child) != d_hd_values.end());
-        cache[cur] = d_hd_values[trie->d_lazy_child];
+        Assert(hd_mv.find(trie->d_lazy_child) != hd_mv.end());
+        cache[cur] = hd_mv[trie->d_lazy_child];
         Trace("sygus-unif-sol-debug")
             << "......leaf, build "
             << d_unif->d_tds->sygusToBuiltin(cache[cur], cache[cur].getType())
@@ -607,70 +812,6 @@ Node SygusUnifRl::DecisionTreeInfo::buildSol(Node cons,
   return cache[root];
 }
 
-bool SygusUnifRl::DecisionTreeInfo::isSeparated(std::vector<Node>& toSeparate)
-{
-  // build point separator
-  for (const Node& f : d_hds)
-  {
-    addPoint(f);
-  }
-  // check separation
-  d_hd_values.clear();
-  NodeManager* nm = NodeManager::currentNM();
-  for (const std::pair<const Node, std::vector<Node>>& rep_to_class :
-       d_pt_sep.d_trie.d_rep_to_class)
-  {
-    Assert(!rep_to_class.second.empty());
-    Node v = d_unif->d_parent->getModelValue(rep_to_class.second[0]);
-    if (Trace.isOn("sygus-unif-rl-dt-debug"))
-    {
-      Trace("sygus-unif-rl-dt-debug") << "...class of ("
-                                      << rep_to_class.second[0];
-      Assert(d_unif->d_hd_to_pt.find(rep_to_class.second[0])
-             != d_unif->d_hd_to_pt.end());
-      for (const Node& pti : d_unif->d_hd_to_pt[rep_to_class.second[0]])
-      {
-        Trace("sygus-unif-rl-dt-debug") << " " << pti << " ";
-      }
-      Trace("sygus-unif-rl-dt-debug") << ") "
-                                      << " with hd value " << v << "\n";
-    }
-    d_hd_values[rep_to_class.second[0]] = v;
-    unsigned i, size = rep_to_class.second.size();
-    for (i = 1; i < size; ++i)
-    {
-      Node vi = d_unif->d_parent->getModelValue(rep_to_class.second[i]);
-      Assert(d_hd_values.find(rep_to_class.second[i]) == d_hd_values.end());
-      d_hd_values[rep_to_class.second[i]] = vi;
-      if (Trace.isOn("sygus-unif-rl-dt-debug"))
-      {
-        Trace("sygus-unif-rl-dt-debug") << "...class of ("
-                                        << rep_to_class.second[i];
-        Assert(d_unif->d_hd_to_pt.find(rep_to_class.second[i])
-               != d_unif->d_hd_to_pt.end());
-        for (const Node& pti : d_unif->d_hd_to_pt[rep_to_class.second[i]])
-        {
-          Trace("sygus-unif-rl-dt-debug") << " " << pti << " ";
-        }
-        Trace("sygus-unif-rl-dt-debug") << ") "
-                                        << " with hd value " << vi << "\n";
-      }
-      // Heads with different model values
-      if (v != vi)
-      {
-        Trace("sygus-unif-rl-dt") << "...in sep class heads with diff values: "
-                                  << rep_to_class.second[0] << " and "
-                                  << rep_to_class.second[i] << "\n";
-        toSeparate.push_back(
-            nm->mkNode(EQUAL, rep_to_class.second[0], rep_to_class.second[i]));
-        // For non-separation suffices to consider one head pair per sep class
-        break;
-      }
-    }
-  }
-  return toSeparate.empty();
-}
-
 void SygusUnifRl::DecisionTreeInfo::PointSeparator::initialize(
     DecisionTreeInfo* dt)
 {
index 5bd6cdc1ef6e82f2df1a9e9e3a4fb2d791eddc91..8a5230d15da2b746051f6e2800379c792973428c 100644 (file)
@@ -59,7 +59,8 @@ class SygusUnifRl : public SygusUnif
   /** Notify enumeration (unused) */
   void notifyEnumeration(Node e, Node v, std::vector<Node>& lemmas) override;
   /** Construct solution */
-  bool constructSolution(std::vector<Node>& sols) override;
+  bool constructSolution(std::vector<Node>& sols,
+                         std::vector<Node>& lemmas) override;
   /** add refinement lemma
    *
    * This adds a lemma to the specification. It returns the purified form
@@ -90,29 +91,29 @@ class SygusUnifRl : public SygusUnif
   /** set conditional enumerators
    *
    * This informs this class that the current set of conditions for evaluation
-   * point e is conds.
+   * point e are enumerated by "enums" and have values "conds"; "guard" is
+   * Boolean variable whose semantics correspond to "there is a solution using
+   * at most enums.size() conditions."
    */
-  void setConditions(Node e, const std::vector<Node>& conds);
+  void setConditions(Node e,
+                     Node guard,
+                     const std::vector<Node>& enums,
+                     const std::vector<Node>& conds);
 
   /** retrieve the head of evaluation points for candidate c, if any */
   std::vector<Node> getEvalPointHeads(Node c);
 
-  /**
-   * if a separation condition is necessary after a failed solution
-   * construction, then sepCond is assigned a pair (e, fi = fj) in which e is
-   * the strategy point and fi, fj head of evaluation points of a unif
-   * function-to-synthesize, such that fi could not be separated from fj by the
-   * current condition values
-   */
-  bool getSeparationPairs(std::map<Node, std::vector<Node>>& sepPairs);
-
  protected:
   /** reference to the parent conjecture */
   CegConjecture* d_parent;
   /* Functions-to-synthesize (a.k.a. candidates) with unification strategies */
   std::unordered_set<Node, NodeHashFunction> d_unif_candidates;
   /** construct sol */
-  Node constructSol(Node f, Node e, NodeRole nrole, int ind) override;
+  Node constructSol(Node f,
+                    Node e,
+                    NodeRole nrole,
+                    int ind,
+                    std::vector<Node>& lemmas) override;
   /** collects data from refinement lemmas to drive solution construction
    *
    * In particular it rebuilds d_app_to_pt whenever d_prev_rlemmas is different
@@ -123,12 +124,6 @@ class SygusUnifRl : public SygusUnif
   void initializeConstructSolFor(Node f) override;
   /** maps unif functions-to~synhesize to their last built solutions */
   std::map<Node, Node> d_cand_to_sol;
-  /** pair of strategy point and equality between evaluation point heads
-   *
-   * this pair is set when a unif solution cannot be built because a two
-   * evaluation point heads cannot be separated
-   */
-  std::map<Node, std::vector<Node>> d_sepPairs;
   /*
     --------------------------------------------------------------
         Purification
@@ -211,32 +206,39 @@ class SygusUnifRl : public SygusUnif
      *
      * The DT contains a solution when no class contains two heads of evaluation
      * points with different model values, i.e. when all points that must be
-     * separated indeed are separated.
-     */
-    Node buildSol(Node cons, std::vector<Node>& toSeparate);
-    /** whether all points that must be separated are separated
+     * separated indeed are separated by the current set of conditions.
      *
-     * This function tests separation of the points in the above sense and in
-     * case two heads cannot be separated, an equality between them is created
-     * and stored in toSeparate, so that a separation lemma can be generated to
-     * guide the synthesis search to yield either conditions that will separate
-     * these heads or equal values to them.
+     * This method either returns a solution (if all points are separated).
+     * It it fails, it adds a conflict lemma to lemmas.
      */
-    bool isSeparated(std::vector<Node>& toSeparate);
+    Node buildSol(Node cons, std::vector<Node>& lemmas);
     /** reference to parent unif util */
     SygusUnifRl* d_unif;
     /** enumerator template (if no templates, nodes in pair are Node::null()) */
     NodePair d_template;
-    /** enumerated condition values */
+    /** enumerated condition values, this is set by setConditions(...). */
     std::vector<Node> d_conds;
     /** gathered evaluation point heads */
     std::vector<Node> d_hds;
     /** get condition enumerator */
     Node getConditionEnumerator() const { return d_cond_enum; }
-    /** clear trie and registered condition values */
-    void resetPointSeparator(const std::vector<Node>& conds);
+    /** set conditions */
+    void setConditions(Node guard,
+                       const std::vector<Node>& enums,
+                       const std::vector<Node>& conds);
 
    private:
+    /**
+     * Conditional enumerator variables corresponding to the condition values in
+     * d_conds. These are used for generating separation lemmas during
+     * buildSol. This is set by setConditions(...).
+     */
+    std::vector<Node> d_enums;
+    /**
+     * The guard literal whose semantics is "we need at most d_enums.size()
+     * conditions in our solution. This is set by setConditions(...).
+     */
+    Node d_guard;
     /**
      * reference to inferred strategy for the function-to-synthesize this DT is
      * associated with
@@ -254,8 +256,6 @@ class SygusUnifRl : public SygusUnif
      * decision tree.
      */
     Node d_cond_enum;
-    /** chache of model values of heads of evaluation points */
-    NodePairMap d_hd_values;
     /** Classifies evaluation points according to enumerated condition values
      *
      * Maintains the invariant that points evaluated in the same way in the
@@ -284,10 +284,6 @@ class SygusUnifRl : public SygusUnif
      * enumerated condiotion values
      */
     PointSeparator d_pt_sep;
-    /** adds the respective evaluation point of the head f to d_pt_sep */
-    void addPoint(Node f);
-    /** adds a value to the pool of condition values and to d_pt_sep */
-    void addCondValue(Node condv);
   };
   /** maps strategy points in the strategy tree to the above data */
   std::map<Node, DecisionTreeInfo> d_stratpt_to_dt;