Information gain heuristic for PBE (#2719)
[cvc5.git] / src / theory / quantifiers / sygus / sygus_unif_io.cpp
index 8f2038d318c4b533c1eaf2c03a6fce1510428719..6daeb1706759e4656132fa8b781a84b5ad1f4087 100644 (file)
@@ -2,9 +2,9 @@
 /*! \file sygus_unif_io.cpp
  ** \verbatim
  ** Top contributors (to current version):
- **   Andrew Reynolds
+ **   Andrew Reynolds, Haniel Barbosa
  ** This file is part of the CVC4 project.
- ** Copyright (c) 2009-2017 by the authors listed in the file AUTHORS
+ ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS
  ** in the top-level source directory) and their institutional affiliations.
  ** All rights reserved.  See the file COPYING in the top-level source
  ** directory for licensing information.\endverbatim
 
 #include "theory/quantifiers/sygus/sygus_unif_io.h"
 
+#include "options/quantifiers_options.h"
 #include "theory/datatypes/datatypes_rewriter.h"
+#include "theory/evaluator.h"
 #include "theory/quantifiers/sygus/term_database_sygus.h"
 #include "theory/quantifiers/term_util.h"
 #include "util/random.h"
 
+#include <math.h>
+
 using namespace CVC4::kind;
 
 namespace CVC4 {
@@ -87,7 +91,6 @@ void UnifContextIo::initialize(SygusUnifIo* sui)
   d_str_pos.clear();
   d_curr_role = role_equal;
   d_visit_role.clear();
-  d_uinfo.clear();
 
   // initialize with #examples
   unsigned sz = sui->d_examples.size();
@@ -175,9 +178,14 @@ bool UnifContextIo::getStringIncrement(SygusUnifIo* sui,
         Trace("sygus-sui-dt-debug") << "X";
         return false;
       }
+      Trace("sygus-sui-dt-debug") << ival;
+      tot += ival;
+    }
+    else
+    {
+      // inactive in this context
+      Trace("sygus-sui-dt-debug") << "-";
     }
-    Trace("sygus-sui-dt-debug") << ival;
-    tot += ival;
     inc.push_back(ival);
   }
   return true;
@@ -459,26 +467,30 @@ void SubsumeTrie::getLeaves(const std::vector<Node>& vals,
   getLeavesInternal(vals, pol, v, 0, -2);
 }
 
-SygusUnifIo::SygusUnifIo() : d_check_sol(false), d_cond_count(0)
+SygusUnifIo::SygusUnifIo()
+    : d_check_sol(false),
+      d_cond_count(0),
+      d_sol_cons_nondet(false),
+      d_solConsUsingInfoGain(false)
 {
   d_true = NodeManager::currentNM()->mkConst(true);
   d_false = NodeManager::currentNM()->mkConst(false);
 }
 
 SygusUnifIo::~SygusUnifIo() {}
-void SygusUnifIo::initialize(QuantifiersEngine* qe,
-                             const std::vector<Node>& funs,
-                             std::vector<Node>& enums,
-                             std::vector<Node>& lemmas)
+void SygusUnifIo::initializeCandidate(
+    QuantifiersEngine* qe,
+    Node f,
+    std::vector<Node>& enums,
+    std::map<Node, std::vector<Node>>& strategy_lemmas)
 {
-  Assert(funs.size() == 1);
   d_examples.clear();
   d_examples_out.clear();
   d_ecache.clear();
-  d_candidate = funs[0];
-  SygusUnif::initialize(qe, funs, enums, lemmas);
+  d_candidate = f;
+  SygusUnif::initializeCandidate(qe, f, enums, strategy_lemmas);
   // learn redundant operators based on the strategy
-  d_strategy[d_candidate].staticLearnRedundantOps(lemmas);
+  d_strategy[f].staticLearnRedundantOps(strategy_lemmas);
 }
 
 void SygusUnifIo::addExample(const std::vector<Node>& input, Node output)
@@ -487,6 +499,32 @@ void SygusUnifIo::addExample(const std::vector<Node>& input, Node output)
   d_examples_out.push_back(output);
 }
 
+void SygusUnifIo::computeExamples(Node e, Node bv, std::vector<Node>& exOut)
+{
+  std::map<Node, std::vector<Node>>& eoc = d_exOutCache[e];
+  std::map<Node, std::vector<Node>>::iterator it = eoc.find(bv);
+  if (it != eoc.end())
+  {
+    exOut.insert(exOut.end(), it->second.begin(), it->second.end());
+    return;
+  }
+  TypeNode xtn = e.getType();
+  std::vector<Node>& eocv = eoc[bv];
+  for (size_t j = 0, size = d_examples.size(); j < size; j++)
+  {
+    Node res = d_tds->evaluateBuiltin(xtn, bv, d_examples[j]);
+    exOut.push_back(res);
+    eocv.push_back(res);
+  }
+}
+
+void SygusUnifIo::clearExampleCache(Node e, Node bv)
+{
+  std::map<Node, std::vector<Node>>& eoc = d_exOutCache[e];
+  Assert(eoc.find(bv) != eoc.end());
+  eoc.erase(bv);
+}
+
 void SygusUnifIo::notifyEnumeration(Node e, Node v, std::vector<Node>& lemmas)
 {
   Trace("sygus-sui-enum") << "Notify enumeration for " << e << " : " << v
@@ -500,25 +538,69 @@ void SygusUnifIo::notifyEnumeration(Node e, Node v, std::vector<Node>& lemmas)
   // iterations.
   Node exp_exc;
 
+  std::vector<Node> base_results;
   TypeNode xtn = e.getType();
   Node bv = d_tds->sygusToBuiltin(v, xtn);
-  std::vector<Node> base_results;
-  // compte the results
-  for (unsigned j = 0, size = d_examples.size(); j < size; j++)
+  bv = d_tds->getExtRewriter()->extendedRewrite(bv);
+  Trace("sygus-sui-enum") << "PBE Compute Examples for " << bv << std::endl;
+  // compte the results (should be cached)
+  computeExamples(e, bv, base_results);
+  // don't need it after this
+  clearExampleCache(e, bv);
+  // get the results for each slave enumerator
+  std::map<Node, std::vector<Node>> srmap;
+  Evaluator* ev = d_tds->getEvaluator();
+  bool tryEval = options::sygusEvalOpt();
+  for (const Node& xs : ei.d_enum_slave)
   {
-    Node res = d_tds->evaluateBuiltin(xtn, bv, d_examples[j]);
-    Trace("sygus-sui-enum-debug")
-        << "...got res = " << res << " from " << bv << std::endl;
-    base_results.push_back(res);
+    Assert(srmap.find(xs) == srmap.end());
+    EnumInfo& eiv = d_strategy[c].getEnumInfo(xs);
+    Node templ = eiv.d_template;
+    if (!templ.isNull())
+    {
+      TNode templ_var = eiv.d_template_arg;
+      std::vector<Node> args;
+      args.push_back(templ_var);
+      std::vector<Node> sresults;
+      for (const Node& res : base_results)
+      {
+        TNode tres = res;
+        std::vector<Node> vals;
+        vals.push_back(tres);
+        Node sres;
+        if (tryEval)
+        {
+          sres = ev->eval(templ, args, vals);
+        }
+        if (sres.isNull())
+        {
+          // fall back on rewriter
+          sres = templ.substitute(templ_var, tres);
+          sres = Rewriter::rewrite(sres);
+        }
+        sresults.push_back(sres);
+      }
+      srmap[xs] = sresults;
+    }
+    else
+    {
+      srmap[xs] = base_results;
+    }
   }
+
   // is it excluded for domain-specific reason?
   std::vector<Node> exp_exc_vec;
+  Assert(d_tds->isEnumerator(e));
+  bool isPassive = d_tds->isPassiveEnumerator(e);
   if (getExplanationForEnumeratorExclude(e, v, base_results, exp_exc_vec))
   {
-    Assert(!exp_exc_vec.empty());
-    exp_exc = exp_exc_vec.size() == 1
-                  ? exp_exc_vec[0]
-                  : NodeManager::currentNM()->mkNode(AND, exp_exc_vec);
+    if (isPassive)
+    {
+      Assert(!exp_exc_vec.empty());
+      exp_exc = exp_exc_vec.size() == 1
+                    ? exp_exc_vec[0]
+                    : NodeManager::currentNM()->mkNode(AND, exp_exc_vec);
+    }
     Trace("sygus-sui-enum")
         << "  ...fail : term is excluded (domain-specific)" << std::endl;
   }
@@ -530,11 +612,8 @@ void SygusUnifIo::notifyEnumeration(Node e, Node v, std::vector<Node>& lemmas)
     for (unsigned s = 0; s < ei.d_enum_slave.size(); s++)
     {
       Node xs = ei.d_enum_slave[s];
-
       EnumInfo& eiv = d_strategy[c].getEnumInfo(xs);
-
       EnumCache& ecv = d_ecache[xs];
-
       Trace("sygus-sui-enum") << "Process " << xs << " from " << s << std::endl;
       // bool prevIsCover = false;
       if (eiv.getRole() == enum_io)
@@ -549,20 +628,13 @@ void SygusUnifIo::notifyEnumeration(Node e, Node v, std::vector<Node>& lemmas)
       Trace("sygus-sui-enum") << xs << " : ";
       // evaluate all input/output examples
       std::vector<Node> results;
-      Node templ = eiv.d_template;
-      TNode templ_var = eiv.d_template_arg;
       std::map<Node, bool> cond_vals;
-      for (unsigned j = 0, size = base_results.size(); j < size; j++)
+      std::map<Node, std::vector<Node>>::iterator itsr = srmap.find(xs);
+      Assert(itsr != srmap.end());
+      for (unsigned j = 0, size = itsr->second.size(); j < size; j++)
       {
-        Node res = base_results[j];
+        Node res = itsr->second[j];
         Assert(res.isConst());
-        if (!templ.isNull())
-        {
-          TNode tres = res;
-          res = templ.substitute(templ_var, res);
-          res = Rewriter::rewrite(res);
-          Assert(res.isConst());
-        }
         Node resb;
         if (eiv.getRole() == enum_io)
         {
@@ -657,32 +729,37 @@ void SygusUnifIo::notifyEnumeration(Node e, Node v, std::vector<Node>& lemmas)
           Trace("sygus-sui-enum")
               << "  ...fail : term is not unique" << std::endl;
         }
-        d_cond_count++;
       }
       if (keep)
       {
         // notify to retry the build of solution
         d_check_sol = true;
+        d_cond_count++;
         ecv.addEnumValue(v, results);
       }
     }
   }
 
-  // exclude this value on subsequent iterations
-  if (exp_exc.isNull())
+  if (isPassive)
   {
-    // if we did not already explain why this should be excluded, use default
-    exp_exc = d_tds->getExplain()->getExplanationForEquality(e, v);
+    // exclude this value on subsequent iterations
+    if (exp_exc.isNull())
+    {
+      Trace("sygus-sui-enum-lemma") << "Use basic exclusion." << std::endl;
+      // if we did not already explain why this should be excluded, use default
+      exp_exc = d_tds->getExplain()->getExplanationForEquality(e, v);
+    }
+    exp_exc = exp_exc.negate();
+    Trace("sygus-sui-enum-lemma")
+        << "SygusUnifIo : enumeration exclude lemma : " << exp_exc << std::endl;
+    lemmas.push_back(exp_exc);
   }
-  exp_exc = exp_exc.negate();
-  Trace("sygus-sui-enum-lemma")
-      << "SygusUnifIo : enumeration exclude lemma : " << exp_exc << std::endl;
-  lemmas.push_back(exp_exc);
 }
 
-bool SygusUnifIo::constructSolution(std::vector<Node>& sols)
+bool SygusUnifIo::constructSolution(std::vector<Node>& sols,
+                                    std::vector<Node>& lemmas)
 {
-  Node sol = constructSolutionNode();
+  Node sol = constructSolutionNode(lemmas);
   if (!sol.isNull())
   {
     sols.push_back(sol);
@@ -691,7 +768,7 @@ bool SygusUnifIo::constructSolution(std::vector<Node>& sols)
   return false;
 }
 
-Node SygusUnifIo::constructSolutionNode()
+Node SygusUnifIo::constructSolutionNode(std::vector<Node>& lemmas)
 {
   Node c = d_candidate;
   if (!d_solution.isNull())
@@ -705,9 +782,10 @@ Node SygusUnifIo::constructSolutionNode()
     Trace("sygus-pbe") << "Construct solution, #iterations = " << d_cond_count
                        << std::endl;
     d_check_sol = false;
+    d_solConsUsingInfoGain = false;
     // try multiple times if we have done multiple conditions, due to
     // non-determinism
-    Node vc;
+    unsigned sol_term_size = 0;
     for (unsigned i = 0; i <= d_cond_count; i++)
     {
       Trace("sygus-pbe-dt") << "ConstructPBE for candidate: " << c << std::endl;
@@ -716,24 +794,36 @@ Node SygusUnifIo::constructSolutionNode()
       initializeConstructSolFor(c);
       // call the virtual construct solution method
       Node e = d_strategy[c].getRootEnumerator();
-      Node vcc = constructSol(c, e, role_equal, 1);
+      Node vcc = constructSol(c, e, role_equal, 1, lemmas);
       // if we constructed the solution, and we either did not previously have
       // a solution, or the new solution is better (smaller).
       if (!vcc.isNull()
-          && (vc.isNull() || (!vc.isNull()
-                              && d_tds->getSygusTermSize(vcc)
-                                     < d_tds->getSygusTermSize(vc))))
+          && (d_solution.isNull()
+              || (!d_solution.isNull()
+                  && d_tds->getSygusTermSize(vcc) < sol_term_size)))
       {
         Trace("sygus-pbe") << "**** SygusUnif SOLVED : " << c << " = " << vcc
                            << std::endl;
         Trace("sygus-pbe") << "...solved at iteration " << i << std::endl;
-        vc = vcc;
+        d_solution = vcc;
+        sol_term_size = d_tds->getSygusTermSize(vcc);
+        // We've determined its feasible, now, enable information gain and
+        // retry. We do this since information gain comes with an overhead,
+        // and we want testing feasibility to be fast.
+        if (!d_solConsUsingInfoGain)
+        {
+          d_solConsUsingInfoGain = true;
+          i = 0;
+        }
+      }
+      else if (!d_sol_cons_nondet)
+      {
+        break;
       }
     }
-    if (!vc.isNull())
+    if (!d_solution.isNull())
     {
-      d_solution = vc;
-      return vc;
+      return d_solution;
     }
     Trace("sygus-pbe") << "...failed to solve." << std::endl;
   }
@@ -783,14 +873,15 @@ bool SygusUnifIo::useStrContainsEnumeratorExclude(Node e)
   return false;
 }
 
-bool SygusUnifIo::getExplanationForEnumeratorExclude(Node e,
-                                                     Node v,
-                                                     std::vector<Node>& results,
-                                                     std::vector<Node>& exp)
+bool SygusUnifIo::getExplanationForEnumeratorExclude(
+    Node e,
+    Node v,
+    std::vector<Node>& results,
+    std::vector<Node>& exp)
 {
+  NodeManager* nm = NodeManager::currentNM();
   if (useStrContainsEnumeratorExclude(e))
   {
-    NodeManager* nm = NodeManager::currentNM();
     // This check whether the example evaluates to something that is larger than
     // the output for some input/output pair. If so, then this term is never
     // useful. We generalize its explanation below.
@@ -848,13 +939,19 @@ void SygusUnifIo::EnumCache::addEnumValue(Node v, std::vector<Node>& results)
   d_enum_vals_res.push_back(results);
 }
 
-void SygusUnifIo::initializeConstructSol() { d_context.initialize(this); }
+void SygusUnifIo::initializeConstructSol()
+{
+  d_context.initialize(this);
+  d_sol_cons_nondet = false;
+}
+
 void SygusUnifIo::initializeConstructSolFor(Node f)
 {
   Assert(d_candidate == f);
 }
 
-Node SygusUnifIo::constructSol(Node f, Node e, NodeRole nrole, int ind)
+Node SygusUnifIo::constructSol(
+    Node f, Node e, NodeRole nrole, int ind, std::vector<Node>& lemmas)
 {
   Assert(d_candidate == f);
   UnifContextIo& x = d_context;
@@ -907,7 +1004,7 @@ Node SygusUnifIo::constructSol(Node f, Node e, NodeRole nrole, int ind)
         ecache.d_term_trie.getSubsumedBy(x.d_vals, true, subsumed_by);
         if (!subsumed_by.empty())
         {
-          ret_dt = constructBestSolvedTerm(subsumed_by);
+          ret_dt = constructBestSolvedTerm(e, subsumed_by);
           indent("sygus-sui-dt", ind);
           Trace("sygus-sui-dt") << "return PBE: success : conditionally solved"
                                 << d_tds->sygusToBuiltin(ret_dt) << std::endl;
@@ -949,7 +1046,7 @@ Node SygusUnifIo::constructSol(Node f, Node e, NodeRole nrole, int ind)
           }
           if (!str_solved.empty())
           {
-            ret_dt = constructBestStringSolvedTerm(str_solved);
+            ret_dt = constructBestSolvedTerm(e, str_solved);
             indent("sygus-sui-dt", ind);
             Trace("sygus-sui-dt") << "return PBE: success : string solved "
                                   << d_tds->sygusToBuiltin(ret_dt) << std::endl;
@@ -996,8 +1093,9 @@ Node SygusUnifIo::constructSol(Node f, Node e, NodeRole nrole, int ind)
         Node val_t = ecache.d_enum_vals[i];
         Assert(incr.find(val_t) == incr.end());
         indent("sygus-sui-dt-debug", ind);
-        Trace("sygus-sui-dt-debug")
-            << "increment string values : " << val_t << " : ";
+        Trace("sygus-sui-dt-debug") << "increment string values : ";
+        TermDbSygus::toStreamSygus("sygus-sui-dt-debug", val_t);
+        Trace("sygus-sui-dt-debug") << " : ";
         Assert(ecache.d_enum_vals_res[i].size() == d_examples_out.size());
         unsigned tot = 0;
         bool exsuccess = x.getStringIncrement(this,
@@ -1022,6 +1120,9 @@ Node SygusUnifIo::constructSol(Node f, Node e, NodeRole nrole, int ind)
 
       if (!incr.empty())
       {
+        // solution construction for strings concatenation is non-deterministic
+        // with respect to failure/success.
+        d_sol_cons_nondet = true;
         ret_dt = constructBestStringToConcat(inc_strs, total_inc, incr);
         Assert(!ret_dt.isNull());
         indent("sygus-sui-dt", ind);
@@ -1048,40 +1149,70 @@ Node SygusUnifIo::constructSol(Node f, Node e, NodeRole nrole, int ind)
           << std::endl;
     }
   }
-  if (ret_dt.isNull() && !einfo.isTemplated())
+  if (!ret_dt.isNull() || einfo.isTemplated())
+  {
+    Assert(ret_dt.isNull() || ret_dt.getType() == e.getType());
+    indent("sygus-sui-dt", ind);
+    Trace("sygus-sui-dt") << "ConstructPBE: returned (pre-strategy) " << ret_dt
+                          << std::endl;
+    return ret_dt;
+  }
+  // we will try a single strategy
+  EnumTypeInfoStrat* etis = nullptr;
+  std::map<NodeRole, StrategyNode>::iterator itsn = tinfo.d_snodes.find(nrole);
+  if (itsn == tinfo.d_snodes.end())
+  {
+    indent("sygus-sui-dt", ind);
+    Trace("sygus-sui-dt") << "ConstructPBE: returned (no-strategy) " << ret_dt
+                          << std::endl;
+    return ret_dt;
+  }
+  // strategy info
+  StrategyNode& snode = itsn->second;
+  if (x.d_visit_role[e].find(nrole) != x.d_visit_role[e].end())
+  {
+    // already visited and context not changed (notice d_visit_role is cleared
+    // when the context changes).
+    indent("sygus-sui-dt", ind);
+    Trace("sygus-sui-dt") << "ConstructPBE: returned (already visited) "
+                          << ret_dt << std::endl;
+    return ret_dt;
+  }
+  x.d_visit_role[e][nrole] = true;
+  // try a random strategy
+  if (snode.d_strats.size() > 1)
+  {
+    std::shuffle(
+        snode.d_strats.begin(), snode.d_strats.end(), Random::getRandom());
+  }
+  // ITE always first if we have not yet solved
+  // the reasoning is that splitting on conditions only subdivides the problem
+  // and cannot be the source of failure, whereas the wrong choice for a
+  // concatenation term may lead to failure
+  if (d_solution.isNull())
   {
-    // we will try a single strategy
-    EnumTypeInfoStrat* etis = nullptr;
-    std::map<NodeRole, StrategyNode>::iterator itsn =
-        tinfo.d_snodes.find(nrole);
-    if (itsn != tinfo.d_snodes.end())
+    for (unsigned i = 0; i < snode.d_strats.size(); i++)
     {
-      // strategy info
-      StrategyNode& snode = itsn->second;
-      if (x.d_visit_role[e].find(nrole) == x.d_visit_role[e].end())
+      if (snode.d_strats[i]->d_this == strat_ITE)
       {
-        x.d_visit_role[e][nrole] = true;
-        // try a random strategy
-        if (snode.d_strats.size() > 1)
-        {
-          std::random_shuffle(snode.d_strats.begin(), snode.d_strats.end());
-        }
-        // get an eligible strategy index
-        unsigned sindex = 0;
-        while (sindex < snode.d_strats.size()
-               && !snode.d_strats[sindex]->isValid(x))
-        {
-          sindex++;
-        }
-        // if we found a eligible strategy
-        if (sindex < snode.d_strats.size())
-        {
-          etis = snode.d_strats[sindex];
-        }
+        // flip the two
+        EnumTypeInfoStrat* etis = snode.d_strats[i];
+        snode.d_strats[i] = snode.d_strats[0];
+        snode.d_strats[0] = etis;
+        break;
       }
     }
-    if (etis != nullptr)
+  }
+
+  // iterate over the strategies
+  unsigned sindex = 0;
+  bool did_recurse = false;
+  while (ret_dt.isNull() && !did_recurse && sindex < snode.d_strats.size())
+  {
+    if (snode.d_strats[sindex]->isValid(x))
     {
+      etis = snode.d_strats[sindex];
+      Assert(etis != nullptr);
       StrategyType strat = etis->d_this;
       indent("sygus-sui-dt", ind + 1);
       Trace("sygus-sui-dt")
@@ -1093,7 +1224,8 @@ Node SygusUnifIo::constructSol(Node f, Node e, NodeRole nrole, int ind)
 
       // for ITE
       Node split_cond_enum;
-      int split_cond_res_index = -1;
+      unsigned split_cond_res_index = 0;
+      CVC4_UNUSED bool set_split_cond_res_index = false;
 
       for (unsigned sc = 0, size = etis->d_cenum.size(); sc < size; sc++)
       {
@@ -1120,9 +1252,8 @@ Node SygusUnifIo::constructSol(Node f, Node e, NodeRole nrole, int ind)
           if (strat == strat_ITE && sc > 0)
           {
             EnumCache& ecache_cond = d_ecache[split_cond_enum];
-            Assert(split_cond_res_index >= 0);
-            Assert(split_cond_res_index
-                   < (int)ecache_cond.d_enum_vals_res.size());
+            Assert(set_split_cond_res_index);
+            Assert(split_cond_res_index < ecache_cond.d_enum_vals_res.size());
             prev = x.d_vals;
             bool ret = x.updateContext(
                 this,
@@ -1138,53 +1269,6 @@ Node SygusUnifIo::constructSol(Node f, Node e, NodeRole nrole, int ind)
 
             EnumCache& ecache_child = d_ecache[ce];
 
-            // only used if the return value is not modified
-            if (!x.isReturnValueModified())
-            {
-              if (x.d_uinfo.find(ce) == x.d_uinfo.end())
-              {
-                Trace("sygus-sui-dt-debug2")
-                    << "  reg : PBE: Look for direct solutions for conditional "
-                       "enumerator "
-                    << ce << " ... " << std::endl;
-                Assert(ecache_child.d_enum_vals.size()
-                       == ecache_child.d_enum_vals_res.size());
-                for (unsigned i = 1; i <= 2; i++)
-                {
-                  std::pair<Node, NodeRole>& te_pair = etis->d_cenum[i];
-                  Node te = te_pair.first;
-                  EnumCache& ecache_te = d_ecache[te];
-                  bool branch_pol = (i == 1);
-                  // for each condition, get terms that satisfy it in this
-                  // branch
-                  for (unsigned k = 0, size = ecache_child.d_enum_vals.size();
-                       k < size;
-                       k++)
-                  {
-                    Node cond = ecache_child.d_enum_vals[k];
-                    std::vector<Node> solved;
-                    ecache_te.d_term_trie.getSubsumedBy(
-                        ecache_child.d_enum_vals_res[k], branch_pol, solved);
-                    Trace("sygus-sui-dt-debug2")
-                        << "  reg : PBE: " << d_tds->sygusToBuiltin(cond)
-                        << " has " << solved.size() << " solutions in branch "
-                        << i << std::endl;
-                    if (!solved.empty())
-                    {
-                      Node slv = constructBestSolvedTerm(solved);
-                      Trace("sygus-sui-dt-debug2")
-                          << "  reg : PBE: ..." << d_tds->sygusToBuiltin(slv)
-                          << " is a solution under branch " << i;
-                      Trace("sygus-sui-dt-debug2")
-                          << " of condition " << d_tds->sygusToBuiltin(cond)
-                          << std::endl;
-                      x.d_uinfo[ce].d_look_ahead_sols[cond][i] = slv;
-                    }
-                  }
-                }
-              }
-            }
-
             // get the conditionals in the current context : they must be
             // distinguishable
             std::map<int, std::vector<Node> > possible_cond;
@@ -1209,57 +1293,14 @@ Node SygusUnifIo::constructSol(Node f, Node e, NodeRole nrole, int ind)
                 }
               }
 
-              // static look ahead conditional : choose conditionals that have
-              // solved terms in at least one branch
-              //    only applicable if we have not modified the return value
-              std::map<int, std::vector<Node> > solved_cond;
-              if (!x.isReturnValueModified())
-              {
-                Assert(x.d_uinfo.find(ce) != x.d_uinfo.end());
-                int solve_max = 0;
-                for (Node& cond : itpc->second)
-                {
-                  std::map<Node, std::map<unsigned, Node> >::iterator itla =
-                      x.d_uinfo[ce].d_look_ahead_sols.find(cond);
-                  if (itla != x.d_uinfo[ce].d_look_ahead_sols.end())
-                  {
-                    int nsolved = itla->second.size();
-                    solve_max = nsolved > solve_max ? nsolved : solve_max;
-                    solved_cond[nsolved].push_back(cond);
-                  }
-                }
-                int n = solve_max;
-                while (n > 0)
-                {
-                  if (!solved_cond[n].empty())
-                  {
-                    rec_c = constructBestSolvedConditional(solved_cond[n]);
-                    indent("sygus-sui-dt", ind + 1);
-                    Trace("sygus-sui-dt")
-                        << "PBE: ITE strategy : choose solved conditional "
-                        << d_tds->sygusToBuiltin(rec_c) << " with " << n
-                        << " solved children..." << std::endl;
-                    std::map<Node, std::map<unsigned, Node> >::iterator itla =
-                        x.d_uinfo[ce].d_look_ahead_sols.find(rec_c);
-                    Assert(itla != x.d_uinfo[ce].d_look_ahead_sols.end());
-                    for (std::pair<const unsigned, Node>& las : itla->second)
-                    {
-                      look_ahead_solved_children[las.first] = las.second;
-                    }
-                    break;
-                  }
-                  n--;
-                }
-              }
-
               // otherwise, guess a conditional
               if (rec_c.isNull())
               {
-                rec_c = constructBestConditional(itpc->second);
+                rec_c = constructBestConditional(ce, itpc->second);
                 Assert(!rec_c.isNull());
                 indent("sygus-sui-dt", ind);
                 Trace("sygus-sui-dt")
-                    << "PBE: ITE strategy : choose random conditional "
+                    << "PBE: ITE strategy : choose best conditional "
                     << d_tds->sygusToBuiltin(rec_c) << std::endl;
               }
             }
@@ -1277,15 +1318,16 @@ Node SygusUnifIo::constructSol(Node f, Node e, NodeRole nrole, int ind)
               Assert(ecache_child.d_enum_val_to_index.find(rec_c)
                      != ecache_child.d_enum_val_to_index.end());
               split_cond_res_index = ecache_child.d_enum_val_to_index[rec_c];
+              set_split_cond_res_index = true;
               split_cond_enum = ce;
-              Assert(split_cond_res_index >= 0);
               Assert(split_cond_res_index
-                     < (int)ecache_child.d_enum_vals_res.size());
+                     < ecache_child.d_enum_vals_res.size());
             }
           }
           else
           {
-            rec_c = constructSol(f, cenum.first, cenum.second, ind + 2);
+            did_recurse = true;
+            rec_c = constructSol(f, cenum.first, cenum.second, ind + 2, lemmas);
           }
 
           // undo update the context
@@ -1325,17 +1367,110 @@ Node SygusUnifIo::constructSol(Node f, Node e, NodeRole nrole, int ind)
             << "PBE: failed for strategy " << strat << std::endl;
       }
     }
+    // increment
+    sindex++;
   }
 
-  if (!ret_dt.isNull())
-  {
-    Assert(ret_dt.getType() == e.getType());
-  }
+  Assert(ret_dt.isNull() || ret_dt.getType() == e.getType());
   indent("sygus-sui-dt", ind);
   Trace("sygus-sui-dt") << "ConstructPBE: returned " << ret_dt << std::endl;
   return ret_dt;
 }
 
+Node SygusUnifIo::constructBestConditional(Node ce,
+                                           const std::vector<Node>& conds)
+{
+  if (!d_solConsUsingInfoGain)
+  {
+    return SygusUnif::constructBestConditional(ce, conds);
+  }
+  UnifContextIo& x = d_context;
+  // use information gain heuristic
+  Trace("sygus-sui-dt-igain") << "Best information gain in context ";
+  print_val("sygus-sui-dt-igain", x.d_vals);
+  Trace("sygus-sui-dt-igain") << std::endl;
+  // set of indices that are active in this branch, i.e. x.d_vals[i] is true
+  std::vector<unsigned> activeIndices;
+  // map (j,t,s) -> n, such that the j^th condition in the vector conds
+  // evaluates to t (typically true/false) on n active I/O pairs with output s.
+  std::map<unsigned, std::map<Node, std::map<Node, unsigned>>> eval;
+  // map (j,t) -> m, such that the j^th condition in the vector conds
+  // evaluates to t (typically true/false) for m active I/O pairs.
+  std::map<unsigned, std::map<Node, unsigned>> evalCount;
+  unsigned nconds = conds.size();
+  EnumCache& ecache = d_ecache[ce];
+  // Get the index of conds[j] in the enumerator cache, this is to look up
+  // its evaluation on each point.
+  std::vector<unsigned> eindex;
+  for (unsigned j = 0; j < nconds; j++)
+  {
+    eindex.push_back(ecache.d_enum_val_to_index[conds[j]]);
+  }
+  unsigned activePoints = 0;
+  for (unsigned i = 0, npoints = x.d_vals.size(); i < npoints; i++)
+  {
+    if (x.d_vals[i].getConst<bool>())
+    {
+      activePoints++;
+      Node eo = d_examples_out[i];
+      for (unsigned j = 0; j < nconds; j++)
+      {
+        Node resn = ecache.d_enum_vals_res[eindex[j]][i];
+        Assert(resn.isConst());
+        eval[j][resn][eo]++;
+        evalCount[j][resn]++;
+      }
+    }
+  }
+  AlwaysAssert(activePoints > 0);
+  // find the condition that leads to the lowest entropy
+  // initially set minEntropy to > 1.0.
+  double minEntropy = 2.0;
+  unsigned bestIndex = 0;
+  for (unsigned j = 0; j < nconds; j++)
+  {
+    // To compute the entropy for a condition C, for pair of terms (s, t), let
+    //   prob(t) be the probability C evaluates to t on an active point,
+    //   prob(s|t) be the probability that an active point on which C
+    //     evaluates to t has output s.
+    // Then, the entropy of C is:
+    //   sum{t}. prob(t)*( sum{s}. -prob(s|t)*log2(prob(s|t)) )
+    // where notice this is always between 0 and 1.
+    double entropySum = 0.0;
+    Trace("sygus-sui-dt-igain") << j << " : ";
+    for (std::pair<const Node, std::map<Node, unsigned>>& ej : eval[j])
+    {
+      unsigned ecount = evalCount[j][ej.first];
+      if (ecount > 0)
+      {
+        double probBranch = double(ecount) / double(activePoints);
+        Trace("sygus-sui-dt-igain") << ej.first << " -> ( ";
+        for (std::pair<const Node, unsigned>& eej : ej.second)
+        {
+          if (eej.second > 0)
+          {
+            double probVal = double(eej.second) / double(ecount);
+            Trace("sygus-sui-dt-igain")
+                << eej.first << ":" << eej.second << " ";
+            double factor = -probVal * log2(probVal);
+            entropySum += probBranch * factor;
+          }
+        }
+        Trace("sygus-sui-dt-igain") << ") ";
+      }
+    }
+    Trace("sygus-sui-dt-igain") << "..." << entropySum << std::endl;
+    if (entropySum < minEntropy)
+    {
+      minEntropy = entropySum;
+      bestIndex = j;
+    }
+  }
+
+  Assert(!conds.empty());
+  return conds[bestIndex];
+}
+
 } /* CVC4::theory::quantifiers namespace */
 } /* CVC4::theory namespace */
 } /* CVC4 namespace */