Cleanup and additions for candidate generator (#2173)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 20 Jul 2018 23:42:25 +0000 (01:42 +0200)
committerGitHub <noreply@github.com>
Fri, 20 Jul 2018 23:42:25 +0000 (01:42 +0200)
src/theory/quantifiers/ematching/candidate_generator.cpp
src/theory/quantifiers/ematching/candidate_generator.h
src/theory/quantifiers/ematching/inst_match_generator.cpp

index 96719cc0fe0f29cec2c64869bea381ac40702fcb..4208b11ae221614990e112d587bf4cc1a838f217 100644 (file)
@@ -33,38 +33,10 @@ bool CandidateGenerator::isLegalCandidate( Node n ){
   return d_qe->getTermDatabase()->isTermActive( n ) && ( !options::cbqi() || !quantifiers::TermUtil::hasInstConstAttr(n) );
 }
 
-void CandidateGeneratorQueue::addCandidate( Node n ) {
-  if( isLegalCandidate( n ) ){
-    d_candidates.push_back( n );
-  }
-}
-
-void CandidateGeneratorQueue::reset( Node eqc ){
-  if( d_candidate_index>0 ){
-    d_candidates.erase( d_candidates.begin(), d_candidates.begin() + d_candidate_index );
-    d_candidate_index = 0;
-  }
-  if( !eqc.isNull() ){
-    d_candidates.push_back( eqc );
-  }
-}
-Node CandidateGeneratorQueue::getNextCandidate(){
-  if( d_candidate_index<(int)d_candidates.size() ){
-    Node n = d_candidates[d_candidate_index];
-    d_candidate_index++;
-    return n;
-  }else{
-    d_candidate_index = 0;
-    d_candidates.clear();
-    return Node::null();
-  }
-}
-
 CandidateGeneratorQE::CandidateGeneratorQE( QuantifiersEngine* qe, Node pat ) :
 CandidateGenerator( qe ), d_term_iter( -1 ){
   d_op = qe->getTermDatabase()->getMatchOperator( pat );
   Assert( !d_op.isNull() );
-  d_op_arity = pat.getNumChildren();
 }
 
 void CandidateGeneratorQE::resetInstantiationRound(){
@@ -83,22 +55,16 @@ void CandidateGeneratorQE::reset( Node eqc ){
       if( ee->hasTerm( eqc ) ){
         quantifiers::TermArgTrie * tat = d_qe->getTermDatabase()->getTermArgTrie( eqc, d_op );
         if( tat ){
-#if 1
           //create an equivalence class iterator in eq class eqc
           Node rep = ee->getRepresentative( eqc );
           d_eqc_iter = eq::EqClassIterator( rep, ee );
           d_mode = cand_term_eqc;
-#else
-          d_tindex.push_back( tat );
-          d_tindex_iter.push_back( tat->d_data.begin() );
-          d_mode = cand_term_tindex;
-#endif     
         }else{
           d_mode = cand_term_none;
         }   
       }else{
         //the only match is this term itself
-        d_n = eqc;
+        d_eqc = eqc;
         d_mode = cand_term_ident;
       }
     }
@@ -144,41 +110,12 @@ Node CandidateGeneratorQE::getNextCandidate(){
         return n;
       }
     }
-  }else if( d_mode==cand_term_tindex ){
-    Debug("cand-gen-qe") << "...get next candidate in tindex " << d_op << " " << d_op_arity << std::endl;
-    //increment the term index iterator
-    if( !d_tindex.empty() ){
-      //populate the vector
-      while( d_tindex_iter.size()<=d_op_arity ){
-        Assert( !d_tindex_iter.empty() );
-        Assert( !d_tindex_iter.back()->second.d_data.empty() );
-        d_tindex.push_back( &(d_tindex_iter.back()->second) );
-        d_tindex_iter.push_back( d_tindex_iter.back()->second.d_data.begin() );
-      }
-      //get the current node
-      Assert( d_tindex_iter.back()->second.hasNodeData() );
-      Node n = d_tindex_iter.back()->second.getNodeData();
-      Debug("cand-gen-qe") << "...returning " << n << std::endl;
-      Assert( !n.isNull() );
-      Assert( isLegalOpCandidate( n ) );
-      //increment
-      bool success = false;
-      do{
-        ++d_tindex_iter.back();
-        if( d_tindex_iter.back()==d_tindex.back()->d_data.end() ){
-          d_tindex.pop_back();
-          d_tindex_iter.pop_back();
-        }else{
-          success = true;
-        }
-      }while( !success && !d_tindex.empty() );
-      return n;   
-    } 
   }else if( d_mode==cand_term_ident ){
     Debug("cand-gen-qe") << "...get next candidate identity" << std::endl;
-    if( !d_n.isNull() ){
-      Node n = d_n;
-      d_n = Node::null();
+    if (!d_eqc.isNull())
+    {
+      Node n = d_eqc;
+      d_eqc = Node::null();
       if( isLegalOpCandidate( n ) ){
         return n;
       }
@@ -187,45 +124,6 @@ Node CandidateGeneratorQE::getNextCandidate(){
   return Node::null();
 }
 
-CandidateGeneratorQELitEq::CandidateGeneratorQELitEq( QuantifiersEngine* qe, Node mpat ) :
-  CandidateGenerator( qe ), d_match_pattern( mpat ){
-  Assert( mpat.getKind()==EQUAL );
-  for( unsigned i=0; i<2; i++ ){
-    if( !quantifiers::TermUtil::hasInstConstAttr(mpat[i]) ){
-      d_match_gterm = mpat[i];
-    }
-  }
-}
-void CandidateGeneratorQELitEq::resetInstantiationRound(){
-
-}
-void CandidateGeneratorQELitEq::reset( Node eqc ){
-  if( d_match_gterm.isNull() ){
-    d_eq = eq::EqClassesIterator( d_qe->getEqualityQuery()->getEngine() );
-  }else{
-    d_do_mgt = true;
-  }
-}
-Node CandidateGeneratorQELitEq::getNextCandidate(){
-  if( d_match_gterm.isNull() ){
-    while( !d_eq.isFinished() ){
-      Node n = (*d_eq);
-      ++d_eq;
-      if( n.getType().isComparableTo( d_match_pattern[0].getType() ) ){
-        //an equivalence class with the same type as the pattern, return reflexive equality
-        return NodeManager::currentNM()->mkNode( d_match_pattern.getKind(), n, n );
-      }
-    }
-  }else{
-    if( d_do_mgt ){
-      d_do_mgt = false;
-      return NodeManager::currentNM()->mkNode( d_match_pattern.getKind(), d_match_gterm, d_match_gterm );
-    }
-  }
-  return Node::null();
-}
-
-
 CandidateGeneratorQELitDeq::CandidateGeneratorQELitDeq( QuantifiersEngine* qe, Node mpat ) :
 CandidateGenerator( qe ), d_match_pattern( mpat ){
 
@@ -233,10 +131,6 @@ CandidateGenerator( qe ), d_match_pattern( mpat ){
   d_match_pattern_type = d_match_pattern[0].getType();
 }
 
-void CandidateGeneratorQELitDeq::resetInstantiationRound(){
-
-}
-
 void CandidateGeneratorQELitDeq::reset( Node eqc ){
   Node false_term = d_qe->getEqualityQuery()->getEngine()->getRepresentative( NodeManager::currentNM()->mkConst<bool>(false) );
   d_eqc_false = eq::EqClassIterator( false_term, d_qe->getEqualityQuery()->getEngine() );
@@ -269,10 +163,6 @@ CandidateGeneratorQEAll::CandidateGeneratorQEAll( QuantifiersEngine* qe, Node mp
   d_firstTime = false;
 }
 
-void CandidateGeneratorQEAll::resetInstantiationRound() {
-
-}
-
 void CandidateGeneratorQEAll::reset( Node eqc ) {
   d_eq = eq::EqClassesIterator( d_qe->getEqualityQuery()->getEngine() );
   d_firstTime = true;
@@ -307,3 +197,56 @@ Node CandidateGeneratorQEAll::getNextCandidate() {
   }
   return Node::null();
 }
+
+CandidateGeneratorConsExpand::CandidateGeneratorConsExpand(
+    QuantifiersEngine* qe, Node mpat)
+    : CandidateGeneratorQE(qe, mpat)
+{
+  Assert(mpat.getKind() == APPLY_CONSTRUCTOR);
+  d_mpat_type = static_cast<DatatypeType>(mpat.getType().toType());
+}
+
+void CandidateGeneratorConsExpand::reset(Node eqc)
+{
+  d_term_iter = 0;
+  if (eqc.isNull())
+  {
+    d_mode = cand_term_db;
+  }
+  else
+  {
+    d_eqc = eqc;
+    d_mode = cand_term_ident;
+    Assert(d_eqc.getType().toType() == d_mpat_type);
+  }
+}
+
+Node CandidateGeneratorConsExpand::getNextCandidate()
+{
+  // get the next term from the base class
+  Node curr = CandidateGeneratorQE::getNextCandidate();
+  if (curr.isNull() || (curr.hasOperator() && curr.getOperator() == d_op))
+  {
+    return curr;
+  }
+  // expand it
+  NodeManager* nm = NodeManager::currentNM();
+  std::vector<Node> children;
+  const Datatype& dt = d_mpat_type.getDatatype();
+  Assert(dt.getNumConstructors() == 1);
+  children.push_back(d_op);
+  for (unsigned i = 0, nargs = dt[0].getNumArgs(); i < nargs; i++)
+  {
+    Node sel =
+        nm->mkNode(APPLY_SELECTOR_TOTAL,
+                   Node::fromExpr(dt[0].getSelectorInternal(d_mpat_type, i)),
+                   curr);
+    children.push_back(sel);
+  }
+  return nm->mkNode(APPLY_CONSTRUCTOR, children);
+}
+
+bool CandidateGeneratorConsExpand::isLegalOpCandidate(Node n)
+{
+  return isLegalCandidate(n);
+}
index dc188062f6ea6635611819990c40270ea42a9c86..da4ec2d83b8806b79508f957f9e96ba28b3a7738 100644 (file)
 namespace CVC4 {
 namespace theory {
 
-namespace quantifiers {
-  class TermArgTrie;
-}
-
 class QuantifiersEngine;
 
 namespace inst {
 
-/** base class for generating candidates for matching */
+/** Candidate generator
+ *
+ * This is the base class for generating a stream of candidate terms for
+ * E-matching. Depending on the kind of trigger we are processing and its
+ * overall context, we are interested in several different criteria for
+ * terms. This includes:
+ * - Generating a stream of all ground terms with a given operator,
+ * - Generating a stream of all ground terms with a given operator in a
+ * particular equivalence class,
+ * - Generating a stream of all terms of a particular type,
+ * - Generating all terms that are disequal from a fixed ground term,
+ * and so on.
+ *
+ * A typical use case of an instance cg of this class is the following. Given
+ * an equivalence class representative eqc:
+ *
+ *  cg->reset( eqc );
+ *  do{
+ *    Node cand = cg->getNextCandidate();
+ *    ; ...if non-null, cand is a candidate...
+ *  }while( !cand.isNull() );
+ *
+ */
 class CandidateGenerator {
 protected:
   QuantifiersEngine* d_qe;
 public:
   CandidateGenerator( QuantifiersEngine* qe ) : d_qe( qe ){}
   virtual ~CandidateGenerator(){}
-
-  /** Get candidates functions.  These set up a context to get all match candidates.
-      cg->reset( eqc );
-      do{
-        Node cand = cg->getNextCandidate();
-        //.......
-      }while( !cand.isNull() );
-
-      eqc is the equivalence class you are searching in
-  */
+  /** reset instantiation round
+   *
+   * This is called at the beginning of each instantiation round.
+   */
+  virtual void resetInstantiationRound() {}
+  /** reset for equivalence class eqc
+   *
+   * This indicates that this class should generate a stream of candidate terms
+   * based on its criteria that occur in the equivalence class of eqc, or
+   * any equivalence class if eqc is null.
+   */
   virtual void reset( Node eqc ) = 0;
+  /** get the next candidate */
   virtual Node getNextCandidate() = 0;
-  /** add candidate to list of nodes returned by this generator */
-  virtual void addCandidate( Node n ) {}
-  /** call this at the beginning of each instantiation round */
-  virtual void resetInstantiationRound() = 0;
 public:
 /** legal candidate */
 bool isLegalCandidate( Node n );
/** is n a legal candidate? */
bool isLegalCandidate(Node n);
 };/* class CandidateGenerator */
 
-/** candidate generator queue (for manual candidate generation) */
-class CandidateGeneratorQueue : public CandidateGenerator {
- private:
-  std::vector< Node > d_candidates;
-  int d_candidate_index;
-
- public:
-  CandidateGeneratorQueue( QuantifiersEngine* qe ) : CandidateGenerator( qe ), d_candidate_index( 0 ){}
-
-  void addCandidate(Node n) override;
-
-  void resetInstantiationRound() override {}
-  void reset(Node eqc) override;
-  Node getNextCandidate() override;
-};/* class CandidateGeneratorQueue */
-
-//the default generator
+/* the default candidate generator class
+ *
+ * This class may generate candidates for E-matching based on several modes:
+ * (1) cand_term_db: iterate over all ground terms for the given operator,
+ * (2) cand_term_ident: generate the given input term as a candidate,
+ * (3) cand_term_eqc: iterate over all terms in an equivalence class, returning
+ * those with the proper operator as candidates.
+ */
 class CandidateGeneratorQE : public CandidateGenerator
 {
   friend class CandidateGeneratorQEDisequal;
 
- private:
-  //operator you are looking for
+ public:
+  CandidateGeneratorQE(QuantifiersEngine* qe, Node pat);
+  /** reset instantiation round */
+  void resetInstantiationRound() override;
+  /** reset */
+  void reset(Node eqc) override;
+  /** get next candidate */
+  Node getNextCandidate() override;
+  /** tell this class to exclude candidates from equivalence class r */
+  void excludeEqc(Node r) { d_exclude_eqc[r] = true; }
+  /** is r an excluded equivalence class? */
+  bool isExcludedEqc(Node r)
+  {
+    return d_exclude_eqc.find(r) != d_exclude_eqc.end();
+  }
+
+ protected:
+  /** operator you are looking for */
   Node d_op;
-  //the equality class iterator
-  unsigned d_op_arity;
-  std::vector< quantifiers::TermArgTrie* > d_tindex;
-  std::vector< std::map< TNode, quantifiers::TermArgTrie >::iterator > d_tindex_iter;
+  /** the equality class iterator (for cand_term_eqc) */
   eq::EqClassIterator d_eqc_iter;
-  //std::vector< Node > d_eqc;
+  /** the TermDb index of the current ground term (for cand_term_db) */
   int d_term_iter;
+  /** the TermDb index of the current ground term (for cand_term_db) */
   int d_term_iter_limit;
-  bool d_using_term_db;
+  /** the term we are matching (for cand_term_ident) */
+  Node d_eqc;
+  /** candidate generation modes */
   enum {
     cand_term_db,
     cand_term_ident,
     cand_term_eqc,
-    cand_term_tindex,
     cand_term_none,
   };
+  /** the current mode of this candidate generator */
   short d_mode;
-  bool isLegalOpCandidate( Node n );
-  Node d_n;
+  /** is n a legal candidate of the required operator? */
+  virtual bool isLegalOpCandidate(Node n);
+  /** the equivalence classes that we have excluded from candidate generation */
   std::map< Node, bool > d_exclude_eqc;
 
- public:
-  CandidateGeneratorQE( QuantifiersEngine* qe, Node pat );
-
-  void resetInstantiationRound() override;
-  void reset(Node eqc) override;
-  Node getNextCandidate() override;
-  void excludeEqc( Node r ) { d_exclude_eqc[r] = true; }
-  bool isExcludedEqc( Node r ) { return d_exclude_eqc.find( r )!=d_exclude_eqc.end(); }
 };
 
-class CandidateGeneratorQELitEq : public CandidateGenerator
+/**
+ * Generate terms based on a disequality, that is, we match (= t[x] s[x])
+ * with equalities (= g1 g2) in the equivalence class of false.
+ */
+class CandidateGeneratorQELitDeq : public CandidateGenerator
 {
- private:
-  //the equality classes iterator
-  eq::EqClassesIterator d_eq;
-  //equality you are trying to match equalities for
-  Node d_match_pattern;
-  Node d_match_gterm;
-  bool d_do_mgt;
-
  public:
-  CandidateGeneratorQELitEq( QuantifiersEngine* qe, Node mpat );
-
-  void resetInstantiationRound() override;
+  /**
+   * mpat is an equality that we are matching to equalities in the equivalence
+   * class of false
+   */
+  CandidateGeneratorQELitDeq(QuantifiersEngine* qe, Node mpat);
+  /** reset */
   void reset(Node eqc) override;
+  /** get next candidate */
   Node getNextCandidate() override;
-};
 
-class CandidateGeneratorQELitDeq : public CandidateGenerator
-{
  private:
-  //the equality class iterator for false
+  /** the equality class iterator for false */
   eq::EqClassIterator d_eqc_false;
-  //equality you are trying to match disequalities for
+  /**
+   * equality you are trying to match against ground equalities that are
+   * assigned to false
+   */
   Node d_match_pattern;
-  //type of disequality
+  /** type of the terms we are generating */
   TypeNode d_match_pattern_type;
-
- public:
-  CandidateGeneratorQELitDeq( QuantifiersEngine* qe, Node mpat );
-
-  void resetInstantiationRound() override;
-  void reset(Node eqc) override;
-  Node getNextCandidate() override;
 };
 
+/**
+ * Generate all terms of the proper sort that occur in the current context.
+ */
 class CandidateGeneratorQEAll : public CandidateGenerator
 {
  private:
@@ -166,10 +178,34 @@ class CandidateGeneratorQEAll : public CandidateGenerator
 
  public:
   CandidateGeneratorQEAll( QuantifiersEngine* qe, Node mpat );
+  /** reset */
+  void reset(Node eqc) override;
+  /** get next candidate */
+  Node getNextCandidate() override;
+};
 
-  void resetInstantiationRound() override;
+/** candidate generation constructor expand
+ *
+ * This modifies the candidates t1, ..., tn generated by CandidateGeneratorQE
+ * so that they are "expansions" of a fixed datatype constructor C. Assuming
+ * C has arity m, we instead return the stream:
+ *   C(sel_1( t1 ), ..., sel_m( tn )) ... C(sel_1( t1 ), ..., C( sel_m( tn ))
+ * where sel_1 ... sel_m are the selectors of C.
+ */
+class CandidateGeneratorConsExpand : public CandidateGeneratorQE
+{
+ public:
+  CandidateGeneratorConsExpand(QuantifiersEngine* qe, Node mpat);
+  /** reset */
   void reset(Node eqc) override;
+  /** get next candidate */
   Node getNextCandidate() override;
+
+ protected:
+  /** the (datatype) type of the input match pattern */
+  DatatypeType d_mpat_type;
+  /** we don't care about the operator of n */
+  bool isLegalOpCandidate(Node n) override;
 };
 
 }/* CVC4::theory::inst namespace */
index 90d1815a431a56bdd0046e17bcaec4c30610d00a..192a6b433fc4cabbb749fe70d2417306378137f5 100644 (file)
@@ -174,8 +174,24 @@ void InstMatchGenerator::initialize( Node q, QuantifiersEngine* qe, std::vector<
 
     //create candidate generator
     if( Trigger::isAtomicTrigger( d_match_pattern ) ){
-      //we will be scanning lists trying to find d_match_pattern.getOperator()
-      d_cg = new inst::CandidateGeneratorQE( qe, d_match_pattern );
+      if (d_match_pattern.getKind() == APPLY_CONSTRUCTOR)
+      {
+        // 1-constructors have a trivial way of generating candidates in a
+        // given equivalence class
+        const Datatype& dt =
+            static_cast<DatatypeType>(d_match_pattern.getType().toType())
+                .getDatatype();
+        if (dt.getNumConstructors() == 1)
+        {
+          d_cg = new inst::CandidateGeneratorConsExpand(qe, d_match_pattern);
+        }
+      }
+      if (d_cg == nullptr)
+      {
+        // we will be scanning lists trying to find
+        // d_match_pattern.getOperator()
+        d_cg = new inst::CandidateGeneratorQE(qe, d_match_pattern);
+      }
       //if matching on disequality, inform the candidate generator not to match on eqc
       if( d_pattern.getKind()==NOT && d_pattern[0].getKind()==EQUAL ){
         ((inst::CandidateGeneratorQE*)d_cg)->excludeEqc( d_eq_class_rel );
@@ -196,13 +212,9 @@ void InstMatchGenerator::initialize( Node q, QuantifiersEngine* qe, std::vector<
     }else if( d_match_pattern.getKind()==EQUAL &&
               d_match_pattern[0].getKind()==INST_CONSTANT && d_match_pattern[1].getKind()==INST_CONSTANT ){
       //we will be producing candidates via literal matching heuristics
-      if( d_pattern.getKind()!=NOT ){
-        //candidates will be all equalities
-        d_cg = new inst::CandidateGeneratorQELitEq( qe, d_match_pattern );
-      }else{
-        //candidates will be all disequalities
-        d_cg = new inst::CandidateGeneratorQELitDeq( qe, d_match_pattern );
-      }
+      Assert(d_pattern.getKind() == NOT);
+      // candidates will be all disequalities
+      d_cg = new inst::CandidateGeneratorQELitDeq(qe, d_match_pattern);
     }else{
       Trace("inst-match-gen-warn") << "(?) Unknown matching pattern is " << d_match_pattern << std::endl;
     }