Information gain heuristic for PBE (#2719)
[cvc5.git] / src / theory / quantifiers / sygus / sygus_unif_io.cpp
index 4fe3cfbed0a938797aa8de2e3f12c5fa04eb4fb5..6daeb1706759e4656132fa8b781a84b5ad1f4087 100644 (file)
@@ -21,6 +21,8 @@
 #include "theory/quantifiers/term_util.h"
 #include "util/random.h"
 
+#include <math.h>
+
 using namespace CVC4::kind;
 
 namespace CVC4 {
@@ -89,7 +91,6 @@ void UnifContextIo::initialize(SygusUnifIo* sui)
   d_str_pos.clear();
   d_curr_role = role_equal;
   d_visit_role.clear();
-  d_uinfo.clear();
 
   // initialize with #examples
   unsigned sz = sui->d_examples.size();
@@ -467,7 +468,10 @@ void SubsumeTrie::getLeaves(const std::vector<Node>& vals,
 }
 
 SygusUnifIo::SygusUnifIo()
-    : d_check_sol(false), d_cond_count(0), d_sol_cons_nondet(false)
+    : d_check_sol(false),
+      d_cond_count(0),
+      d_sol_cons_nondet(false),
+      d_solConsUsingInfoGain(false)
 {
   d_true = NodeManager::currentNM()->mkConst(true);
   d_false = NodeManager::currentNM()->mkConst(false);
@@ -495,6 +499,32 @@ void SygusUnifIo::addExample(const std::vector<Node>& input, Node output)
   d_examples_out.push_back(output);
 }
 
+void SygusUnifIo::computeExamples(Node e, Node bv, std::vector<Node>& exOut)
+{
+  std::map<Node, std::vector<Node>>& eoc = d_exOutCache[e];
+  std::map<Node, std::vector<Node>>::iterator it = eoc.find(bv);
+  if (it != eoc.end())
+  {
+    exOut.insert(exOut.end(), it->second.begin(), it->second.end());
+    return;
+  }
+  TypeNode xtn = e.getType();
+  std::vector<Node>& eocv = eoc[bv];
+  for (size_t j = 0, size = d_examples.size(); j < size; j++)
+  {
+    Node res = d_tds->evaluateBuiltin(xtn, bv, d_examples[j]);
+    exOut.push_back(res);
+    eocv.push_back(res);
+  }
+}
+
+void SygusUnifIo::clearExampleCache(Node e, Node bv)
+{
+  std::map<Node, std::vector<Node>>& eoc = d_exOutCache[e];
+  Assert(eoc.find(bv) != eoc.end());
+  eoc.erase(bv);
+}
+
 void SygusUnifIo::notifyEnumeration(Node e, Node v, std::vector<Node>& lemmas)
 {
   Trace("sygus-sui-enum") << "Notify enumeration for " << e << " : " << v
@@ -508,17 +538,15 @@ void SygusUnifIo::notifyEnumeration(Node e, Node v, std::vector<Node>& lemmas)
   // iterations.
   Node exp_exc;
 
+  std::vector<Node> base_results;
   TypeNode xtn = e.getType();
   Node bv = d_tds->sygusToBuiltin(v, xtn);
-  std::vector<Node> base_results;
-  // compte the results
-  for (unsigned j = 0, size = d_examples.size(); j < size; j++)
-  {
-    Node res = d_tds->evaluateBuiltin(xtn, bv, d_examples[j]);
-    Trace("sygus-sui-enum-debug")
-        << "...got res = " << res << " from " << bv << std::endl;
-    base_results.push_back(res);
-  }
+  bv = d_tds->getExtRewriter()->extendedRewrite(bv);
+  Trace("sygus-sui-enum") << "PBE Compute Examples for " << bv << std::endl;
+  // compte the results (should be cached)
+  computeExamples(e, bv, base_results);
+  // don't need it after this
+  clearExampleCache(e, bv);
   // get the results for each slave enumerator
   std::map<Node, std::vector<Node>> srmap;
   Evaluator* ev = d_tds->getEvaluator();
@@ -564,13 +592,15 @@ void SygusUnifIo::notifyEnumeration(Node e, Node v, std::vector<Node>& lemmas)
   std::vector<Node> exp_exc_vec;
   Assert(d_tds->isEnumerator(e));
   bool isPassive = d_tds->isPassiveEnumerator(e);
-  if (isPassive
-      && getExplanationForEnumeratorExclude(e, v, base_results, exp_exc_vec))
+  if (getExplanationForEnumeratorExclude(e, v, base_results, exp_exc_vec))
   {
-    Assert(!exp_exc_vec.empty());
-    exp_exc = exp_exc_vec.size() == 1
-                  ? exp_exc_vec[0]
-                  : NodeManager::currentNM()->mkNode(AND, exp_exc_vec);
+    if (isPassive)
+    {
+      Assert(!exp_exc_vec.empty());
+      exp_exc = exp_exc_vec.size() == 1
+                    ? exp_exc_vec[0]
+                    : NodeManager::currentNM()->mkNode(AND, exp_exc_vec);
+    }
     Trace("sygus-sui-enum")
         << "  ...fail : term is excluded (domain-specific)" << std::endl;
   }
@@ -699,12 +729,12 @@ void SygusUnifIo::notifyEnumeration(Node e, Node v, std::vector<Node>& lemmas)
           Trace("sygus-sui-enum")
               << "  ...fail : term is not unique" << std::endl;
         }
-        d_cond_count++;
       }
       if (keep)
       {
         // notify to retry the build of solution
         d_check_sol = true;
+        d_cond_count++;
         ecv.addEnumValue(v, results);
       }
     }
