Refactoring of initial lemmas in datatypes (#8666)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 27 Apr 2022 13:25:36 +0000 (08:25 -0500)
committerGitHub <noreply@github.com>
Wed, 27 Apr 2022 13:25:36 +0000 (13:25 +0000)
This is work towards revising how/when the datatypes instantiate rule is applied.

This simplifies the management of when new terms are registered in the theory of datatypes.

We now use the equality engine's eqNotifyNewClass callback to know when a term should be considered. Previously, this was split over two methods (additionally, collectTerms).

Most importantly, this eliminates the need for a manual addition of the "instantiated constructor term" in getInstantiateCons, which complicates the logic for the impact of applying the datatypes instantiate rule.

src/theory/datatypes/theory_datatypes.cpp
src/theory/datatypes/theory_datatypes.h

index 741d2506c58860c5a02292a33322bc905836017f..7d135ee59087cea23d1d54e64cd6a231fa2bb93f 100644 (file)
@@ -56,8 +56,7 @@ TheoryDatatypes::TheoryDatatypes(Env& env,
       d_term_sk(userContext()),
       d_labels(context()),
       d_selector_apps(context()),
-      d_collectTermsCache(context()),
-      d_collectTermsCacheU(userContext()),
+      d_initialLemmaCache(userContext()),
       d_functionTerms(context()),
       d_singleton_eq(userContext()),
       d_sygusExtension(nullptr),
@@ -464,7 +463,6 @@ void TheoryDatatypes::preRegisterTerm(TNode n)
       Trace("dt-expand") << "...nested recursion ok" << std::endl;
     }
   }
-  collectTerms( n );
   switch (n.getKind()) {
   case kind::EQUAL:
   case kind::APPLY_TESTER:
@@ -473,6 +471,8 @@ void TheoryDatatypes::preRegisterTerm(TNode n)
     d_equalityEngine->addTriggerPredicate(n);
     break;
   default:
+    // do initial lemmas (e.g. for dt.size)
+    registerInitialLemmas(n);
     // Function applications/predicates
     d_equalityEngine->addTerm(n);
     if (d_sygusExtension)
@@ -521,9 +521,28 @@ TrustNode TheoryDatatypes::explain(TNode literal)
 }
 
 /** called when a new equivalance class is created */
