Give rewrite engine pointer to conflict-based instantiation module (#3174)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Mon, 12 Aug 2019 19:23:31 +0000 (14:23 -0500)
committerGitHub <noreply@github.com>
Mon, 12 Aug 2019 19:23:31 +0000 (14:23 -0500)
src/theory/quantifiers/rewrite_engine.cpp
src/theory/quantifiers/rewrite_engine.h
src/theory/quantifiers_engine.cpp
src/theory/quantifiers_engine.h

index 10157c3b30c4e4ba2be24dbfe1a029323c48f1c8..07ff9ee4656ad49814d3bb658e3de8ed6edc6630 100644 (file)
 #include "theory/quantifiers/rewrite_engine.h"
 
 #include "options/quantifiers_options.h"
-#include "theory/quantifiers/ematching/inst_match_generator.h"
-#include "theory/quantifiers/first_order_model.h"
-#include "theory/quantifiers/fmf/model_engine.h"
-#include "theory/quantifiers/instantiate.h"
-#include "theory/quantifiers/quant_conflict_find.h"
-#include "theory/quantifiers/quant_util.h"
-#include "theory/quantifiers/quantifiers_attributes.h"
-#include "theory/quantifiers/term_database.h"
-#include "theory/quantifiers/term_util.h"
 #include "theory/quantifiers_engine.h"
-#include "theory/theory_engine.h"
 
 using namespace CVC4;
 using namespace std;
@@ -42,8 +32,11 @@ struct PrioritySort {
   }
 };
 