@@ -752,6 +782,7 @@ Node SygusUnifIo::constructSolutionNode(std::vector<Node>& lemmas)
     Trace("sygus-pbe") << "Construct solution, #iterations = " << d_cond_count
                        << std::endl;
     d_check_sol = false;
+    d_solConsUsingInfoGain = false;
     // try multiple times if we have done multiple conditions, due to
     // non-determinism
     unsigned sol_term_size = 0;
@@ -776,6 +807,14 @@ Node SygusUnifIo::constructSolutionNode(std::vector<Node>& lemmas)
         Trace("sygus-pbe") << "...solved at iteration " << i << std::endl;
         d_solution = vcc;
         sol_term_size = d_tds->getSygusTermSize(vcc);
+        // We've determined its feasible, now, enable information gain and
+        // retry. We do this since information gain comes with an overhead,
+        // and we want testing feasibility to be fast.
+        if (!d_solConsUsingInfoGain)
+        {
+          d_solConsUsingInfoGain = true;
+          i = 0;
+        }
       }
       else if (!d_sol_cons_nondet)
       {
@@ -965,7 +1004,7 @@ Node SygusUnifIo::constructSol(
         ecache.d_term_trie.getSubsumedBy(x.d_vals, true, subsumed_by);
         if (!subsumed_by.empty())
         {
-          ret_dt = constructBestSolvedTerm(subsumed_by);
+          ret_dt = constructBestSolvedTerm(e, subsumed_by);
           indent("sygus-sui-dt", ind);
           Trace("sygus-sui-dt") << "return PBE: success : conditionally solved"
                                 << d_tds->sygusToBuiltin(ret_dt) << std::endl;
@@ -1007,7 +1046,7 @@ Node SygusUnifIo::constructSol(
           }
           if (!str_solved.empty())
           {
-            ret_dt = constructBestStringSolvedTerm(str_solved);
+            ret_dt = constructBestSolvedTerm(e, str_solved);
             indent("sygus-sui-dt", ind);
             Trace("sygus-sui-dt") << "return PBE: success : string solved "
                                   << d_tds->sygusToBuiltin(ret_dt) << std::endl;
@@ -1143,7 +1182,8 @@ Node SygusUnifIo::constructSol(
   // try a random strategy
   if (snode.d_strats.size() > 1)
   {
-    std::random_shuffle(snode.d_strats.begin(), snode.d_strats.end());
+    std::shuffle(
+        snode.d_strats.begin(), snode.d_strats.end(), Random::getRandom());
   }
   // ITE always first if we have not yet solved
   // the reasoning is that splitting on conditions only subdivides the problem
@@ -1229,54 +1269,6 @@ Node SygusUnifIo::constructSol(
 
             EnumCache& ecache_child = d_ecache[ce];
 
-            // only used if the return value is not modified
-            if (!x.isReturnValueModified())
-            {
-              if (x.d_uinfo.find(ce) == x.d_uinfo.end())
-              {
-                x.d_uinfo[ce].clear();
-                Trace("sygus-sui-dt-debug2")
-                    << "  reg : PBE: Look for direct solutions for conditional "
-                       "enumerator "
-                    << ce << " ... " << std::endl;
-                Assert(ecache_child.d_enum_vals.size()
-                       == ecache_child.d_enum_vals_res.size());
-                for (unsigned i = 1; i <= 2; i++)
-                {
-                  std::pair<Node, NodeRole>& te_pair = etis->d_cenum[i];
-                  Node te = te_pair.first;
-                  EnumCache& ecache_te = d_ecache[te];
-                  bool branch_pol = (i == 1);
-                  // for each condition, get terms that satisfy it in this
-                  // branch
-                  for (unsigned k = 0, size = ecache_child.d_enum_vals.size();
-                       k < size;
-                       k++)
-                  {
-                    Node cond = ecache_child.d_enum_vals[k];
-                    std::vector<Node> solved;
-                    ecache_te.d_term_trie.getSubsumedBy(
-                        ecache_child.d_enum_vals_res[k], branch_pol, solved);
-                    Trace("sygus-sui-dt-debug2")
-                        << "  reg : PBE: " << d_tds->sygusToBuiltin(cond)
-                        << " has " << solved.size() << " solutions in branch "
-                        << i << std::endl;
-                    if (!solved.empty())
-                    {
-                      Node slv = constructBestSolvedTerm(solved);
-                      Trace("sygus-sui-dt-debug2")
-                          << "  reg : PBE: ..." << d_tds->sygusToBuiltin(slv)
-                          << " is a solution under branch " << i;
-                      Trace("sygus-sui-dt-debug2")
-                          << " of condition " << d_tds->sygusToBuiltin(cond)
-                          << std::endl;
-                      x.d_uinfo[ce].d_look_ahead_sols[cond][i] = slv;
-                    }
-                  }
-                }
-              }
-            }
-
             // get the conditionals in the current context : they must be
             // distinguishable
             std::map<int, std::vector<Node> > possible_cond;
@@ -1301,56 +1293,14 @@ Node SygusUnifIo::constructSol(
                 }
               }
 
-              // static look ahead conditional : choose conditionals that have
-              // solved terms in at least one branch
-              //    only applicable if we have not modified the return value
-              std::map<int, std::vector<Node> > solved_cond;
-              if (!x.isReturnValueModified() && !x.d_uinfo[ce].empty())
-              {
-                int solve_max = 0;
-                for (Node& cond : itpc->second)
-                {
-                  std::map<Node, std::map<unsigned, Node> >::iterator itla =
-                      x.d_uinfo[ce].d_look_ahead_sols.find(cond);
-                  if (itla != x.d_uinfo[ce].d_look_ahead_sols.end())
-                  {
-                    int nsolved = itla->second.size();
-                    solve_max = nsolved > solve_max ? nsolved : solve_max;
-                    solved_cond[nsolved].push_back(cond);
-                  }
-                }
-                int n = solve_max;
-                while (n > 0)
-                {
-                  if (!solved_cond[n].empty())
-                  {
-                    rec_c = constructBestSolvedConditional(solved_cond[n]);
-                    indent("sygus-sui-dt", ind + 1);
-                    Trace("sygus-sui-dt")
-                        << "PBE: ITE strategy : choose solved conditional "
-                        << d_tds->sygusToBuiltin(rec_c) << " with " << n
-                        << " solved children..." << std::endl;
-                    std::map<Node, std::map<unsigned, Node> >::iterator itla =
-                        x.d_uinfo[ce].d_look_ahead_sols.find(rec_c);
-                    Assert(itla != x.d_uinfo[ce].d_look_ahead_sols.end());
-                    for (std::pair<const unsigned, Node>& las : itla->second)
-                    {
-                      look_ahead_solved_children[las.first] = las.second;
-                    }
-                    break;
-                  }
-                  n--;
-                }
-              }
-
               // otherwise, guess a conditional
               if (rec_c.isNull())
               {
-                rec_c = constructBestConditional(itpc->second);
+                rec_c = constructBestConditional(ce, itpc->second);
                 Assert(!rec_c.isNull());
                 indent("sygus-sui-dt", ind);
                 Trace("sygus-sui-dt")
-                    << "PBE: ITE strategy : choose random conditional "
+                    << "PBE: ITE strategy : choose best conditional "
                     << d_tds->sygusToBuiltin(rec_c) << std::endl;
               }
             }
@@ -1427,6 +1377,100 @@ Node SygusUnifIo::constructSol(
   return ret_dt;
 }
 
+Node SygusUnifIo::constructBestConditional(Node ce,
+                                           const std::vector<Node>& conds)
+{
+  if (!d_solConsUsingInfoGain)
+  {
+    return SygusUnif::constructBestConditional(ce, conds);
+  }
+  UnifContextIo& x = d_context;
+  // use information gain heuristic
+  Trace("sygus-sui-dt-igain") << "Best information gain in context ";
+  print_val("sygus-sui-dt-igain", x.d_vals);
+  Trace("sygus-sui-dt-igain") << std::endl;
+  // set of indices that are active in this branch, i.e. x.d_vals[i] is true
+  std::vector<unsigned> activeIndices;
+  // map (j,t,s) -> n, such that the j^th condition in the vector conds
+  // evaluates to t (typically true/false) on n active I/O pairs with output s.
+  std::map<unsigned, std::map<Node, std::map<Node, unsigned>>> eval;
+  // map (j,t) -> m, such that the j^th condition in the vector conds
+  // evaluates to t (typically true/false) for m active I/O pairs.
+  std::map<unsigned, std::map<Node, unsigned>> evalCount;
+  unsigned nconds = conds.size();
+  EnumCache& ecache = d_ecache[ce];
+  // Get the index of conds[j] in the enumerator cache, this is to look up
+  // its evaluation on each point.
+  std::vector<unsigned> eindex;
+  for (unsigned j = 0; j < nconds; j++)
+  {
+    eindex.push_back(ecache.d_enum_val_to_index[conds[j]]);
+  }
+  unsigned activePoints = 0;
+  for (unsigned i = 0, npoints = x.d_vals.size(); i < npoints; i++)
+  {
+    if (x.d_vals[i].getConst<bool>())
+    {
+      activePoints++;
+      Node eo = d_examples_out[i];
+      for (unsigned j = 0; j < nconds; j++)
+      {
+        Node resn = ecache.d_enum_vals_res[eindex[j]][i];
+        Assert(resn.isConst());
+        eval[j][resn][eo]++;
+        evalCount[j][resn]++;
+      }
+    }
+  }
+  AlwaysAssert(activePoints > 0);
+  // find the condition that leads to the lowest entropy
+  // initially set minEntropy to > 1.0.
+  double minEntropy = 2.0;
+  unsigned bestIndex = 0;
+  for (unsigned j = 0; j < nconds; j++)
+  {
+    // To compute the entropy for a condition C, for pair of terms (s, t), let
+    //   prob(t) be the probability C evaluates to t on an active point,
+    //   prob(s|t) be the probability that an active point on which C
+    //     evaluates to t has output s.
+    // Then, the entropy of C is:
+    //   sum{t}. prob(t)*( sum{s}. -prob(s|t)*log2(prob(s|t)) )
+    // where notice this is always between 0 and 1.
+    double entropySum = 0.0;
+    Trace("sygus-sui-dt-igain") << j << " : ";
+    for (std::pair<const Node, std::map<Node, unsigned>>& ej : eval[j])
+    {
+      unsigned ecount = evalCount[j][ej.first];
+      if (ecount > 0)
+      {
+        double probBranch = double(ecount) / double(activePoints);
+        Trace("sygus-sui-dt-igain") << ej.first << " -> ( ";
+        for (std::pair<const Node, unsigned>& eej : ej.second)
+        {
+          if (eej.second > 0)
+          {
+            double probVal = double(eej.second) / double(ecount);
+            Trace("sygus-sui-dt-igain")
+                << eej.first << ":" << eej.second << " ";
+            double factor = -probVal * log2(probVal);
+            entropySum += probBranch * factor;
+          }
+        }
+        Trace("sygus-sui-dt-igain") << ") ";
+      }
+    }
+    Trace("sygus-sui-dt-igain") << "..." << entropySum << std::endl;
+    if (entropySum < minEntropy)
+    {
+      minEntropy = entropySum;
+      bestIndex = j;
+    }
+  }
+
+  Assert(!conds.empty());
+  return conds[bestIndex];
+}
+
 } /* CVC4::theory::quantifiers namespace */
 } /* CVC4::theory namespace */
 } /* CVC4 namespace */