Cosmetic improvements to theory datatypes (#7020)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 17 Aug 2021 05:44:20 +0000 (00:44 -0500)
committerGitHub <noreply@github.com>
Tue, 17 Aug 2021 05:44:20 +0000 (05:44 +0000)
Mainly just indentation / formatting changes. In preparation for playing around with heuristics to datatypes theory motivated by Facebook benchmarks.

src/theory/datatypes/theory_datatypes.cpp
src/theory/datatypes/theory_datatypes.h

index 3dfd9b458d898108eef7a244fa366e981eea9e2b..2162c4d14a094f13b1a47fa129d2dcd6a08db232 100644 (file)
@@ -58,7 +58,6 @@ TheoryDatatypes::TheoryDatatypes(Env& env,
       d_collectTermsCacheU(getUserContext()),
       d_functionTerms(getSatContext()),
       d_singleton_eq(getUserContext()),
-      d_lemmas_produced_c(getUserContext()),
       d_sygusExtension(nullptr),
       d_state(env, valuation),
       d_im(*this, d_state, d_pnm),
@@ -539,129 +538,151 @@ void TheoryDatatypes::eqNotifyMerge(TNode t1, TNode t2)
 }
 
 void TheoryDatatypes::merge( Node t1, Node t2 ){
-  if (!d_state.isInConflict())
+  if (d_state.isInConflict())
   {
-    Trace("datatypes-merge") << "Merge " << t1 << " " << t2 << std::endl;
-    Assert(areEqual(t1, t2));
-    TNode trep1 = t1;
-    TNode trep2 = t2;
-    EqcInfo* eqc2 = getOrMakeEqcInfo( t2 );
-    if( eqc2 ){
-      bool checkInst = false;
-      if( !eqc2->d_constructor.get().isNull() ){
-        trep2 = eqc2->d_constructor.get();
-      }
-      EqcInfo* eqc1 = getOrMakeEqcInfo( t1 );
-      if( eqc1 ){
-        Trace("datatypes-debug") << "  merge eqc info " << eqc2 << " into " << eqc1 << std::endl;
-        if( !eqc1->d_constructor.get().isNull() ){
-          trep1 = eqc1->d_constructor.get();
-        }
-        //check for clash
-        TNode cons1 = eqc1->d_constructor.get();
-        TNode cons2 = eqc2->d_constructor.get();
-        //if both have constructor, then either clash or unification
-        if( !cons1.isNull() && !cons2.isNull() ){
-          Trace("datatypes-debug") << "  constructors : " << cons1 << " " << cons2 << std::endl;
-          Node unifEq = cons1.eqNode( cons2 );
-          std::vector< Node > rew;
-          if (utils::checkClash(cons1, cons2, rew))
-          {
-            std::vector<Node> conf;
-            conf.push_back(unifEq);
-            Trace("dt-conflict")
-                << "CONFLICT: Clash conflict : " << conf << std::endl;
-            d_im.sendDtConflict(conf, InferenceId::DATATYPES_CLASH_CONFLICT);
-            return;
-          }
-          else
-          {
-            Assert(areEqual(cons1, cons2));
-            //do unification
-            for( int i=0; i<(int)cons1.getNumChildren(); i++ ) {
-              if( !areEqual( cons1[i], cons2[i] ) ){
-                Node eq = cons1[i].eqNode( cons2[i] );
-                d_im.addPendingInference(
-                    eq, InferenceId::DATATYPES_UNIF, unifEq);
-                Trace("datatypes-infer") << "DtInfer : cons-inj : " << eq << " by " << unifEq << std::endl;
-              }
-            }
-          }
-        }
-        Trace("datatypes-debug") << "  instantiated : " << eqc1->d_inst << " " << eqc2->d_inst << std::endl;
-        eqc1->d_inst = eqc1->d_inst || eqc2->d_inst;
-        if( !cons2.isNull() ){
-          if( cons1.isNull() ){
-            Trace("datatypes-debug") << "  must check if it is okay to set the constructor." << std::endl;
-            checkInst = true;
-            addConstructor( eqc2->d_constructor.get(), eqc1, t1 );
-            if (d_state.isInConflict())
-            {
-              return;
-            }
-          }
-        }
-      }else{
-        Trace("datatypes-debug") << "  no eqc info for " << t1 << ", must create" << std::endl;
-        //just copy the equivalence class information
-        eqc1 = getOrMakeEqcInfo( t1, true );
-        eqc1->d_inst.set( eqc2->d_inst );
-        eqc1->d_constructor.set( eqc2->d_constructor );
-        eqc1->d_selectors.set( eqc2->d_selectors );
+    return;
+  }
+  Trace("datatypes-merge") << "Merge " << t1 << " " << t2 << std::endl;
+  Assert(areEqual(t1, t2));
+  TNode trep1 = t1;
+  TNode trep2 = t2;
+  EqcInfo* eqc2 = getOrMakeEqcInfo(t2);
+  if (eqc2 == nullptr)
+  {
+    return;
+  }
+  bool checkInst = false;
+  if (!eqc2->d_constructor.get().isNull())
+  {
+    trep2 = eqc2->d_constructor.get();
+  }
+  EqcInfo* eqc1 = getOrMakeEqcInfo(t1);
+  if (eqc1)
+  {
+    Trace("datatypes-debug")
+        << "  merge eqc info " << eqc2 << " into " << eqc1 << std::endl;
+    if (!eqc1->d_constructor.get().isNull())
+    {
+      trep1 = eqc1->d_constructor.get();
+    }
+    // check for clash
+    TNode cons1 = eqc1->d_constructor.get();
+    TNode cons2 = eqc2->d_constructor.get();
+    // if both have constructor, then either clash or unification
+    if (!cons1.isNull() && !cons2.isNull())
+    {
+      Trace("datatypes-debug")
+          << "  constructors : " << cons1 << " " << cons2 << std::endl;
+      Node unifEq = cons1.eqNode(cons2);
+      std::vector<Node> rew;
+      if (utils::checkClash(cons1, cons2, rew))
+      {
+        std::vector<Node> conf;
+        conf.push_back(unifEq);
+        Trace("dt-conflict")
+            << "CONFLICT: Clash conflict : " << conf << std::endl;
+        d_im.sendDtConflict(conf, InferenceId::DATATYPES_CLASH_CONFLICT);
+        return;
       }
-
-
-      //merge labels
-      NodeUIntMap::iterator lbl_i = d_labels.find(t2);
-      if( lbl_i != d_labels.end() ){
-        Trace("datatypes-debug") << "  merge labels from " << eqc2 << " " << t2 << std::endl;
-        size_t n_label = (*lbl_i).second;
-        for (size_t i = 0; i < n_label; i++)
+      else
+      {
+        Assert(areEqual(cons1, cons2));
+        // do unification
+        for (size_t i = 0, nchild = cons1.getNumChildren(); i < nchild; i++)
         {
-          Assert(i < d_labels_data[t2].size());
-          Node t = d_labels_data[ t2 ][i];
-          Node t_arg = d_labels_args[t2][i];
-          unsigned tindex = d_labels_tindex[t2][i];
-          addTester( tindex, t, eqc1, t1, t_arg );
-          if (d_state.isInConflict())
+          if (!areEqual(cons1[i], cons2[i]))
           {
-            Trace("datatypes-debug") << "  conflict!" << std::endl;
-            return;
+            Node eq = cons1[i].eqNode(cons2[i]);
+            d_im.addPendingInference(eq, InferenceId::DATATYPES_UNIF, unifEq);
+            Trace("datatypes-infer") << "DtInfer : cons-inj : " << eq << " by "
+                                     << unifEq << std::endl;
           }
         }
-
       }
-      //merge selectors
-      if( !eqc1->d_selectors && eqc2->d_selectors ){
-        eqc1->d_selectors = true;
+    }
+    Trace("datatypes-debug") << "  instantiated : " << eqc1->d_inst << " "
+                             << eqc2->d_inst << std::endl;
+    eqc1->d_inst = eqc1->d_inst || eqc2->d_inst;
+    if (!cons2.isNull())
+    {
+      if (cons1.isNull())
+      {
+        Trace("datatypes-debug")
+            << "  must check if it is okay to set the constructor."
+            << std::endl;
         checkInst = true;
-      }
-      NodeUIntMap::iterator sel_i = d_selector_apps.find(t2);
-      if( sel_i != d_selector_apps.end() ){
-        Trace("datatypes-debug") << "  merge selectors from " << eqc2 << " " << t2 << std::endl;
-        size_t n_sel = (*sel_i).second;
-        for (size_t j = 0; j < n_sel; j++)
-        {
-          addSelector( d_selector_apps_data[t2][j], eqc1, t1, eqc2->d_constructor.get().isNull() );
-        }
-      }
-      if( checkInst ){
-        Trace("datatypes-debug") << "  checking instantiate" << std::endl;
-        instantiate( eqc1, t1 );
+        addConstructor(eqc2->d_constructor.get(), eqc1, t1);
         if (d_state.isInConflict())
         {
           return;
         }
       }
     }
-    Trace("datatypes-debug") << "Finished Merge " << t1 << " " << t2 << std::endl;
   }
+  else
+  {
+    Trace("datatypes-debug")
+        << "  no eqc info for " << t1 << ", must create" << std::endl;
+    // just copy the equivalence class information
+    eqc1 = getOrMakeEqcInfo(t1, true);
+    eqc1->d_inst.set(eqc2->d_inst);
+    eqc1->d_constructor.set(eqc2->d_constructor);
+    eqc1->d_selectors.set(eqc2->d_selectors);
+  }
+
+  // merge labels
+  NodeUIntMap::iterator lbl_i = d_labels.find(t2);
+  if (lbl_i != d_labels.end())
+  {
+    Trace("datatypes-debug")
+        << "  merge labels from " << eqc2 << " " << t2 << std::endl;
+    size_t n_label = (*lbl_i).second;
+    for (size_t i = 0; i < n_label; i++)
+    {
+      Assert(i < d_labels_data[t2].size());
+      Node t = d_labels_data[t2][i];
+      Node t_arg = d_labels_args[t2][i];
+      unsigned tindex = d_labels_tindex[t2][i];
+      addTester(tindex, t, eqc1, t1, t_arg);
+      if (d_state.isInConflict())
+      {
+        Trace("datatypes-debug") << "  conflict!" << std::endl;
+        return;
+      }
+    }
+  }
+  // merge selectors
+  if (!eqc1->d_selectors && eqc2->d_selectors)
+  {
+    eqc1->d_selectors = true;
+    checkInst = true;
+  }
+  NodeUIntMap::iterator sel_i = d_selector_apps.find(t2);
+  if (sel_i != d_selector_apps.end())
+  {
+    Trace("datatypes-debug")
+        << "  merge selectors from " << eqc2 << " " << t2 << std::endl;
+    size_t n_sel = (*sel_i).second;
+    for (size_t j = 0; j < n_sel; j++)
+    {
+      addSelector(d_selector_apps_data[t2][j],
+                  eqc1,
+                  t1,
+                  eqc2->d_constructor.get().isNull());
+    }
+  }
+  if (checkInst)
+  {
+    Trace("datatypes-debug") << "  checking instantiate" << std::endl;
+    instantiate(eqc1, t1);
+  }
+  Trace("datatypes-debug") << "Finished Merge " << t1 << " " << t2 << std::endl;
 }
 
-TheoryDatatypes::EqcInfo::EqcInfo( context::Context* c )
-    : d_inst( c, false )
-    , d_constructor( c, Node::null() )
-    , d_selectors( c, false )
+TheoryDatatypes::EqcInfo::EqcInfo(context::Context* c)
+    : d_inst(c, false),
+      d_constructor(c, Node::null()),
+      d_selectors(c, false)
 {}
 
 bool TheoryDatatypes::hasLabel( EqcInfo* eqc, Node n ){
@@ -823,7 +844,7 @@ void TheoryDatatypes::addTester(
       const DType& dt = t_arg.getType().getDType();
       Debug("datatypes-labels") << "Labels at " << n_lbl << " / " << dt.getNumConstructors() << std::endl;
       if( tpolarity ){
-        instantiate( eqc, n );
+        instantiate(eqc, n);
         // We could propagate is-C1(x) => not is-C2(x) here for all other
         // constructors, but empirically this hurts performance.
       }else{
@@ -1409,13 +1430,14 @@ Node TheoryDatatypes::getInstantiateCons(Node n, const DType& dt, int index)
   return n_ic;
 }
 
-void TheoryDatatypes::instantiate( EqcInfo* eqc, Node n ){
+bool TheoryDatatypes::instantiate(EqcInfo* eqc, Node n)
+{
   Trace("datatypes-debug") << "Instantiate: " << n << std::endl;
   //add constructor to equivalence class if not done so already
   int index = getLabelIndex( eqc, n );
   if (index == -1 || eqc->d_inst)
   {
-    return;
+    return false;
   }
   Node exp;
   Node tt;
@@ -1437,7 +1459,8 @@ void TheoryDatatypes::instantiate( EqcInfo* eqc, Node n ){
   Node eq;
   if (tt == tt_cons)
   {
-    return;
+    // not necessary
+    return false;
   }
   eq = tt.eqNode(tt_cons);
   // Determine if the equality must be sent out as a lemma. Notice that
@@ -1460,9 +1483,10 @@ void TheoryDatatypes::instantiate( EqcInfo* eqc, Node n ){
   }
   Trace("datatypes-infer-debug") << "DtInstantiate : " << eqc << " " << eq
                                  << " forceLemma = " << forceLemma << std::endl;
-  d_im.addPendingInference(eq, InferenceId::DATATYPES_INST, exp, forceLemma);
   Trace("datatypes-infer") << "DtInfer : instantiate : " << eq << " by " << exp
                            << std::endl;
+  d_im.addPendingInference(eq, InferenceId::DATATYPES_INST, exp, forceLemma);
+  return true;
 }
 
 void TheoryDatatypes::checkCycles() {
index ecfa6f02ae0587500159b8cf2fdc1745d27a3655..68dedb6f319baa59c3c8c6ca3d00d41b381f94b8 100644 (file)
@@ -158,8 +158,6 @@ private:
   std::map< TypeNode, Node > d_singleton_lemma[2];
   /** Cache for singleton equalities processed */
   BoolMap d_singleton_eq;
-  /** list of all lemmas produced */
-  BoolMap d_lemmas_produced_c;
 private:
   /** assert fact */
   void assertFact( Node fact, Node exp );
@@ -268,9 +266,10 @@ private:
   void collectTerms( Node n );
   /** get instantiate cons */
   Node getInstantiateCons(Node n, const DType& dt, int index);
-  /** check instantiate */
-  void instantiate( EqcInfo* eqc, Node n );
-private:
+  /** check instantiate, return true if an inference was generated. */
+  bool instantiate(EqcInfo* eqc, Node n);
+
+ private:
   //equality queries
   bool hasTerm( TNode a );
   bool areEqual( TNode a, TNode b );