-
-RewriteEngine::RewriteEngine( context::Context* c, QuantifiersEngine* qe ) : QuantifiersModule(qe) {
+RewriteEngine::RewriteEngine(context::Context* c,
+                             QuantifiersEngine* qe,
+                             QuantConflictFind* qcf)
+    : QuantifiersModule(qe), d_qcf(qcf)
+{
   d_needsSort = false;
 }
 
@@ -115,101 +108,98 @@ void RewriteEngine::check(Theory::Effort e, QEffort quant_e)
 }
 
 int RewriteEngine::checkRewriteRule( Node f, Theory::Effort e ) {
-  int addedLemmas = 0;
   Trace("rewrite-engine-inst") << "Check " << d_qinfo_n[f] << ", priority = " << getPriority( f ) << ", effort = " << e << "..." << std::endl;
-  QuantConflictFind * qcf = d_quantEngine->getConflictFind();
-  if( qcf ){
-    //reset QCF module
-    qcf->setEffort(QuantConflictFind::EFFORT_CONFLICT);
-    //get the proper quantifiers info
-    std::map< Node, QuantInfo >::iterator it = d_qinfo.find( f );
-    if( it!=d_qinfo.end() ){
-      QuantInfo * qi = &it->second;
-      if( qi->matchGeneratorIsValid() ){
-        Node rr = QuantAttributes::getRewriteRule( f );
-        Trace("rewrite-engine-inst-debug") << "   Reset round..." << std::endl;
-        qi->reset_round( qcf );
-        Trace("rewrite-engine-inst-debug") << "   Get matches..." << std::endl;
-        while( !d_quantEngine->inConflict() && qi->getNextMatch( qcf ) &&
-               ( addedLemmas==0 || !options::rrOneInstPerRound() ) ){
-          Trace("rewrite-engine-inst-debug") << "   Got match to complete..." << std::endl;
-          qi->debugPrintMatch( "rewrite-engine-inst-debug" );
-          std::vector< int > assigned;
-          if( !qi->isMatchSpurious( qcf ) ){
-            bool doContinue = false;
-            bool success = true;
-            int tempAddedLemmas = 0;
-            while( !d_quantEngine->inConflict() && tempAddedLemmas==0 && success && ( addedLemmas==0 || !options::rrOneInstPerRound() ) ){
-              success = qi->completeMatch( qcf, assigned, doContinue );
-              doContinue = true;
-              if( success ){
-                Trace("rewrite-engine-inst-debug") << "   Construct match..." << std::endl;
-                std::vector< Node > inst;
-                qi->getMatch( inst );
-                Trace("rewrite-engine-inst-debug") << "   Add instantiation..." << std::endl;
-                for( unsigned i=0; i<f[0].getNumChildren(); i++ ){
-                  Trace("rewrite-engine-inst-debug") << "  " << f[0][i] << " -> ";
-                  if( i<inst.size() ){
-                    Trace("rewrite-engine-inst-debug") << inst[i] << std::endl;
-                  }else{
-                    Trace("rewrite-engine-inst-debug") << "OUT_OF_RANGE" << std::endl;
-                    Assert( false );
-                  }
-                }
-                //resize to remove auxiliary variables
-                if( inst.size()>f[0].getNumChildren() ){
-                  inst.resize( f[0].getNumChildren() );
-                }
-                if (d_quantEngine->getInstantiate()->addInstantiation(f, inst))
-                {
-                  addedLemmas++;
-                  tempAddedLemmas++;
-                  /*
-                  //remove rewritten terms from consideration
-                  std::vector< Node > to_remove;
-                  switch( rr[2].getKind() ){
-                  case kind::RR_REWRITE:
-                    to_remove.push_back( rr[2][0] );
-                    break;
-                  case kind::RR_REDUCTION:
-                    for( unsigned i=0; i<rr[2][0].getNumChildren(); i++ ){
-                      to_remove.push_back( rr[2][0][i] );
-                    }
-                    break;
-                  default:
-                    break;
-                  }
-                  for( unsigned j=0; j<to_remove.size(); j++ ){
-                    Node ns = d_quantEngine->getSubstitute( to_remove[j], inst );
-                    Trace("rewrite-engine-inst-debug") << "Will remove : " << ns << std::endl;
-                    d_quantEngine->getTermDatabase()->setTermInactive( ns );
-                  }
-                  */
-                }else{
-                  Trace("rewrite-engine-inst-debug") << "   - failed." << std::endl;
-                }
-                Trace("rewrite-engine-inst-debug") << "   Get next completion..." << std::endl;
+
+  // get the proper quantifiers info
+  std::map<Node, QuantInfo>::iterator it = d_qinfo.find(f);
+  if (it == d_qinfo.end())
+  {
+    Trace("rewrite-engine-inst-debug") << "...No qinfo." << std::endl;
+    return 0;
+  }
+  // reset QCF module
+  QuantInfo* qi = &it->second;
+  if (!qi->matchGeneratorIsValid())
+  {
+    Trace("rewrite-engine-inst-debug") << "...Invalid qinfo." << std::endl;
+    return 0;
+  }
+  d_qcf->setEffort(QuantConflictFind::EFFORT_CONFLICT);
+  Node rr = QuantAttributes::getRewriteRule(f);
+  Trace("rewrite-engine-inst-debug") << "   Reset round..." << std::endl;
+  qi->reset_round(d_qcf);
+  Trace("rewrite-engine-inst-debug") << "   Get matches..." << std::endl;
+  int addedLemmas = 0;
+  while (!d_quantEngine->inConflict() && qi->getNextMatch(d_qcf)
+         && (addedLemmas == 0 || !options::rrOneInstPerRound()))
+  {
+    Trace("rewrite-engine-inst-debug")
+        << "   Got match to complete..." << std::endl;
+    qi->debugPrintMatch("rewrite-engine-inst-debug");
+    std::vector<int> assigned;
+    if (!qi->isMatchSpurious(d_qcf))
+    {
+      bool doContinue = false;
+      bool success = true;
+      int tempAddedLemmas = 0;
+      while (!d_quantEngine->inConflict() && tempAddedLemmas == 0 && success
+             && (addedLemmas == 0 || !options::rrOneInstPerRound()))
+      {
+        success = qi->completeMatch(d_qcf, assigned, doContinue);
+        doContinue = true;
+        if (success)
+        {
+          Trace("rewrite-engine-inst-debug")
+              << "   Construct match..." << std::endl;
+          std::vector<Node> inst;
+          qi->getMatch(inst);
+          if (Trace.isOn("rewrite-engine-inst-debug"))
+          {
+            Trace("rewrite-engine-inst-debug")
+                << "   Add instantiation..." << std::endl;
+            for (unsigned i = 0, nchild = f[0].getNumChildren(); i < nchild;
+                 i++)
+            {
+              Trace("rewrite-engine-inst-debug") << "  " << f[0][i] << " -> ";
+              if (i < inst.size())
+              {
+                Trace("rewrite-engine-inst-debug") << inst[i] << std::endl;
+              }
+              else
+              {
+                Trace("rewrite-engine-inst-debug")
+                    << "OUT_OF_RANGE" << std::endl;
+                Assert(false);
               }
             }
-            //Trace("rewrite-engine-inst-debug") << "   Reverted assigned variables : ";
-            //for( unsigned a=0; a<assigned.size(); a++ ) {
-            //  Trace("rewrite-engine-inst-debug") << assigned[a] << " ";
-            //}
-            //Trace("rewrite-engine-inst-debug") << std::endl;
-            //qi->revertMatch( assigned );
-            //Assert( assigned.empty() );
-            Trace("rewrite-engine-inst-debug") << "   - failed to complete." << std::endl;
-          }else{
-            Trace("rewrite-engine-inst-debug") << "   - match is spurious." << std::endl;
           }
-          Trace("rewrite-engine-inst-debug") << "   Get next match..." << std::endl;
+          // resize to remove auxiliary variables
+          if (inst.size() > f[0].getNumChildren())
+          {
+            inst.resize(f[0].getNumChildren());
+          }
+          if (d_quantEngine->getInstantiate()->addInstantiation(f, inst))
+          {
+            addedLemmas++;
+            tempAddedLemmas++;
+          }
+          else
+          {
+            Trace("rewrite-engine-inst-debug") << "   - failed." << std::endl;
+          }
+          Trace("rewrite-engine-inst-debug")
+              << "   Get next completion..." << std::endl;
         }
-      }else{
-        Trace("rewrite-engine-inst-debug") << "...Invalid qinfo." << std::endl;
       }
-    }else{
-      Trace("rewrite-engine-inst-debug") << "...No qinfo." << std::endl;
+      Trace("rewrite-engine-inst-debug")
+          << "   - failed to complete." << std::endl;
     }
+    else
+    {
+      Trace("rewrite-engine-inst-debug")
+          << "   - match is spurious." << std::endl;
+    }
+    Trace("rewrite-engine-inst-debug") << "   Get next match..." << std::endl;
   }
   d_quantEngine->d_statistics.d_instantiations_rr += addedLemmas;
   Trace("rewrite-engine-inst") << "-> Generated " << addedLemmas << " lemmas." << std::endl;
@@ -218,71 +208,82 @@ int RewriteEngine::checkRewriteRule( Node f, Theory::Effort e ) {
 
 void RewriteEngine::registerQuantifier( Node f ) {
   Node rr = QuantAttributes::getRewriteRule( f );
-  if( !rr.isNull() ){
-    Trace("rr-register") << "Register quantifier " << f << std::endl;
-    Trace("rr-register") << "  rewrite rule is : " << rr << std::endl;
-    d_rr_quant.push_back( f );
-    d_rr[f] = rr;
-    d_needsSort = true;
-    Trace("rr-register") << "  guard is : " << d_rr[f][1] << std::endl;
+  if (rr.isNull())
+  {
+    return;
+  }
+  Trace("rr-register") << "Register quantifier " << f << std::endl;
+  Trace("rr-register") << "  rewrite rule is : " << rr << std::endl;
+  d_rr_quant.push_back(f);
+  d_rr[f] = rr;
+  d_needsSort = true;
+  Trace("rr-register") << "  guard is : " << d_rr[f][1] << std::endl;
 
-    QuantConflictFind * qcf = d_quantEngine->getConflictFind();
-    if( qcf ){
-      std::vector< Node > qcfn_c;
+  std::vector<Node> qcfn_c;
 
-      std::vector< Node > bvl;
-      for( unsigned i=0; i<f[0].getNumChildren(); i++ ){
-        bvl.push_back( f[0][i] );
-      }
+  std::vector<Node> bvl;
+  bvl.insert(bvl.end(), f[0].begin(), f[0].end());
 
-      std::vector< Node > cc;
-      //add patterns
-      for( unsigned i=1; i<f[2].getNumChildren(); i++ ){
-        std::vector< Node > nc;
-        for( unsigned j=0; j<f[2][i].getNumChildren(); j++ ){
-          Node nn;
-          Node nbv = NodeManager::currentNM()->mkBoundVar( f[2][i][j].getType() );
-          if( f[2][i][j].getType().isBoolean() && f[2][i][j].getKind()!=APPLY_UF ){
-            nn = f[2][i][j].negate();
-          }else{
-            nn = f[2][i][j].eqNode( nbv ).negate();
-            bvl.push_back( nbv );
-          }
-          nc.push_back( nn );
-        }
-        if( !nc.empty() ){
-          Node n = nc.size()==1 ? nc[0] : NodeManager::currentNM()->mkNode( AND, nc );
-          Trace("rr-register-debug") << "  pattern is " << n << std::endl;
-          if( std::find( cc.begin(), cc.end(), n )==cc.end() ){
-            cc.push_back( n );
-          }
-        }
+  NodeManager* nm = NodeManager::currentNM();
+  std::vector<Node> cc;
+  // add patterns
+  for (unsigned i = 1, nchild = f[2].getNumChildren(); i < nchild; i++)
+  {
+    std::vector<Node> nc;
+    for (const Node& pat : f[2][i])
+    {
+      Node nn;
+      Node nbv = nm->mkBoundVar(pat.getType());
+      if (pat.getType().isBoolean() && pat.getKind() != APPLY_UF)
+      {
+        nn = pat.negate();
       }
-      qcfn_c.push_back( NodeManager::currentNM()->mkNode( BOUND_VAR_LIST, bvl ) );
+      else
+      {
+        nn = pat.eqNode(nbv).negate();
+        bvl.push_back(nbv);
+      }
+      nc.push_back(nn);
+    }
+    if (!nc.empty())
+    {
+      Node n = nc.size() == 1 ? nc[0] : nm->mkNode(AND, nc);
+      Trace("rr-register-debug") << "  pattern is " << n << std::endl;
+      if (std::find(cc.begin(), cc.end(), n) == cc.end())
+      {
+        cc.push_back(n);
+      }
+    }
+  }
+  qcfn_c.push_back(nm->mkNode(BOUND_VAR_LIST, bvl));
 
-      std::vector< Node > body_c;
-      //add the guards
-      if( d_rr[f][1].getKind()==AND ){
-        for( unsigned j=0; j<d_rr[f][1].getNumChildren(); j++ ){
-          if( MatchGen::isHandled( d_rr[f][1][j] ) ){
-            body_c.push_back( d_rr[f][1][j].negate() );
-          }
-        }
-      }else if( d_rr[f][1]!=d_quantEngine->getTermUtil()->d_true ){
-        if( MatchGen::isHandled( d_rr[f][1] ) ){
-          body_c.push_back( d_rr[f][1].negate() );
-        }
+  std::vector<Node> body_c;
+  // add the guards
+  if (d_rr[f][1].getKind() == AND)
+  {
+    for (const Node& g : d_rr[f][1])
+    {
+      if (MatchGen::isHandled(g))
+      {
+        body_c.push_back(g.negate());
       }
-      //add the patterns to the body
-      body_c.push_back( cc.size()==1 ? cc[0] : NodeManager::currentNM()->mkNode( AND, cc ) );
-      //make the body
-      qcfn_c.push_back( body_c.size()==1 ? body_c[0] : NodeManager::currentNM()->mkNode( OR, body_c ) );
-      //make the quantified formula
-      d_qinfo_n[f] = NodeManager::currentNM()->mkNode( FORALL, qcfn_c );
-      Trace("rr-register") << "  qcf formula is : " << d_qinfo_n[f] << std::endl;
-      d_qinfo[f].initialize( qcf, d_qinfo_n[f], d_qinfo_n[f][1] );
     }
   }
+  else if (d_rr[f][1] != d_quantEngine->getTermUtil()->d_true)
+  {
+    if (MatchGen::isHandled(d_rr[f][1]))
+    {
+      body_c.push_back(d_rr[f][1].negate());
+    }
+  }
+  // add the patterns to the body
+  body_c.push_back(cc.size() == 1 ? cc[0] : nm->mkNode(AND, cc));
+  // make the body
+  qcfn_c.push_back(body_c.size() == 1 ? body_c[0] : nm->mkNode(OR, body_c));
+  // make the quantified formula
+  d_qinfo_n[f] = nm->mkNode(FORALL, qcfn_c);
+  Trace("rr-register") << "  qcf formula is : " << d_qinfo_n[f] << std::endl;
+  d_qinfo[f].initialize(d_qcf, d_qinfo_n[f], d_qinfo_n[f][1]);
 }
 
 void RewriteEngine::assertNode( Node n ) {
index 717f4009b2bc7a24b6a6fe142ef86826e26df807..5832d2817a82d22c7c89ac6f36028a717b32710b 100644 (file)
@@ -18,8 +18,9 @@
 #ifndef CVC4__REWRITE_ENGINE_H
 #define CVC4__REWRITE_ENGINE_H
 
-#include "context/context.h"
-#include "context/context_mm.h"
+#include <map>
+#include <vector>
+
 #include "theory/quantifiers/ematching/trigger.h"
 #include "theory/quantifiers/quant_conflict_find.h"
 #include "theory/quantifiers/quant_util.h"
@@ -30,9 +31,6 @@ namespace quantifiers {
 
 class RewriteEngine : public QuantifiersModule
 {
-  typedef context::CDHashMap<Node, bool, NodeHashFunction> NodeBoolMap;
-  typedef context::CDHashMap<Node, int, NodeHashFunction> NodeIntMap;
-  typedef context::CDHashMap<Node, Node, NodeHashFunction> NodeNodeMap;
   std::vector< Node > d_rr_quant;
   std::vector< Node > d_priority_order;
   std::map< Node, Node > d_rr;
@@ -45,10 +43,14 @@ class RewriteEngine : public QuantifiersModule
   bool d_needsSort;
   std::map< Node, std::map< Node, Node > > d_inst_const_node;
   Node getInstConstNode( Node n, Node q );
-private:
+
+ private:
   int checkRewriteRule( Node f, Theory::Effort e );
-public:
-  RewriteEngine( context::Context* c, QuantifiersEngine* qe );
+
+ public:
+  RewriteEngine(context::Context* c,
+                QuantifiersEngine* qe,
+                QuantConflictFind* qcf);
 
   bool needsCheck(Theory::Effort e) override;
   void check(Theory::Effort e, QEffort quant_e) override;
@@ -57,6 +59,14 @@ public:
   bool checkCompleteFor(Node q) override;
   /** Identify this module */
   std::string identify() const override { return "RewriteEngine"; }
+
+ private:
+  /**
+   * A pointer to the quantifiers conflict find module of the quantifiers
+   * engine. This is the module that computes instantiations for rewrite rule
+   * quantifiers.
+   */
+  QuantConflictFind* d_qcf;
 };
 
 }
index f0b0c31dfc3857ba4e4ddb9e350a14b25007aa02..c17af1e1f6cfe9f66088cfbeea369880534c6006 100644 (file)
@@ -149,7 +149,7 @@ class QuantifiersEnginePrivate
     }
     if (options::quantRewriteRules())
     {
-      d_rr_engine.reset(new quantifiers::RewriteEngine(c, qe));
+      d_rr_engine.reset(new quantifiers::RewriteEngine(c, qe, d_qcf.get()));
       modules.push_back(d_rr_engine.get());
     }
     if (options::ltePartialInst())
@@ -408,10 +408,6 @@ quantifiers::BoundedIntegers* QuantifiersEngine::getBoundedIntegers() const
 {
   return d_private->d_bint.get();
 }
-quantifiers::QuantConflictFind* QuantifiersEngine::getConflictFind() const
-{
-  return d_private->d_qcf.get();
-}
 quantifiers::SynthEngine* QuantifiersEngine::getSynthEngine() const
 {
   return d_private->d_synth_e.get();
index 7e5fe9102b08aeaf1ebd498ca9e1426fbb990933..7a9f5e7dabff27528e8e6764000e1bf43c4abdf9 100644 (file)
@@ -123,8 +123,6 @@ public:
   //---------------------- modules (TODO remove these #1163)
   /** get bounded integers utility */
   quantifiers::BoundedIntegers* getBoundedIntegers() const;
-  /** Conflict find mechanism for quantifiers */
-  quantifiers::QuantConflictFind* getConflictFind() const;
   /** ceg instantiation */
   quantifiers::SynthEngine* getSynthEngine() const;
   /** get inst strategy cbqi */