-void TheoryDatatypes::eqNotifyNewClass(TNode t){
-  if( t.getKind()==APPLY_CONSTRUCTOR ){
-    getOrMakeEqcInfo( t, true );
+void TheoryDatatypes::eqNotifyNewClass(TNode n)
+{
+  Kind nk = n.getKind();
+  if (nk == APPLY_CONSTRUCTOR)
+  {
+    Trace("datatypes") << "  Found constructor " << n << endl;
+    getOrMakeEqcInfo(n, true);
+    if (n.getNumChildren() > 0)
+    {
+      d_functionTerms.push_back(n);
+    }
+  }
+  if (nk == APPLY_SELECTOR || nk == DT_SIZE || nk == DT_HEIGHT_BOUND)
+  {
+    d_functionTerms.push_back(n);
+    // we must also record which selectors exist
+    Trace("dt-collapse-sel") << "  Found selector " << n << endl;
+    Node rep = getRepresentative(n[0]);
+    // record it in the selectors
+    EqcInfo* eqc = getOrMakeEqcInfo(rep, true);
+    // add it to the eqc info
+    addSelector(n, eqc, rep);
   }
 }
 
@@ -1256,49 +1275,16 @@ Node TheoryDatatypes::getSingletonLemma( TypeNode tn, bool pol ) {
   }
 }
 
-void TheoryDatatypes::collectTerms( Node n ) {
-  if (d_collectTermsCache.find(n) != d_collectTermsCache.end())
-  {
-    // already processed
-    return;
-  }
-  d_collectTermsCache[n] = true;
-  Kind nk = n.getKind();
-  if (nk == APPLY_CONSTRUCTOR)
-  {
-    Trace("datatypes") << "  Found constructor " << n << endl;
-    if (n.getNumChildren() > 0)
-    {
-      d_functionTerms.push_back(n);
-    }
-    return;
-  }
-  if (nk == APPLY_SELECTOR || nk == DT_SIZE || nk == DT_HEIGHT_BOUND)
-  {
-    d_functionTerms.push_back(n);
-    // we must also record which selectors exist
-    Trace("dt-collapse-sel") << "  Found selector " << n << endl;
-    Node rep = getRepresentative(n[0]);
-    // record it in the selectors
-    EqcInfo* eqc = getOrMakeEqcInfo(rep, true);
-    // add it to the eqc info
-    addSelector(n, eqc, rep);
-  }
-
-  // now, do user-context-dependent lemmas
-  if (nk != DT_SIZE && nk != DT_HEIGHT_BOUND)
-  {
-    // if not one of these kinds, there are no lemmas
-    return;
-  }
-  if (d_collectTermsCacheU.find(n) != d_collectTermsCacheU.end())
+void TheoryDatatypes::registerInitialLemmas(Node n)
+{
+  if (d_initialLemmaCache.find(n) != d_initialLemmaCache.end())
   {
     return;
   }
-  d_collectTermsCacheU[n] = true;
+  d_initialLemmaCache[n] = true;
 
   NodeManager* nm = NodeManager::currentNM();
-
+  Kind nk = n.getKind();
   if (nk == DT_SIZE)
   {
     Node lem = nm->mkNode(LEQ, d_zero, n);
@@ -1341,10 +1327,7 @@ Node TheoryDatatypes::getInstantiateCons(Node n, const DType& dt, int index)
   //add constructor to equivalence class
   Node k = getTermSkolemFor( n );
   Node n_ic = utils::getInstCons(k, dt, index);
-  n_ic = rewrite(n_ic);
-  // it may be a new term, so we collect terms and add it to the equality engine
-  collectTerms( n_ic );
-  d_equalityEngine->addTerm(n_ic);
+  Assert (n_ic == rewrite(n_ic));
   Trace("dt-enum") << "Made instantiate cons " << n_ic << std::endl;
   return n_ic;
 }
index 933e7065ddcf8f81e31ee70a8db3bc336607b9fb..a9072b419f3a6d7a73d0613f2cd841e34d269b94 100644 (file)
@@ -138,16 +138,11 @@ private:
   std::map< Node, std::vector< Node > > d_selector_apps_data;
   /** The conflict node */
   Node d_conflictNode;
-  /**
-   * SAT-context dependent cache for which terms we have called
-   * collectTerms(...) on.
-   */
-  BoolMap d_collectTermsCache;
   /**
    * User-context dependent cache for which terms we have called
-   * collectTerms(...) on.
+   * registerInitialLemmas(...) on.
    */
-  BoolMap d_collectTermsCacheU;
+  BoolMap d_initialLemmaCache;
   /** All the function terms that the theory has seen */
   context::CDList<TNode> d_functionTerms;
   /** uninterpreted constant to variable map */
@@ -241,6 +236,11 @@ private:
   void merge( Node t1, Node t2 );
   /** collapse selector, s is of the form sel( n ) where n = c */
   void collapseSelector( Node s, Node c );
+  /** 
+   * Register initial lemmas. This adds pending lemmas on the inference manager
+   * corresponding to unit lemmas for e.g. dt.size.
+   */
+  void registerInitialLemmas(Node n);
   /** for checking if cycles exist */
   void checkCycles();
   Node searchForCycle(TNode n,
@@ -261,8 +261,6 @@ private:
   Node getCodatatypesValue( Node n, std::map< Node, Node >& eqc_cons, std::map< Node, int >& vmap, int depth );
   /** get singleton lemma */
   Node getSingletonLemma( TypeNode tn, bool pol );
-  /** collect terms */
-  void collectTerms( Node n );
   /** get instantiate cons */
   Node getInstantiateCons(Node n, const DType& dt, int index);
   /** check instantiate, return true if an inference was generated. */