Skolemize candidate rewrite rule checks (#1777)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Mon, 16 Apr 2018 13:14:53 +0000 (08:14 -0500)
committerGitHub <noreply@github.com>
Mon, 16 Apr 2018 13:14:53 +0000 (08:14 -0500)
src/theory/quantifiers/cegqi/inst_strategy_cbqi.cpp
src/theory/quantifiers/ematching/inst_match_generator.cpp
src/theory/quantifiers/ematching/trigger.cpp
src/theory/quantifiers/macros.cpp
src/theory/quantifiers/sygus/ce_guided_conjecture.cpp
src/theory/quantifiers/sygus/ce_guided_conjecture.h
src/theory/quantifiers/term_util.cpp
src/theory/quantifiers/term_util.h

index df04a743b4b5ae6579e487213f64aa99f0d508bd..d2aa75288c9e253dd5c4b88f46440189d2e1e410 100644 (file)
@@ -113,7 +113,7 @@ bool InstStrategyCbqi::registerCbqiLemma( Node q ) {
       //compute dependencies between quantified formulas
       if( options::cbqiLitDepend() || options::cbqiInnermost() ){
         std::vector< Node > ics;
-        TermUtil::computeVarContains( q, ics );
+        TermUtil::computeInstConstContains(q, ics);
         d_parent_quant[q].clear();
         d_children_quant[q].clear();
         std::vector< Node > dep;
index 0252def60e1db698e205fc5588f082bbd30ad253..9c3095e59b4fc92d6ebe0495e5d09cf447c20d58 100644 (file)
@@ -597,7 +597,11 @@ int VarMatchGeneratorTermSubs::getNextMatch(Node q,
 InstMatchGeneratorMultiLinear::InstMatchGeneratorMultiLinear( Node q, std::vector< Node >& pats, QuantifiersEngine* qe ) {
   //order patterns to maximize eager matching failures
   std::map< Node, std::vector< Node > > var_contains;
-  qe->getTermUtil()->getVarContains( q, pats, var_contains );
+  for (const Node& pat : pats)
+  {
+    quantifiers::TermUtil::computeInstConstContainsForQuant(
+        q, pat, var_contains[pat]);
+  }
   std::map< Node, std::vector< Node > > var_to_node;
   for( std::map< Node, std::vector< Node > >::iterator it = var_contains.begin(); it != var_contains.end(); ++it ){
     for( unsigned i=0; i<it->second.size(); i++ ){
@@ -710,7 +714,11 @@ InstMatchGeneratorMulti::InstMatchGeneratorMulti(Node q,
 {
   Trace("multi-trigger-cache") << "Making smart multi-trigger for " << q << std::endl;
   std::map< Node, std::vector< Node > > var_contains;
-  qe->getTermUtil()->getVarContains( q, pats, var_contains );
+  for (const Node& pat : pats)
+  {
+    quantifiers::TermUtil::computeInstConstContainsForQuant(
+        q, pat, var_contains[pat]);
+  }
   //convert to indicies
   for( std::map< Node, std::vector< Node > >::iterator it = var_contains.begin(); it != var_contains.end(); ++it ){
     Trace("multi-trigger-cache") << "Pattern " << it->first << " contains: ";
index cb5afbfabb12ee7c3cd11737b8e6679abc210b78..3928cf485a872eb5786f8bc72b23ba3d4e9ee0fe 100644 (file)
@@ -36,7 +36,7 @@ namespace inst {
 
 void TriggerTermInfo::init( Node q, Node n, int reqPol, Node reqPolEq ){
   if( d_fv.empty() ){
-    quantifiers::TermUtil::getVarContainsNode( q, n, d_fv );
+    quantifiers::TermUtil::computeInstConstContainsForQuant(q, n, d_fv);
   }
   if( d_reqPol==0 ){
     d_reqPol = reqPol;
@@ -134,7 +134,11 @@ bool Trigger::mkTriggerTerms( Node q, std::vector< Node >& nodes, unsigned n_var
   std::map< Node, std::vector< Node > > patterns;
   size_t varCount = 0;
   std::map< Node, std::vector< Node > > varContains;
-  quantifiers::TermUtil::getVarContains( q, temp, varContains );
+  for (const Node& pat : temp)
+  {
+    quantifiers::TermUtil::computeInstConstContainsForQuant(
+        q, pat, varContains[pat]);
+  }
   for( unsigned i=0; i<temp.size(); i++ ){
     bool foundVar = false;
     for( unsigned j=0; j<varContains[ temp[i] ].size(); j++ ){
@@ -744,7 +748,7 @@ void Trigger::filterTriggerInstances(std::vector<Node>& nodes)
   std::map<unsigned, std::vector<Node> > fvs;
   for (unsigned i = 0, size = nodes.size(); i < size; i++)
   {
-    quantifiers::TermUtil::computeVarContains(nodes[i], fvs[i]);
+    quantifiers::TermUtil::computeInstConstContains(nodes[i], fvs[i]);
   }
   std::vector<bool> active;
   active.resize(nodes.size(), true);
@@ -870,8 +874,9 @@ void Trigger::getTriggerVariables(Node n, Node q, std::vector<Node>& t_vars)
   std::vector< Node > exclude;
   collectPatTerms(q, n, patTerms, quantifiers::TRIGGER_SEL_ALL, exclude, tinfo);
   //collect all variables from all patterns in patTerms, add to t_vars
-  for( unsigned i=0; i<patTerms.size(); i++ ){
-    quantifiers::TermUtil::getVarContainsNode( q, patTerms[i], t_vars );
+  for (const Node& pat : patTerms)
+  {
+    quantifiers::TermUtil::computeInstConstContainsForQuant(q, pat, t_vars);
   }
 }
 
index cafd6e5798208f30a763c813bd55403a2aef7e56..9a6cc6e97d58d99a345fbc6770875c28169e2b5d 100644 (file)
@@ -156,7 +156,7 @@ bool QuantifierMacros::isGroundUfTerm( Node f, Node n ) {
   Node icn = d_qe->getTermUtil()->substituteBoundVariablesToInstConstants(n, f);
   Trace("macros-debug2") << "Get free variables in " << icn << std::endl;
   std::vector< Node > var;
-  d_qe->getTermUtil()->getVarContainsNode( f, icn, var );
+  quantifiers::TermUtil::computeInstConstContainsForQuant(f, icn, var);
   Trace("macros-debug2") << "Get trigger variables for " << icn << std::endl;
   std::vector< Node > trigger_var;
   inst::Trigger::getTriggerVariables( icn, f, trigger_var );
index d160581bfbc9eb87745c9eea5655d55d434d5d0a..1e0f728173d70c2e52293b8fd9bfb6734f395bfe 100644 (file)
@@ -620,11 +620,39 @@ void CegConjecture::printSynthSolution( std::ostream& out, bool singleInvocation
               // Notice we don't set produce-models. rrChecker takes the same
               // options as the SmtEngine we belong to, where we ensure that
               // produce-models is set.
-              SmtEngine rrChecker(NodeManager::currentNM()->toExprManager());
+              NodeManager* nm = NodeManager::currentNM();
+              SmtEngine rrChecker(nm->toExprManager());
               rrChecker.setLogic(smt::currentSmtEngine()->getLogicInfo());
               Node crr = solbr.eqNode(eq_solr).negate();
-              Trace("rr-check")
-                  << "Check candidate rewrite : " << crr << std::endl;
+              Trace("rr-check") << "Check candidate rewrite : " << crr
+                                << std::endl;
+              // quantify over the free variables in crr
+              std::vector<Node> fvs;
+              TermUtil::computeVarContains(crr, fvs);
+              std::map<Node, unsigned> fv_index;
+              std::vector<Node> sks;
+              if (!fvs.empty())
+              {
+                // map to skolems
+                for (unsigned i = 0, size = fvs.size(); i < size; i++)
+                {
+                  Node v = fvs[i];
+                  fv_index[v] = i;
+                  std::map<Node, Node>::iterator itf = d_fv_to_skolem.find(v);
+                  if (itf == d_fv_to_skolem.end())
+                  {
+                    Node sk = nm->mkSkolem("rrck", v.getType());
+                    d_fv_to_skolem[v] = sk;
+                    sks.push_back(sk);
+                  }
+                  else
+                  {
+                    sks.push_back(itf->second);
+                  }
+                }
+                crr = crr.substitute(
+                    fvs.begin(), fvs.end(), sks.begin(), sks.end());
+              }
               rrChecker.assertFormula(crr.toExpr());
               Result r = rrChecker.checkSat();
               Trace("rr-check") << "...result : " << r << std::endl;
@@ -639,15 +667,28 @@ void CegConjecture::printSynthSolution( std::ostream& out, bool singleInvocation
                 std::vector<Node> pt;
                 for (const Node& v : vars)
                 {
-                  Node val = Node::fromExpr(rrChecker.getValue(v.toExpr()));
-                  Trace("rr-check") << "  " << v << " -> " << val << std::endl;
+                  std::map<Node, unsigned>::iterator itf = fv_index.find(v);
+                  Node val;
+                  if (itf == fv_index.end())
+                  {
+                    // not in conjecture, can use arbitrary value
+                    val = v.getType().mkGroundTerm();
+                  }
+                  else
+                  {
+                    // get the model value of its skolem
+                    Node sk = sks[itf->second];
+                    val = Node::fromExpr(rrChecker.getValue(sk.toExpr()));
+                    Trace("rr-check") << "  " << v << " -> " << val
+                                      << std::endl;
+                  }
                   pt.push_back(val);
                 }
                 d_sampler[prog].addSamplePoint(pt);
                 // add the solution again
+                // by construction of the above point, we should be unique now
                 Node eq_sol_new = its->second.registerTerm(sol);
-                Assert(!r.asSatisfiabilityResult().isSat()
-                       || eq_sol_new == sol);
+                Assert(eq_sol_new == sol);
               }
               else
               {
index 215a4d161cfb86d3390886364e3890c62684ca48..b6812a18ae5446e1600bac650cec61152a7414b5 100644 (file)
@@ -247,6 +247,11 @@ private:
    * rewrite rules.
    */
   std::map<Node, SygusSamplerExt> d_sampler;
+  /**
+   * Cache of skolems for each free variable that appears in a synthesis check
+   * (for --sygus-rr-synth-check).
+   */
+  std::map<Node, Node> d_fv_to_skolem;
 };
 
 } /* namespace CVC4::theory::quantifiers */
index 7cebf0e1edd3422174338172e47233296bbce784..b3915bd5de5ef1f801d31b2816c2f4a6221b070b 100644 (file)
@@ -267,51 +267,74 @@ Node TermUtil::substituteInstConstants(Node n, Node q, std::vector<Node>& terms)
                       terms.end());
 }
 
-void TermUtil::computeVarContains( Node n, std::vector< Node >& varContains ) {
-  std::map< Node, bool > visited;
-  computeVarContains2( n, INST_CONSTANT, varContains, visited );
+void TermUtil::computeInstConstContains(Node n, std::vector<Node>& ics)
+{
+  computeVarContainsInternal(n, INST_CONSTANT, ics);
 }
 
-void TermUtil::computeQuantContains( Node n, std::vector< Node >& quantContains ) {
-  std::map< Node, bool > visited;
-  computeVarContains2( n, FORALL, quantContains, visited );
+void TermUtil::computeVarContains(Node n, std::vector<Node>& vars)
+{
+  computeVarContainsInternal(n, BOUND_VARIABLE, vars);
 }
 
+void TermUtil::computeQuantContains(Node n, std::vector<Node>& quants)
+{
+  computeVarContainsInternal(n, FORALL, quants);
+}
 
-void TermUtil::computeVarContains2( Node n, Kind k, std::vector< Node >& varContains, std::map< Node, bool >& visited ){
-  if( visited.find( n )==visited.end() ){
-    visited[n] = true;
-    if( n.getKind()==k ){
-      if( std::find( varContains.begin(), varContains.end(), n )==varContains.end() ){
-        varContains.push_back( n );
-      }
-    }else{
-      if (n.hasOperator())
+void TermUtil::computeVarContainsInternal(Node n,
+                                          Kind k,
+                                          std::vector<Node>& vars)
+{
+  std::unordered_set<TNode, TNodeHashFunction> visited;
+  std::unordered_set<TNode, TNodeHashFunction>::iterator it;
+  std::vector<TNode> visit;
+  TNode cur;
+  visit.push_back(n);
+  do
+  {
+    cur = visit.back();
+    visit.pop_back();
+    it = visited.find(cur);
+
+    if (it == visited.end())
+    {
+      visited.insert(cur);
+      if (cur.getKind() == k)
       {
-        computeVarContains2(n.getOperator(), k, varContains, visited);
+        if (std::find(vars.begin(), vars.end(), cur) == vars.end())
+        {
+          vars.push_back(cur);
+        }
       }
-      for( unsigned i=0; i<n.getNumChildren(); i++ ){
-        computeVarContains2( n[i], k, varContains, visited );
+      else
+      {
+        if (cur.hasOperator())
+        {
+          visit.push_back(cur.getOperator());
+        }
+        for (const Node& cn : cur)
+        {
+          visit.push_back(cn);
+        }
       }
     }
-  }
+  } while (!visit.empty());
 }
 
-void TermUtil::getVarContains( Node f, std::vector< Node >& pats, std::map< Node, std::vector< Node > >& varContains ){
-  for( unsigned i=0; i<pats.size(); i++ ){
-    varContains[ pats[i] ].clear();
-    getVarContainsNode( f, pats[i], varContains[ pats[i] ] );
-  }
-}
-
-void TermUtil::getVarContainsNode( Node f, Node n, std::vector< Node >& varContains ){
-  std::vector< Node > vars;
-  computeVarContains( n, vars );
-  for( unsigned j=0; j<vars.size(); j++ ){
-    Node v = vars[j];
-    if( v.getAttribute(InstConstantAttribute())==f ){
-      if( std::find( varContains.begin(), varContains.end(), v )==varContains.end() ){
-        varContains.push_back( v );
+void TermUtil::computeInstConstContainsForQuant(Node q,
+                                                Node n,
+                                                std::vector<Node>& vars)
+{
+  std::vector<Node> ics;
+  computeInstConstContains(n, ics);
+  for (const Node& v : ics)
+  {
+    if (v.getAttribute(InstConstantAttribute()) == q)
+    {
+      if (std::find(vars.begin(), vars.end(), v) == vars.end())
+      {
+        vars.push_back(v);
       }
     }
   }
index 6b83ad639fc1ed5fd1795d5fc25729ddbc9fc402..df88c1b30f4f651119c761061745594d1d7c4240 100644 (file)
@@ -180,23 +180,28 @@ public:
   static Node getQuantSimplify( Node n );
 
  private:
-  /** helper function for compute var contains */
-  static void computeVarContains2( Node n, Kind k, std::vector< Node >& varContains, std::map< Node, bool >& visited );
+  /** adds the set of nodes of kind k in n to vars */
+  static void computeVarContainsInternal(Node n,
+                                         Kind k,
+                                         std::vector<Node>& vars);
+
  public:
-  /** compute var contains */
-  static void computeVarContains( Node n, std::vector< Node >& varContains );
-  /** get var contains for each of the patterns in pats */
-  static void getVarContains( Node f, std::vector< Node >& pats, std::map< Node, std::vector< Node > >& varContains );
-  /** get var contains for node n */
-  static void getVarContainsNode( Node f, Node n, std::vector< Node >& varContains );
-  /** compute quant contains */
-  static void computeQuantContains( Node n, std::vector< Node >& quantContains );
-  // TODO (#1216) : this should be in trigger.h
-  /** filter all nodes that have instances */
-  static void filterInstances( std::vector< Node >& nodes );
-
-//for term ordering
-private:
+  /** adds the set of nodes of kind INST_CONSTANT in n to ics */
+  static void computeInstConstContains(Node n, std::vector<Node>& ics);
+  /** adds the set of nodes of kind BOUND_VARIABLE in n to vars */
+  static void computeVarContains(Node n, std::vector<Node>& vars);
+  /** adds the set of (top-level) nodes of kind FORALL in n to quants */
+  static void computeQuantContains(Node n, std::vector<Node>& quants);
+  /**
+   * Adds the set of nodes of kind INST_CONSTANT in n that belong to quantified
+   * formula q to vars.
+   */
+  static void computeInstConstContainsForQuant(Node q,
+                                               Node n,
+                                               std::vector<Node>& vars);
+
+  // for term ordering
+ private:
   /** operator id count */
   int d_op_id_count;
   /** map from operators to id */